2018年9月6日木曜日

MNISTは「重心」が重要

scikit-learnの手書き数字認識で、MNISTのデータを使ってみることにした。
先ずは、どんな数字の画像データなのか現物確認しないと始まらないので、MNISTデータを
画像に落とすスクリプトを書いた。
どうせなら、先に作ったスクリプトで使えるファイル名で画像保存する。

"counter_limit"で各数字ごとに保存する件数を指定。
下の例だと、「0」1,000件、「1」1,000件、、、「9」1,000件、と合計10,000件の画像が保存される。

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import numpy as np
  5. from sklearn import datasets
  6. import cv2, os
  7.  
  8. counter = [''] * 10
  9. counter_limit = 1000
  10.  
  11. try:
  12. os.mkdir('mnist_img')
  13. except:
  14. pass
  15.  
  16. for i in range (0, 10):
  17. counter[i] = 0
  18. try:
  19. os.mkdir(str(i) + 'test')
  20. except:
  21. continue
  22.  
  23.  
  24. mnist = datasets.fetch_mldata('MNIST original')
  25. X, Y = mnist.data,mnist.target
  26.  
  27. for i in range(0, 70000):
  28. for n in range(0, 10):
  29. if Y[i] == n:
  30. counter[n] += 1
  31. if counter[n] <= counter_limit:
  32. img =255 - X[i].reshape(28,28)
  33. #cv2.imshow('mnist', img)
  34. #k = cv2.waitKey()
  35. print('i = ' + str(i) + ', target = ' + str(Y[i]))
  36.  
  37. save_name0 = './' + str(n) + 'test/' + str(counter[n] * 10 + n) + 'test.png'
  38. cv2.imwrite(save_name0, img)
  39.  
  40. save_name1 = './mnist_img/' + str(counter[n] * 10 + n) + 'test.png'
  41. cv2.imwrite(save_name1, img)
  42.  

因みに、70,000件全ての画像を保存するには"counter_limit"を(余裕を見て)"counter_limit = 8000"とかにする。
そうすると、各数字に以下の件数の画像が保存される。

0:6903
1:7877
2:6990
3:7141
4:6824
5:6313
6:6876
7:7293
8:6825
9:6958

さて、保存した画像を眺めると、書かれた数字はバラエティがあるように見えるが、意外と偏っていて、
描画領域については広範に描かれている訳ではなさそうだ。

例えば、
「7」は28x28の上部に余白が多い
「6」だと28x28の下部に余白が多い
「4」が右に寄っているように見える(縦棒が中心より右に書かれている傾向にあるから)
「9」も28x28の上部に余白がある
等々。
ちょっと見ただけでも上記のような偏り(傾向)が分かる。

ただ、これには訳がある。
MNISTを公開しているサイトに以下の記述がある。
the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
※次のページが参考になりました。ありがとうございます。
「自分の手書きデータをTensorFlowで予測する」
「MNISTを学習させたモデルの気持ちを調べる」

「重心」を計算して中央に配置したデータのため、重心が上にあるものは下に移動して
上部に余白ができる。

重心が下にあるものは上に移動して下部に余白ができる。
このデータで学習をした場合、同じ様に重心で移動したデータの検証ならば正解率が
上がるが、そうでないデータでは誤認識が生じやすくなる。

そこで、MNISTのデータで学習をする場合は、それに合わせて判定用データも加工する
スクリプトを書いてみた。
前「おまけのおまけ」スクリプト"mnist_img"ディレクトリに置き、引数に28を渡して
生成した「28x28.csv」を選択すると、MNISTデータでの学習を前提にして判定用データ
を加工する。
具体的には"def mnist_mod(img):"の部分がそれにあたる。
横方向の移動に関しては、重心が中央よりも右方向に行くようにしてある。
※追記1(2018/09/09 10:00)※
"def mnist_mod(img):"の重心移動方法を修正。
※追記2(2018/09/10 22:00)※
digits.csv用の前処理を追加したつもり。
※追記3(2018/09/12 20:10)※
縦横比が同じでないテスト用画像の読み込みに対応したつもり。
現状下の様に、パーツに分かれた数字は、重心を求めていないためMNISTの学習結果では、正常に判定できない。
※追記4(2018/09/14 20:30)※
パーツに分かれた数字の重心(と言うか画像の重心)を求めるように修正。

ある程度の調整は出きるが、、、。うーん、今一。
ちなみに分類機のパラメータは出鱈目なので悪しからず。

※追記5(2019/12/05 11:10)※
OpenCVのバージョンにより、findContoursの返り値の数が異なる事に対応(?)。

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import cv2
  7. from time import sleep
  8. from os.path import dirname, join, basename, exists
  9. from sklearn.datasets.base import Bunch
  10. import sys
  11. from sklearn.externals import joblib
  12.  
  13. py_version = sys.version_info.major
  14. cv_version = cv2.__version__[0:1]
  15. module_path = dirname(__file__)
  16.  
  17. if py_version != 2:
  18. import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
  19. tkinter.simpledialog as tksimpledialog
  20. else:
  21. import Tkinter as tk, tkFileDialog as tkfiledialog, tkMessageBox as tkmessage, \
  22. tkSimpleDialog as tksimpledialog
  23.  
  24. def clf_select():
  25. def check():
  26. root.quit()
  27. return v.get()
  28.  
  29. root = tk.Tk()
  30. root.title('digit_test')
  31.  
  32. v = tk.IntVar()
  33. v.set(0)
  34.  
  35. radio1 = tk.Radiobutton(text='SVM', variable=v, value=1, font=("",14))
  36. radio1.pack(anchor='w')
  37.  
  38. radio2 = tk.Radiobutton(text='MLPClassifier', variable=v, value=2, font=("",14))
  39. radio2.pack(anchor='w')
  40.  
  41. radio3 = tk.Radiobutton(text='LogisticRegression', variable=v, value=3, font=("",14))
  42. radio3.pack(anchor='w')
  43.  
  44. radio4 = tk.Radiobutton(text='DecisionTreeClassifier', variable=v, value=4, font=("",14))
  45. radio4.pack(anchor='w')
  46.  
  47. button1 = tk.Button(root, text='決 定', command= lambda: check(), width=30, font=("",14))
  48. button1.pack()
  49.  
  50. root.mainloop()
  51. try:
  52. root.destroy()
  53. except:
  54. sys.exit()
  55.  
  56. return check()
  57.  
  58. def csv_select():
  59. root = tk.Tk()
  60. root.withdraw()
  61.  
  62. file_type = [('', '*csv')]
  63. data_path = join(dirname(__file__), 'data')
  64. tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
  65. test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
  66. if test_csv == () or test_csv == '':
  67. test_csv = 'digits.csv'
  68. else:
  69. test_csv = basename(test_csv)
  70.  
  71. root.destroy()
  72. print(test_csv)
  73. return test_csv
  74.  
  75. def img_select():
  76. root = tk.Tk()
  77. root.withdraw()
  78.  
  79. file_type = [('', '*png')]
  80. data_path = dirname(__file__)
  81. #tkmessage.showinfo(test_csv, '画像(png)を選択してください')
  82. test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
  83. if test_img == () or test_img == '':
  84. test_img = 'digit_test.png'
  85.  
  86. root.destroy()
  87. print(test_img)
  88. return test_img
  89.  
  90. def question0():
  91. root = tk.Tk()
  92. root.withdraw()
  93.  
  94. y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')
  95.  
  96. root.destroy()
  97. return y_or_n
  98.  
  99. def question1():
  100. root = tk.Tk()
  101. root.withdraw()
  102.  
  103. next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')
  104.  
  105. root.destroy()
  106. return next_y_or_n
  107.  
  108. def correct_input():
  109. root = tk.Tk()
  110. root.withdraw()
  111. root.after(1, lambda: root.focus_force())
  112. correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください',
  113. initialvalue='ここに入力',parent=root)
  114.  
  115. root.destroy()
  116. return correct_digit
  117.  
  118. def load_digits():
  119. module_path = dirname(__file__)
  120. data = np.loadtxt(join(module_path, 'data', test_csv),
  121. delimiter=',')
  122.  
  123. data_len = int(np.sqrt(len(data[0]) - 1))
  124.  
  125. target = data[:, -1].astype(np.int)
  126. flat_data = data[:, :-1]
  127. images = flat_data.view()
  128. images.shape = (-1, data_len, data_len)
  129.  
  130. return Bunch(data=flat_data,
  131. target=target,
  132. target_names=np.arange(10),
  133. images=images)
  134.  
  135. def pre_mod(img):
  136. data_len = int(np.sqrt(len(data[0])))
  137. size = (data_len, data_len)
  138. temp_h, temp_w = img.shape[:2]
  139.  
  140. if temp_w > temp_h:
  141. temp_size = temp_w
  142. else:
  143. temp_size = temp_h
  144.  
  145. if test_csv == '28x28.csv':
  146. temp_scale = 0.79
  147. elif test_csv == 'digits.csv':
  148. temp_scale = 1
  149. else:
  150. temp_scale = 0.9
  151.  
  152. temp_re_size = temp_size + 2
  153.  
  154. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  155. img = np.round(255 - img)
  156. M = np.float32([[1,0,1],[0,1,1]])
  157. img = cv2.warpAffine(img, M, (temp_re_size, temp_re_size))
  158. img = np.round(255 - img)
  159.  
  160. ret,thresh = cv2.threshold(img, 50, 255, 0)
  161. if cv_version == '3':
  162. imgEdge,contours,hierarchy = cv2.findContours(thresh, cv2.RETR_LIST,
  163. cv2.CHAIN_APPROX_SIMPLE)
  164. else:
  165. contours,hierarchy = cv2.findContours(thresh, cv2.RETR_LIST,
  166. cv2.CHAIN_APPROX_SIMPLE)
  167.  
  168. x_locale = []
  169. y_locale = []
  170. if len(contours) > 2:
  171. n = len(contours) - 1
  172. else:
  173. n = 1
  174.  
  175. for i in range(0, n):
  176. x,y,w,h = cv2.boundingRect(contours[i])
  177. x_locale.append(x)
  178. x_locale.append(x + w)
  179. y_locale.append(y)
  180. y_locale.append(y + h)
  181. x = min(x_locale)
  182. y = min(y_locale)
  183. w = max(x_locale) - min(x_locale)
  184. h = max(y_locale) - min(y_locale)
  185. img = img[y:y+h, x:x+w]
  186.  
  187. if h >= w:
  188. r_scale_x = int(w * data_len * temp_scale / h)
  189. r_scale_y = int(data_len * temp_scale)
  190. else:
  191. r_scale_x = int(data_len * temp_scale)
  192. r_scale_y = int(h * data_len * temp_scale / w)
  193.  
  194. img = cv2.resize(img, (r_scale_x, r_scale_y), interpolation=cv2.INTER_AREA)
  195. img = np.round(255 - img)
  196.  
  197. if test_csv == '28x28.csv':
  198. M = np.float32([[1,0,1],[0,1,1]])
  199. img = cv2.warpAffine(img, M, (r_scale_x + 2, r_scale_y + 2))
  200.  
  201. M = cv2.moments(img)
  202. cx = int(M['m10']/M['m00']) + 1
  203. cy = int(M['m01']/M['m00']) + 1
  204.  
  205. else:
  206. cx = int(r_scale_x / 2)
  207. cy = int(r_scale_y / 2)
  208.  
  209. print(cx, cy)
  210.  
  211. xc = int(data_len / 2) - cx
  212. yc = int(data_len / 2) - cy
  213. print(xc, yc)
  214. M = np.float32([[1,0,xc],[0,1,yc]])
  215. img = cv2.warpAffine(img, M, size)
  216. img = np.round(img / 16)
  217. return img
  218.  
  219.  
  220. test_clf = clf_select()
  221. if test_clf == 1:
  222. print('SVM')
  223. from sklearn import svm
  224. clf_name = 'svm'
  225. classifier = svm.SVC(C=10, gamma='scale')
  226. elif test_clf == 2:
  227. print('MLPClassifier')
  228. from sklearn.neural_network import MLPClassifier
  229. clf_name = 'mlp'
  230. classifier = MLPClassifier(hidden_layer_sizes=(256, 128, 64, 64, 10),
  231. verbose=True, alpha=0.001,
  232. max_iter=10000, tol=0.00001, random_state=1)
  233. elif test_clf == 3:
  234. print('LogisticRegression')
  235. from sklearn.linear_model import LogisticRegression
  236. clf_name = 'lr'
  237. classifier = LogisticRegression()
  238. elif test_clf == 4:
  239. print('DecisionTreeClassifier')
  240. from sklearn.tree import DecisionTreeClassifier
  241. clf_name = 'dt'
  242. classifier = DecisionTreeClassifier()
  243. else:
  244. sys.exit()
  245.  
  246. test_csv = csv_select()
  247. test_img = img_select()
  248.  
  249. if exists('./' + test_csv + '.' + clf_name + '.pkl'):
  250. y_or_n = ''
  251. #y_or_n = 'no'
  252. else:
  253. y_or_n = 'no'
  254.  
  255. while(True):
  256. digits = load_digits()
  257. n_samples = len(digits.images)
  258. data = digits.images.reshape((n_samples, -1))
  259.  
  260. if y_or_n == 'no':
  261. print(classifier)
  262. classifier.fit(data, digits.target)
  263. joblib.dump(classifier, test_csv + '.' + clf_name + '.pkl', compress=True)
  264. else:
  265. classifier = joblib.load(test_csv + '.' + clf_name + '.pkl')
  266.  
  267. images_and_labels = list(zip(digits.images, digits.target))
  268. for index, (image, label) in enumerate(images_and_labels[:10]):
  269. plt.subplot(3, 5, index + 1)
  270. plt.axis('on')
  271. plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
  272. plt.title('Training %i' % label)
  273.  
  274. img = cv2.imread(test_img)
  275. grayim = pre_mod(img)
  276. im = grayim
  277.  
  278. print(grayim)
  279. grayim = grayim.reshape(1, -1)
  280. np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')
  281.  
  282. predicted = classifier.predict(grayim)
  283. print('判定は、' + str(predicted))
  284.  
  285. plt.subplot(3, 5, 11)
  286. plt.axis('on')
  287. plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
  288. plt.title('Original')
  289. plt.subplot(3, 5, 12)
  290. plt.axis('on')
  291. plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
  292. plt.title('Prodicted %i' % predicted)
  293. plt.show()
  294.  
  295. y_or_n = question0()
  296.  
  297. if y_or_n == 'no':
  298. correct_digit = correct_input()
  299. if correct_digit != None:
  300. grayim = np.append(grayim, correct_digit)
  301. grayim = grayim.reshape(1, -1)
  302.  
  303. f = open(join(module_path, 'data', test_csv), 'a')
  304. np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
  305. f.close()
  306. sleep(1)
  307. else:
  308. correct_digit = str(predicted[0])
  309. else:
  310. next_y_or_n = question1()
  311.  
  312. if next_y_or_n == 'no':
  313. sys.exit()
  314.  
  315. test_img = img_select()
  316.  

2018年8月20日月曜日

scikit-learn おまけの話(おまけ付き)

scikit-learn付属のdigits.csvを使ってみたが、手書き数字の認識をしたいなら、別のデザインの方が良いだろう。

「手書き数字の認識 → scikit-learn」ではなくて、「scikit-learnの演習 → 手書き数字の認識例を使用」
である事を忘れると、勘違いが生じる。

と、書きながらも勘違いのスクリプトをあげる。
前回と同じ動作の、scikit-learnを使った数字認識のスクリプトだが、tkinterを使ってみた。
データのcsvファイルの選択、判定する画像pngファイルの選択ダイアログやメッセージ等がtkinterで表示される。

※追記1(2018/08/21 14:30)※
分類器を保存するようにした。
スクリプトでは分類器の保存ファイルが存在すれば、それを使うようにしているが、
実際問題としては意味が無い

※追記2(2018/08/26 9:30)※
「8x8=64 + ラベル」以外に対応するようにした。
(画像の)縦横比が同じトレーニング用データならば、それに合わせてテスト用画像を
リサイズして判定する。
分類器の保存ファイルは実際の判定では使わないように変更。

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import matplotlib.pyplot as plt
  5. from sklearn import datasets
  6. from sklearn.neural_network import MLPClassifier
  7.  
  8. import numpy as np
  9. import cv2
  10. from time import sleep
  11. from os.path import dirname, join, basename, exists
  12. from sklearn.datasets.base import Bunch
  13. import sys
  14. from sklearn.externals import joblib
  15.  
  16. py_version = sys.version_info.major
  17.  
  18. module_path = dirname(__file__)
  19. #print(dirname(abspath('__file__')))
  20. #print(module_path)
  21.  
  22. if py_version != 2:
  23. import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
  24. tkinter.simpledialog as tksimpledialog
  25. else:
  26. import Tkinter as tk, tkFileDialog as tkfiledialog, tkMessageBox as tkmessage, \
  27. tkSimpleDialog as tksimpledialog
  28.  
  29. def csv_select():
  30. root = tk.Tk()
  31. root.withdraw()
  32.  
  33. file_type = [('', '*csv')]
  34. data_path = join(dirname(__file__), 'data')
  35. tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
  36. test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
  37. if test_csv == () or test_csv == '':
  38. test_csv = 'digits.csv'
  39. else:
  40. test_csv = basename(test_csv)
  41.  
  42. root.destroy()
  43. #print(test_csv)
  44. return test_csv
  45.  
  46. def img_select():
  47. root = tk.Tk()
  48. root.withdraw()
  49.  
  50. file_type = [('', '*png')]
  51. #data_path = dirname(abspath('__file__'))
  52. data_path = dirname(__file__)
  53. tkmessage.showinfo(test_csv, '画像(png)を選択してください')
  54. test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
  55. #print(dirname(abspath('__file__')))
  56. if test_img == () or test_img == '':
  57. test_img = 'digit_test.png'
  58.  
  59. root.destroy()
  60. #print(test_img)
  61. return test_img
  62.  
  63. def question0():
  64. root = tk.Tk()
  65. root.withdraw()
  66.  
  67. y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')
  68.  
  69. root.destroy()
  70. return y_or_n
  71.  
  72. def question1():
  73. root = tk.Tk()
  74. root.withdraw()
  75.  
  76. next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')
  77.  
  78. root.destroy()
  79. return next_y_or_n
  80.  
  81. def correct_input():
  82. root = tk.Tk()
  83. root.withdraw()
  84. root.after(1, lambda: root.focus_force())
  85. correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください',
  86. initialvalue='ここに入力',parent=root)
  87.  
  88. root.destroy()
  89. return correct_digit
  90.  
  91. def load_digits():
  92. module_path = dirname(__file__)
  93. data = np.loadtxt(join(module_path, 'data', test_csv),
  94. delimiter=',')
  95.  
  96. data_len = int(np.sqrt(len(data[0]) - 1))
  97. #print(data_len)
  98.  
  99. target = data[:, -1].astype(np.int)
  100. flat_data = data[:, :-1]
  101. images = flat_data.view()
  102. images.shape = (-1, data_len, data_len)
  103.  
  104. return Bunch(data=flat_data,
  105. target=target,
  106. target_names=np.arange(10),
  107. images=images)
  108.  
  109.  
  110. test_csv = csv_select()
  111. test_img = img_select()
  112.  
  113. if exists('./' + test_csv + '.mlp.pkl'):
  114. #y_or_n = ''
  115. y_or_n = 'no'
  116. else:
  117. y_or_n = 'no'
  118.  
  119. while(True):
  120. digits = load_digits()
  121. n_samples = len(digits.images)
  122. data = digits.images.reshape((n_samples, -1))
  123.  
  124. if y_or_n == 'no':
  125. classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10),
  126. max_iter=10000, tol=0.00001, random_state=1)
  127. print(classifier)
  128. classifier.fit(data, digits.target)
  129. joblib.dump(classifier, test_csv + '.mlp.pkl', compress=True)
  130. else:
  131. classifier = joblib.load(test_csv + '.mlp.pkl')
  132. #print(classifier)
  133.  
  134. images_and_labels = list(zip(digits.images, digits.target))
  135. #plt.figure(figsize=(5.5, 3))
  136. for index, (image, label) in enumerate(images_and_labels[:10]):
  137. plt.subplot(3, 5, index + 1)
  138. plt.axis('off')
  139. plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
  140. plt.title('Training %i' % label)
  141.  
  142. img = cv2.imread(test_img)
  143.  
  144. data_len = int(np.sqrt(len(data[0])))
  145. #print(data_len)
  146.  
  147. size = (data_len, data_len)
  148. im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
  149.  
  150. imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
  151. #print(imgray)
  152. grayim = np.round((255 - imgray) / 16)
  153. grayim = grayim.reshape(1, -1)
  154. print(grayim)
  155. np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')
  156.  
  157. predicted = classifier.predict(grayim)
  158. print('判定は、' + str(predicted))
  159.  
  160. plt.subplot(3, 5, 11)
  161. plt.axis('off')
  162. plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
  163. plt.title('Original')
  164. plt.subplot(3, 5, 12)
  165. plt.axis('off')
  166. plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
  167. plt.title('Prodicted %i' % predicted)
  168. plt.show()
  169.  
  170. y_or_n = question0()
  171.  
  172. if y_or_n == 'no':
  173. correct_digit = correct_input()
  174. if correct_digit != None:
  175. grayim = np.append(grayim, correct_digit)
  176. #print(grayim)
  177. grayim = grayim.reshape(1, -1)
  178.  
  179. f = open(join(module_path, 'data', test_csv), 'a')
  180. np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
  181. f.close()
  182. sleep(1)
  183. else:
  184. correct_digit = str(predicted[0])
  185. else:
  186. next_y_or_n = question1()
  187.  
  188. if next_y_or_n == 'no':
  189. break
  190.  
  191. test_img = img_select()
  192.  

おまけのおまけスクリプト

10個の画像から、トレーニング用csvデータを生成するスクリプト。
スクリプトと同一ディレクトリ内の「0test.png」「1test.png」・・・「9test.png」という
ファイル名の画像ファイルを、順番に読込み「8x8=64 + ラベル」の型でcsvファイルを生成する。
デフォルトでは8x8=64だが、スクリプトの引数に数値を与えるとそれを基にしたデータになる。
例えば、「5」を引数にすると、「5x5=25 + ラベル」の型になる。
(仮にスクリプトのファイル名を「prepare.py」とした場合、「python prepare.py 5」と言った記述)

生成されたcsvは、「5x5.csv」と言うファイル名で同一ディレクトリに保存される。
これをdataディレクトリへ移動して上のスクリプトで読み込めば、そのデータで分類器を生成してテスト
画像を判定する。

※追記3(2018/08/27 11:45)※
ファイル名からラベルの取り方を変更

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import numpy as np
  5. import sys, cv2
  6.  
  7. args = sys.argv
  8. print(args)
  9. if len(args) > 2:
  10. print('引数が多すぎます')
  11. sys.exit()
  12. elif len(args) == 2:
  13. try:
  14. type(int(args[1])) == int
  15. pix = int(args[1])
  16. except ValueError:
  17. print('引数の型が違います')
  18. sys.exit()
  19. else:
  20. print('既定のデータ型を使います')
  21. pix = 8
  22.  
  23. for num in range(0, 10):
  24. try:
  25. img_file = str(num) + 'test.png'
  26. im = cv2.imread(img_file)
  27.  
  28. size = (pix, pix)
  29. im = cv2.resize(im, size, interpolation=cv2.INTER_AREA)
  30.  
  31. imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
  32. print(imgray)
  33. grayim = np.round((255 - imgray) / 16)
  34. grayim = grayim.reshape(1, -1)
  35. grayim = np.array(grayim, dtype = np.uint8)
  36. grayim = np.append(grayim, int(str(num)[-1]))
  37. grayim = grayim.reshape(1, -1)
  38. print(grayim)
  39.  
  40. f = open(str(pix) + 'x' + str(pix) + '.csv', 'a')
  41. np.savetxt(f, grayim, delimiter = ",", fmt = "%.0f")
  42. f.close()
  43.  
  44. #cv2.imshow(str(pix) + 'test', imgray)
  45. #cv2.waitKey(0)
  46. except:
  47. print(str(num) + 'test.png をスキップ')
  48.  

10個より多くのデータを生成したい場合は"range(0, 10)"を書き換える
  1. for num in range(0, 1000):

2018年8月4日土曜日

scikit-learnで漢数字認識


scikit-learnにサンプルで付いてくる手書き数字認識のデータセットを弄ってみるのが目的。
(※画像の処理にOpenCVを使ってます※)
手書き数字のデータ(digits.csv)を書き換えたいので、付属のデータをコピーして使用する。
"pip"でscikit-learnをインストールした私の環境で、データは以下のディレクトリにある。

※Windowsの場合
"C:/Users/<username>/AppData/Local/Programs/Python/Python36-32/lib/site-packages/sklearn/datasets/data/digits.csv.gz"

※Arch Linux 32の場合
"/usr/lib/python3.6/site-packages/sklearn/datasets/data/digits.csv.gz"

予め任意のディレクトリを作成してその中に"data"ディレクトリを作成し、その"data"ディレクトリに
"digits.csv.gz"をコピーする。更にエディタ等で直接扱いたいので、解凍もしておく。
データを置いたディレクトリより一つ上の階層に以下のPythonスクリプトを置く。

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import matplotlib.pyplot as plt
  5. from sklearn import datasets, svm
  6.  
  7. import numpy as np
  8. import cv2
  9. from time import sleep
  10. from os.path import dirname, join
  11. from sklearn.datasets.base import Bunch
  12. import sys
  13.  
  14. y_or_n = 'n'
  15. module_path = dirname(__file__)
  16.  
  17. def load_digits(n_class=10, return_X_y=False):
  18. module_path = dirname(__file__)
  19. data = np.loadtxt(join(module_path, 'data', 'digits.csv'),
  20. delimiter=',')
  21. #with open(join(module_path, 'descr', 'digits.rst')) as f:
  22. # descr = f.read()
  23. target = data[:, -1].astype(np.int)
  24. flat_data = data[:, :-1]
  25. images = flat_data.view()
  26. images.shape = (-1, 8, 8)
  27.  
  28. if n_class < 10:
  29. idx = target < n_class
  30. flat_data, target = flat_data[idx], target[idx]
  31. images = images[idx]
  32.  
  33. if return_X_y:
  34. return flat_data, target
  35.  
  36. return Bunch(data=flat_data,
  37. target=target,
  38. target_names=np.arange(10),
  39. images=images)
  40. #DESCR=descr)
  41.  
  42.  
  43. while(True):
  44. if y_or_n == 'n':
  45. digits = load_digits()
  46.  
  47. n_samples = len(digits.images)
  48. data = digits.images.reshape((n_samples, -1))
  49.  
  50. classifier = svm.SVC(C=100, gamma=0.001)
  51. print(classifier)
  52. #classifier.fit(data[:n_samples], digits.target[:n_samples])
  53. classifier.fit(data, digits.target)
  54.  
  55. images_and_labels = list(zip(digits.images, digits.target))
  56. for index, (image, label) in enumerate(images_and_labels[:10]):
  57. plt.subplot(3, 5, index + 1)
  58. plt.axis('off')
  59. plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
  60. plt.title('Training %i' % label)
  61.  
  62. img = cv2.imread('digit_test.png')
  63.  
  64. size = (8, 8)
  65. im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
  66.  
  67. imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
  68. #print(imgray)
  69. grayim = np.round((255 - imgray) / 16)
  70. grayim = grayim.reshape(1, -1)
  71. print(grayim)
  72. np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')
  73.  
  74. predicted = classifier.predict(grayim)
  75. print('判定は、' + str(predicted))
  76.  
  77. plt.subplot(3, 5, 11)
  78. plt.axis('off')
  79. plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
  80. plt.title('Original')
  81. plt.subplot(3, 5, 12)
  82. plt.axis('off')
  83. plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
  84. plt.title('Prodicted %i' % predicted)
  85. plt.show()
  86.  
  87. if sys.version_info.major != 2:
  88. y_or_n = input('正しいですか?(y or n)')
  89. else:
  90. y_or_n = raw_input('正しいですか?(y or n)')
  91.  
  92. if y_or_n == 'n':
  93. correct_digit = input('正しい数字を教えてください:')
  94.  
  95. grayim = np.append(grayim, int(correct_digit))
  96. grayim = grayim.reshape(1, -1)
  97.  
  98. f = open(join(module_path, 'data', 'digits.csv'), 'a')
  99. np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
  100. f.close()
  101. sleep(0.5)
  102. else:
  103. if sys.version_info.major != 2:
  104. next_y_or_n = input('続けますか?(y or n)')
  105. else:
  106. next_y_or_n = raw_input('続けますか?(y or n)')
  107.  
  108. if next_y_or_n == 'n':
  109. break
  110.  
"digits.csv"には8x8のサイズで0~9の数字の画像データを、0~16の階調で格納し、最後にその画像が
示す数字の正解ラベルが付加されている。
 1行に64個(8x8)の値と、1個の正解ラベルで、65個の値を持ち、全部で1796行になる。
「scikit-learn 手書き 数字」等で検索すると、1796のデータを学習用と分類用に分けて、その結果を
 示すサンプルが多く見受けられる。
このスクリプトでは、"digits.csv"のデータをすべて読み込んでサポートベクトルマシンによる分類器を
用意し、同じディレクトリ内の画像データ"digit_test.png"を読み込んで分類器を使った判定を行う。

画像"digit_test.png"は8x8ピクセルである必要はなく、例えば100x100でも良いし、80x100とかでも
8x8へ変換するようになっている。

スクリプトを実行させると、10個の凡例と手書き画像のオリジナル・8x8変換後の判定が表示される。


一旦その表示ウィンドウを閉じると、「正しいですか?」と聞いてくるので答えを入力する。


「n」を入力すると、「正しい数字を教えてください」と聞いてくるので正解を入力する。
読み込んだデータに正解ラベルを付加した状態でデータに追加し、再度分類器を生成して判定を やり直す。
一度で駄目でも、何回か繰り返せば正解の判定を導き出すようになる。
「続けますか?」に答える前に、次に判定したい画像を"digit_test.png"として用意しておけば、
新たな画像を判定する。でないと、同じ画像の判定を繰り返すだけになるので注意。

 ただ重要なのは、数字を理解して判定している訳ではなく、64個の値の組み合わせに正解を
紐付けているだけだと言うこと。

だから、アラビア数字じゃなく漢数字の画像データでも良いはず!

漢数字の画像を用意して、データの追加を何回か繰り替えして行けば、漢数字での判定が出きるようになる。

 (要は、正解ラベルで意味づけをしているのは人間なので、数字である必要も、画像データである必要も無い)  

おまけ(漢数字用データ)
  1. 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
  2. 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
  3. 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
  4. 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
  5. 0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3
  6. 0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3
  7. 0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4
  8. 0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4
  9. 0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5
  10. 0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5
  11. 0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6
  12. 0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6
  13. 0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7
  14. 0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7
  15. 0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8
  16. 0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8
  17. 0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9
  18. 0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9
  19.  

MLPClassifier使用版
※追記1(2018/08/14 16:20)※
Python2にも対応したつもり
※追記2(2018/08/14 19:40)※
バグ取り

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import matplotlib.pyplot as plt
  5. from sklearn import datasets
  6. from sklearn.neural_network import MLPClassifier
  7.  
  8. import numpy as np
  9. import cv2
  10. from time import sleep
  11. from os.path import dirname, join
  12. from sklearn.datasets.base import Bunch
  13. import sys
  14.  
  15. y_or_n = 'n'
  16. module_path = dirname(__file__)
  17.  
  18. def load_digits():
  19. module_path = dirname(__file__)
  20. data = np.loadtxt(join(module_path, 'data', 'digits.csv'),
  21. delimiter=',')
  22.  
  23. target = data[:, -1].astype(np.int)
  24. flat_data = data[:, :-1]
  25. images = flat_data.view()
  26. images.shape = (-1, 8, 8)
  27.  
  28. return Bunch(data=flat_data,
  29. target=target,
  30. target_names=np.arange(10),
  31. images=images)
  32.  
  33.  
  34. while(True):
  35. if y_or_n == 'n':
  36. digits = load_digits()
  37. n_samples = len(digits.images)
  38. data = digits.images.reshape((n_samples, -1))
  39.  
  40. classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10), max_iter=10000, tol=0.00001, random_state=1)
  41. print(classifier)
  42. classifier.fit(data, digits.target)
  43.  
  44. images_and_labels = list(zip(digits.images, digits.target))
  45. for index, (image, label) in enumerate(images_and_labels[:10]):
  46. plt.subplot(3, 5, index + 1)
  47. plt.axis('off')
  48. plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
  49. plt.title('Training %i' % label)
  50.  
  51. img = cv2.imread('digit_test.png')
  52.  
  53. size = (8, 8)
  54. im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
  55.  
  56. imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
  57. #print(imgray)
  58. grayim = np.round((255 - imgray) / 16)
  59. grayim = grayim.reshape(1, -1)
  60. print(grayim)
  61. np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')
  62.  
  63. predicted = classifier.predict(grayim)
  64. print('判定は、' + str(predicted))
  65.  
  66. plt.subplot(3, 5, 11)
  67. plt.axis('off')
  68. plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
  69. plt.title('Original')
  70. plt.subplot(3, 5, 12)
  71. plt.axis('off')
  72. plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
  73. plt.title('Prodicted %i' % predicted)
  74. plt.show()
  75.  
  76. if sys.version_info.major != 2:
  77. y_or_n = input('正しいですか?(y or n)')
  78. else:
  79. y_or_n = raw_input('正しいですか?(y or n)')
  80.  
  81. if y_or_n == 'n':
  82. correct_digit = input('正しい数字を教えてください:')
  83.  
  84. grayim = np.append(grayim, int(correct_digit))
  85. #print(grayim)
  86. grayim = grayim.reshape(1, -1)
  87.  
  88. f = open(join(module_path, 'data', 'digits.csv'), 'a')
  89. np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
  90. f.close()
  91. sleep(1)
  92. else:
  93. if sys.version_info.major != 2:
  94. next_y_or_n = input('続けますか?(y or n)')
  95. else:
  96. next_y_or_n = raw_input('続けますか?(y or n)')
  97.  
  98. if next_y_or_n == 'n':
  99. break
  100.  

2018年5月3日木曜日

ServoBlaster が更新されました

去年こちらに書いたServoBlasterの問題が修正されていました。

 https://github.com/richardghirst/PiBits/tree/master/ServoBlaster

kernel 4.9以降に"/proc/cpuinfo"の"Hardware"情報を元に"board_model"を判定する方法
(実際は、"board_model"からgpioのベースになるアドレスを設定する方法)に不具合が
出ていましたが、こちらの"bcm_host_get_peripheral_address()" を使う方法になっています。

手持ちのRespberry Pi 2 Model Bでの動作確認ができました。

2018年2月17日土曜日

OpenCVで遊ぶ

OpenCVのテンプレートマッチングを少し試していて思いついた。
テンプレートの検出画面をターゲティングモニタ風にしてみた。

2018年2月11日日曜日

動画でOpenCV(テンプレートマッチング)

WebカメラLogicool c270を買ったので 「OpenCVのテンプレートマッチングを試す 」
動画でやってみた。

今回は、OpenCVのチュートリアル「動画を扱う」を参考にした。


テンプレート画像として、"temp100.png" "temp200.png" "temp300.png"を用意。
Logicool c270から読み込んだ画像を処理する。
Raspberry Pi 2 Model Bの処理を考えて、サイズを320x240、フレームレート15を設定。

※注意※
自分の環境では1回目の実行は問題ないが、2回目の実行で"cap = cv2.VideoCapture(0)"が
正常に行われないようなので、"cap.get(cv2.CAP_PROP_FPS) == 0"で判定しながら
繰り返し処理を入れてあります。

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import cv2
  5. import numpy as np
  6. import tkinter
  7.  
  8. cap = cv2.VideoCapture(0)
  9.  
  10. while(cap.get(cv2.CAP_PROP_FPS) == 0):
  11. cap.release()
  12. cap = cv2.VideoCapture(0)
  13.  
  14. cap.set(3, 320)
  15. cap.set(4, 240)
  16. cap.set(5, 15.0)
  17.  
  18. print(cap.get(3))
  19. print(cap.get(4))
  20. print(cap.get(5))
  21.  
  22. grid_x = (25, 115, 205, 295)
  23. grid_y = (45, 95, 145, 195)
  24.  
  25. fourcc = cv2.VideoWriter_fourcc(*"XVID")
  26. out = cv2.VideoWriter('output.avi', fourcc, 15.0, (320,240))
  27.  
  28. while(True):
  29. res_100 = "???"
  30. res_200 = "???"
  31. res_300 = "???"
  32. text ="???"
  33.  
  34. # Capture frame-by-frame
  35. ret, frame = cap.read()
  36. img = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  37.  
  38. for temp_num in [100, 200, 300]:
  39. temp = cv2.imread("temp" + str(temp_num) + ".png", 0)
  40. result = cv2.matchTemplate(img, temp, cv2.TM_CCOEFF_NORMED)
  41. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  42.  
  43. if max_val > 0.8:
  44. b = 0
  45. g = 0
  46. r = 0
  47. if temp_num == 100:
  48. b = 255
  49. if temp_num == 200:
  50. g = 255
  51. if temp_num == 300:
  52. r = 255
  53.  
  54. top_left = max_loc
  55. w, h = temp.shape[::-1]
  56. bottom_right = (top_left[0] + w, top_left[1] + h)
  57.  
  58. res_x = []
  59. res_y = []
  60.  
  61. for i in [0, 1, 2]:
  62. if top_left[0] >= grid_x[i] and top_left[0] < grid_x[i + 1]:
  63. if (i + 1 in res_x) == False:
  64. res_x.append(i + 1)
  65. for n in [0, 1, 2]:
  66. if top_left[1] >= grid_y[n] and top_left[1] < grid_y[n + 1]:
  67. if (n + 1 in res_y) == False:
  68. res_y.append(n + 1)
  69.  
  70. #print("temp" + str(temp_num) + ".png = " + str(res_x[0]) + "-" + str(res_y[0]))
  71.  
  72. if temp_num == 100:
  73. res_100 = str(res_x[0]) + "-" + str(res_y[0])
  74. if temp_num == 200:
  75. res_200 = str(res_x[0]) + "-" + str(res_y[0])
  76. if temp_num == 300:
  77. res_300 = str(res_x[0]) + "-" + str(res_y[0])
  78.  
  79. text = str(res_x[0]) + "-" + str(res_y[0]) + "(" + str("{:.3}".format(max_val)) + ")"
  80.  
  81. cv2.rectangle(frame, top_left, bottom_right, (b, g, r), 2)
  82. cv2.putText(frame, text, (top_left[0] - 5, top_left[1] - 5), cv2.FONT_HERSHEY_PLAIN, 0.8, (b, g, r), 1, cv2.LINE_AA)
  83.  
  84. # Display the resulting frame
  85. # write the frame
  86. out.write(frame)
  87.  
  88. cv2.imshow("frame", frame)
  89.  
  90. key = cv2.waitKey(1)&0xff
  91. if key == ord("s"):
  92.  
  93. label = tkinter.Label(None, text = "1 = " + res_100 + "\n2 = " + res_200 + "\n3 = " + res_300, font=("Times", "28"))
  94. label.pack()
  95. label.mainloop()
  96.  
  97. cv2.imwrite("capture.png", frame)
  98. #break
  99.  
  100. if key == ord("q"):
  101. break
  102.  
  103. # When everything done, release the capture
  104. cap.release()
  105. out.release()
  106. cv2.destroyAllWindows()
  107.  

上のコードでは、"s"キーが押された時に"tkinter"を使って検出対象の位置情報を表示させているが、
これは一例なので、検出結果から「ここで何かする」って意味になる。
更に別スレッド、別プロセスにすれば動画(画像)の取得を止めること無く「何かする」
が実行出きるので、実際の運用はそう言った使い方になる筈。

2018年1月4日木曜日

OpenCVのテンプレートマッチングを試す

以前にOpenCVを使って顔検出をやってみたが、今回はテンプレートマッチングで、
簡易的に数字を検出してみる。
  
下のページを参考にさせてもらいました。有り難うございました。
http://www.tech-tech.xyz/archives/3065942.html

本当は画像の中から文字を検出したかったのだが、意外と面倒そうだった。
手書きの文字やフォント不明の文字を認識したいわけではなく、予め用意した数字と
同じものさえ検出できればよかったので、テンプレート画像との一致を見つけることで
良しとした。

先ず、テンプレート画像を用意する。

上の画像から、検出したい数字の部分を切り取ってテンプレート画像を作る。
これら「1」「2」「3」のテンプレートを使って検出を行う。
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import cv2
  5. import shutil
  6.  
  7. shutil.copyfile("test.jpg","result.jpg")
  8.  
  9. img = cv2.imread("result.jpg", 0)
  10.  
  11. for temp_num in [1, 2, 3]:
  12. temp = cv2.imread("temp" + str(temp_num) + ".png", 0)
  13.  
  14. result = cv2.matchTemplate(img, temp, cv2.TM_CCOEFF_NORMED)
  15.  
  16. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  17.  
  18. if max_val > 0.8:
  19. top_left = max_loc
  20. w, h = temp.shape[::-1]
  21. bottom_right = (top_left[0] + w, top_left[1] + h)
  22.  
  23. result = cv2.imread("result.jpg")
  24. cv2.rectangle(result,top_left, bottom_right, (255, 0, 0), 2)
  25. cv2.imwrite("result.jpg", result)

※追記1(2018/01/04 20:20)※
  誤検出を減らすように、テンプレートマッチングの結果から
"max_val"(グレースケールで最も高い輝度の値)が0.8より大きい時だけ 
枠表示をするように変更しました。
何回か試したところ、正常に検出できた場合の"max_val"は、ほぼ0.9以上。
誤検出の場合は、0.7以下でした。
そこで「エイヤッ」と、0.8を設定しました。


今回は検出する画像がjpg、テンプレート画像がpngという変則コードになってしまったが
実行できれば問題ない(本当かよ?)

実際のカードを作って、それらを並べ替えてやってみた。
元画像を印刷した後、切り取って数字カードを作る。

それらのカードを適当に並べ替えて検出してみた。

更に、枠を外してやってみた。

<結果>
・影があったり、余分な物が写っている画像でも、正常に検出ができた。

<問題点>
・検出する画像の数字部分とテンプレート画像のサイズが(大体)合っていないと、
以下のようになる。
間違った検出。画像の数字に対してテンプレートが小さい
 ・検出する画像が大きいと時間が掛かる。