「手書き数字の認識 → scikit-learn」ではなくて、「scikit-learnの演習 → 手書き数字の認識例を使用」
である事を忘れると、勘違いが生じる。
と、書きながらも勘違いのスクリプトをあげる。
前回と同じ動作の、scikit-learnを使った数字認識のスクリプトだが、tkinterを使ってみた。
データのcsvファイルの選択、判定する画像pngファイルの選択ダイアログやメッセージ等がtkinterで表示される。
※追記1(2018/08/21 14:30)※
分類器を保存するようにした。
スクリプトでは分類器の保存ファイルが存在すれば、それを使うようにしているが、
実際問題としては意味が無い。
※追記2(2018/08/26 9:30)※
「8x8=64 + ラベル」以外に対応するようにした。
(画像の)縦横比が同じトレーニング用データならば、それに合わせてテスト用画像を
リサイズして判定する。
分類器の保存ファイルは実際の判定では使わないように変更。
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- import matplotlib.pyplot as plt
- from sklearn import datasets
- from sklearn.neural_network import MLPClassifier
- import numpy as np
- import cv2
- from time import sleep
- from os.path import dirname, join, basename, exists
- from sklearn.datasets.base import Bunch
- import sys
- from sklearn.externals import joblib
- py_version = sys.version_info.major
- module_path = dirname(__file__)
- #print(dirname(abspath('__file__')))
- #print(module_path)
- if py_version != 2:
- import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
- tkinter.simpledialog as tksimpledialog
- else:
- import Tkinter as tk, tkFileDialog as tkfiledialog, tkMessageBox as tkmessage, \
- tkSimpleDialog as tksimpledialog
- def csv_select():
- root = tk.Tk()
- root.withdraw()
- file_type = [('', '*csv')]
- data_path = join(dirname(__file__), 'data')
- tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
- test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
- if test_csv == () or test_csv == '':
- test_csv = 'digits.csv'
- else:
- test_csv = basename(test_csv)
- root.destroy()
- #print(test_csv)
- return test_csv
- def img_select():
- root = tk.Tk()
- root.withdraw()
- file_type = [('', '*png')]
- #data_path = dirname(abspath('__file__'))
- data_path = dirname(__file__)
- tkmessage.showinfo(test_csv, '画像(png)を選択してください')
- test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
- #print(dirname(abspath('__file__')))
- if test_img == () or test_img == '':
- test_img = 'digit_test.png'
- root.destroy()
- #print(test_img)
- return test_img
- def question0():
- root = tk.Tk()
- root.withdraw()
- y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')
- root.destroy()
- return y_or_n
- def question1():
- root = tk.Tk()
- root.withdraw()
- next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')
- root.destroy()
- return next_y_or_n
- def correct_input():
- root = tk.Tk()
- root.withdraw()
- root.after(1, lambda: root.focus_force())
- correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください',
- initialvalue='ここに入力',parent=root)
- root.destroy()
- return correct_digit
- def load_digits():
- module_path = dirname(__file__)
- data = np.loadtxt(join(module_path, 'data', test_csv),
- delimiter=',')
- data_len = int(np.sqrt(len(data[0]) - 1))
- #print(data_len)
- target = data[:, -1].astype(np.int)
- flat_data = data[:, :-1]
- images = flat_data.view()
- images.shape = (-1, data_len, data_len)
- return Bunch(data=flat_data,
- target=target,
- target_names=np.arange(10),
- images=images)
- test_csv = csv_select()
- test_img = img_select()
- if exists('./' + test_csv + '.mlp.pkl'):
- #y_or_n = ''
- y_or_n = 'no'
- else:
- y_or_n = 'no'
- while(True):
- digits = load_digits()
- n_samples = len(digits.images)
- data = digits.images.reshape((n_samples, -1))
- if y_or_n == 'no':
- classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10),
- max_iter=10000, tol=0.00001, random_state=1)
- print(classifier)
- classifier.fit(data, digits.target)
- joblib.dump(classifier, test_csv + '.mlp.pkl', compress=True)
- else:
- classifier = joblib.load(test_csv + '.mlp.pkl')
- #print(classifier)
- images_and_labels = list(zip(digits.images, digits.target))
- #plt.figure(figsize=(5.5, 3))
- for index, (image, label) in enumerate(images_and_labels[:10]):
- plt.subplot(3, 5, index + 1)
- plt.axis('off')
- plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
- plt.title('Training %i' % label)
- img = cv2.imread(test_img)
- data_len = int(np.sqrt(len(data[0])))
- #print(data_len)
- size = (data_len, data_len)
- im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
- imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
- #print(imgray)
- grayim = np.round((255 - imgray) / 16)
- grayim = grayim.reshape(1, -1)
- print(grayim)
- np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')
- predicted = classifier.predict(grayim)
- print('判定は、' + str(predicted))
- plt.subplot(3, 5, 11)
- plt.axis('off')
- plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
- plt.title('Original')
- plt.subplot(3, 5, 12)
- plt.axis('off')
- plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
- plt.title('Prodicted %i' % predicted)
- plt.show()
- y_or_n = question0()
- if y_or_n == 'no':
- correct_digit = correct_input()
- if correct_digit != None:
- grayim = np.append(grayim, correct_digit)
- #print(grayim)
- grayim = grayim.reshape(1, -1)
- f = open(join(module_path, 'data', test_csv), 'a')
- np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
- f.close()
- sleep(1)
- else:
- correct_digit = str(predicted[0])
- else:
- next_y_or_n = question1()
- if next_y_or_n == 'no':
- break
- test_img = img_select()
おまけのおまけスクリプト
10個の画像から、トレーニング用csvデータを生成するスクリプト。スクリプトと同一ディレクトリ内の「0test.png」「1test.png」・・・「9test.png」という
ファイル名の画像ファイルを、順番に読込み「8x8=64 + ラベル」の型でcsvファイルを生成する。
デフォルトでは8x8=64だが、スクリプトの引数に数値を与えるとそれを基にしたデータになる。
例えば、「5」を引数にすると、「5x5=25 + ラベル」の型になる。
(仮にスクリプトのファイル名を「prepare.py」とした場合、「python prepare.py 5」と言った記述)
生成されたcsvは、「5x5.csv」と言うファイル名で同一ディレクトリに保存される。
これをdataディレクトリへ移動して上のスクリプトで読み込めば、そのデータで分類器を生成してテスト
画像を判定する。
※追記3(2018/08/27 11:45)※
ファイル名からラベルの取り方を変更
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- import numpy as np
- import sys, cv2
- args = sys.argv
- print(args)
- if len(args) > 2:
- print('引数が多すぎます')
- sys.exit()
- elif len(args) == 2:
- try:
- type(int(args[1])) == int
- pix = int(args[1])
- except ValueError:
- print('引数の型が違います')
- sys.exit()
- else:
- print('既定のデータ型を使います')
- pix = 8
- for num in range(0, 10):
- try:
- img_file = str(num) + 'test.png'
- im = cv2.imread(img_file)
- size = (pix, pix)
- im = cv2.resize(im, size, interpolation=cv2.INTER_AREA)
- imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
- print(imgray)
- grayim = np.round((255 - imgray) / 16)
- grayim = grayim.reshape(1, -1)
- grayim = np.array(grayim, dtype = np.uint8)
- grayim = np.append(grayim, int(str(num)[-1]))
- grayim = grayim.reshape(1, -1)
- print(grayim)
- f = open(str(pix) + 'x' + str(pix) + '.csv', 'a')
- np.savetxt(f, grayim, delimiter = ",", fmt = "%.0f")
- f.close()
- #cv2.imshow(str(pix) + 'test', imgray)
- #cv2.waitKey(0)
- except:
- print(str(num) + 'test.png をスキップ')
10個より多くのデータを生成したい場合は"range(0, 10)"を書き換える
- for num in range(0, 1000):