2018年8月20日月曜日

scikit-learn おまけの話(おまけ付き)

scikit-learn付属のdigits.csvを使ってみたが、手書き数字の認識をしたいなら、別のデザインの方が良いだろう。

「手書き数字の認識 → scikit-learn」ではなくて、「scikit-learnの演習 → 手書き数字の認識例を使用」
である事を忘れると、勘違いが生じる。

と、書きながらも勘違いのスクリプトをあげる。
前回と同じ動作の、scikit-learnを使った数字認識のスクリプトだが、tkinterを使ってみた。
データのcsvファイルの選択、判定する画像pngファイルの選択ダイアログやメッセージ等がtkinterで表示される。

※追記1(2018/08/21 14:30)※
分類器を保存するようにした。
スクリプトでは分類器の保存ファイルが存在すれば、それを使うようにしているが、
実際問題としては意味が無い

※追記2(2018/08/26 9:30)※
「8x8=64 + ラベル」以外に対応するようにした。
(画像の)縦横比が同じトレーニング用データならば、それに合わせてテスト用画像を
リサイズして判定する。
分類器の保存ファイルは実際の判定では使わないように変更。

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import matplotlib.pyplot as plt
  5. from sklearn import datasets
  6. from sklearn.neural_network import MLPClassifier
  7.  
  8. import numpy as np
  9. import cv2
  10. from time import sleep
  11. from os.path import dirname, join, basename, exists
  12. from sklearn.datasets.base import Bunch
  13. import sys
  14. from sklearn.externals import joblib
  15.  
  16. py_version = sys.version_info.major
  17.  
  18. module_path = dirname(__file__)
  19. #print(dirname(abspath('__file__')))
  20. #print(module_path)
  21.  
  22. if py_version != 2:
  23. import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
  24. tkinter.simpledialog as tksimpledialog
  25. else:
  26. import Tkinter as tk, tkFileDialog as tkfiledialog, tkMessageBox as tkmessage, \
  27. tkSimpleDialog as tksimpledialog
  28.  
  29. def csv_select():
  30. root = tk.Tk()
  31. root.withdraw()
  32.  
  33. file_type = [('', '*csv')]
  34. data_path = join(dirname(__file__), 'data')
  35. tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
  36. test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
  37. if test_csv == () or test_csv == '':
  38. test_csv = 'digits.csv'
  39. else:
  40. test_csv = basename(test_csv)
  41.  
  42. root.destroy()
  43. #print(test_csv)
  44. return test_csv
  45.  
  46. def img_select():
  47. root = tk.Tk()
  48. root.withdraw()
  49.  
  50. file_type = [('', '*png')]
  51. #data_path = dirname(abspath('__file__'))
  52. data_path = dirname(__file__)
  53. tkmessage.showinfo(test_csv, '画像(png)を選択してください')
  54. test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
  55. #print(dirname(abspath('__file__')))
  56. if test_img == () or test_img == '':
  57. test_img = 'digit_test.png'
  58.  
  59. root.destroy()
  60. #print(test_img)
  61. return test_img
  62.  
  63. def question0():
  64. root = tk.Tk()
  65. root.withdraw()
  66.  
  67. y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')
  68.  
  69. root.destroy()
  70. return y_or_n
  71.  
  72. def question1():
  73. root = tk.Tk()
  74. root.withdraw()
  75.  
  76. next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')
  77.  
  78. root.destroy()
  79. return next_y_or_n
  80.  
  81. def correct_input():
  82. root = tk.Tk()
  83. root.withdraw()
  84. root.after(1, lambda: root.focus_force())
  85. correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください',
  86. initialvalue='ここに入力',parent=root)
  87.  
  88. root.destroy()
  89. return correct_digit
  90.  
  91. def load_digits():
  92. module_path = dirname(__file__)
  93. data = np.loadtxt(join(module_path, 'data', test_csv),
  94. delimiter=',')
  95.  
  96. data_len = int(np.sqrt(len(data[0]) - 1))
  97. #print(data_len)
  98.  
  99. target = data[:, -1].astype(np.int)
  100. flat_data = data[:, :-1]
  101. images = flat_data.view()
  102. images.shape = (-1, data_len, data_len)
  103.  
  104. return Bunch(data=flat_data,
  105. target=target,
  106. target_names=np.arange(10),
  107. images=images)
  108.  
  109.  
  110. test_csv = csv_select()
  111. test_img = img_select()
  112.  
  113. if exists('./' + test_csv + '.mlp.pkl'):
  114. #y_or_n = ''
  115. y_or_n = 'no'
  116. else:
  117. y_or_n = 'no'
  118.  
  119. while(True):
  120. digits = load_digits()
  121. n_samples = len(digits.images)
  122. data = digits.images.reshape((n_samples, -1))
  123.  
  124. if y_or_n == 'no':
  125. classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10),
  126. max_iter=10000, tol=0.00001, random_state=1)
  127. print(classifier)
  128. classifier.fit(data, digits.target)
  129. joblib.dump(classifier, test_csv + '.mlp.pkl', compress=True)
  130. else:
  131. classifier = joblib.load(test_csv + '.mlp.pkl')
  132. #print(classifier)
  133.  
  134. images_and_labels = list(zip(digits.images, digits.target))
  135. #plt.figure(figsize=(5.5, 3))
  136. for index, (image, label) in enumerate(images_and_labels[:10]):
  137. plt.subplot(3, 5, index + 1)
  138. plt.axis('off')
  139. plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
  140. plt.title('Training %i' % label)
  141.  
  142. img = cv2.imread(test_img)
  143.  
  144. data_len = int(np.sqrt(len(data[0])))
  145. #print(data_len)
  146.  
  147. size = (data_len, data_len)
  148. im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
  149.  
  150. imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
  151. #print(imgray)
  152. grayim = np.round((255 - imgray) / 16)
  153. grayim = grayim.reshape(1, -1)
  154. print(grayim)
  155. np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')
  156.  
  157. predicted = classifier.predict(grayim)
  158. print('判定は、' + str(predicted))
  159.  
  160. plt.subplot(3, 5, 11)
  161. plt.axis('off')
  162. plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
  163. plt.title('Original')
  164. plt.subplot(3, 5, 12)
  165. plt.axis('off')
  166. plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
  167. plt.title('Prodicted %i' % predicted)
  168. plt.show()
  169.  
  170. y_or_n = question0()
  171.  
  172. if y_or_n == 'no':
  173. correct_digit = correct_input()
  174. if correct_digit != None:
  175. grayim = np.append(grayim, correct_digit)
  176. #print(grayim)
  177. grayim = grayim.reshape(1, -1)
  178.  
  179. f = open(join(module_path, 'data', test_csv), 'a')
  180. np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
  181. f.close()
  182. sleep(1)
  183. else:
  184. correct_digit = str(predicted[0])
  185. else:
  186. next_y_or_n = question1()
  187.  
  188. if next_y_or_n == 'no':
  189. break
  190.  
  191. test_img = img_select()
  192.  

おまけのおまけスクリプト

10個の画像から、トレーニング用csvデータを生成するスクリプト。
スクリプトと同一ディレクトリ内の「0test.png」「1test.png」・・・「9test.png」という
ファイル名の画像ファイルを、順番に読込み「8x8=64 + ラベル」の型でcsvファイルを生成する。
デフォルトでは8x8=64だが、スクリプトの引数に数値を与えるとそれを基にしたデータになる。
例えば、「5」を引数にすると、「5x5=25 + ラベル」の型になる。
(仮にスクリプトのファイル名を「prepare.py」とした場合、「python prepare.py 5」と言った記述)

生成されたcsvは、「5x5.csv」と言うファイル名で同一ディレクトリに保存される。
これをdataディレクトリへ移動して上のスクリプトで読み込めば、そのデータで分類器を生成してテスト
画像を判定する。

※追記3(2018/08/27 11:45)※
ファイル名からラベルの取り方を変更

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import numpy as np
  5. import sys, cv2
  6.  
  7. args = sys.argv
  8. print(args)
  9. if len(args) > 2:
  10. print('引数が多すぎます')
  11. sys.exit()
  12. elif len(args) == 2:
  13. try:
  14. type(int(args[1])) == int
  15. pix = int(args[1])
  16. except ValueError:
  17. print('引数の型が違います')
  18. sys.exit()
  19. else:
  20. print('既定のデータ型を使います')
  21. pix = 8
  22.  
  23. for num in range(0, 10):
  24. try:
  25. img_file = str(num) + 'test.png'
  26. im = cv2.imread(img_file)
  27.  
  28. size = (pix, pix)
  29. im = cv2.resize(im, size, interpolation=cv2.INTER_AREA)
  30.  
  31. imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
  32. print(imgray)
  33. grayim = np.round((255 - imgray) / 16)
  34. grayim = grayim.reshape(1, -1)
  35. grayim = np.array(grayim, dtype = np.uint8)
  36. grayim = np.append(grayim, int(str(num)[-1]))
  37. grayim = grayim.reshape(1, -1)
  38. print(grayim)
  39.  
  40. f = open(str(pix) + 'x' + str(pix) + '.csv', 'a')
  41. np.savetxt(f, grayim, delimiter = ",", fmt = "%.0f")
  42. f.close()
  43.  
  44. #cv2.imshow(str(pix) + 'test', imgray)
  45. #cv2.waitKey(0)
  46. except:
  47. print(str(num) + 'test.png をスキップ')
  48.  

10個より多くのデータを生成したい場合は"range(0, 10)"を書き換える
  1. for num in range(0, 1000):

0 件のコメント :

コメントを投稿