scikit-learnにサンプルで付いてくる手書き数字認識のデータセットを弄ってみるのが目的。
(※画像の処理にOpenCVを使ってます※)
手書き数字のデータ(digits.csv)を書き換えたいので、付属のデータをコピーして使用する。
"pip"でscikit-learnをインストールした私の環境で、データは以下のディレクトリにある。
※Windowsの場合
"C:/Users/<username>/AppData/Local/Programs/Python/Python36-32/lib/site-packages/sklearn/datasets/data/digits.csv.gz"
※Arch Linux 32の場合
"/usr/lib/python3.6/site-packages/sklearn/datasets/data/digits.csv.gz"
予め任意のディレクトリを作成してその中に"data"ディレクトリを作成し、その"data"ディレクトリに
"digits.csv.gz"をコピーする。更にエディタ等で直接扱いたいので、解凍もしておく。
データを置いたディレクトリより一つ上の階層に以下のPythonスクリプトを置く。
#!/usr/bin/env python # -*- coding: utf-8 -*- import matplotlib.pyplot as plt from sklearn import datasets, svm import numpy as np import cv2 from time import sleep from os.path import dirname, join from sklearn.datasets.base import Bunch import sys y_or_n = 'n' module_path = dirname(__file__) def load_digits(n_class=10, return_X_y=False): module_path = dirname(__file__) data = np.loadtxt(join(module_path, 'data', 'digits.csv'), delimiter=',') #with open(join(module_path, 'descr', 'digits.rst')) as f: # descr = f.read() target = data[:, -1].astype(np.int) flat_data = data[:, :-1] images = flat_data.view() images.shape = (-1, 8, 8) if n_class < 10: idx = target < n_class flat_data, target = flat_data[idx], target[idx] images = images[idx] if return_X_y: return flat_data, target return Bunch(data=flat_data, target=target, target_names=np.arange(10), images=images) #DESCR=descr) while(True): if y_or_n == 'n': digits = load_digits() n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1)) classifier = svm.SVC(C=100, gamma=0.001) print(classifier) #classifier.fit(data[:n_samples], digits.target[:n_samples]) classifier.fit(data, digits.target) images_and_labels = list(zip(digits.images, digits.target)) for index, (image, label) in enumerate(images_and_labels[:10]): plt.subplot(3, 5, index + 1) plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Training %i' % label) img = cv2.imread('digit_test.png') size = (8, 8) im = cv2.resize(img, size, interpolation=cv2.INTER_AREA) imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY) #print(imgray) grayim = np.round((255 - imgray) / 16) grayim = grayim.reshape(1, -1) print(grayim) np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f') predicted = classifier.predict(grayim) print('判定は、' + str(predicted)) plt.subplot(3, 5, 11) plt.axis('off') plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Original') plt.subplot(3, 5, 12) plt.axis('off') plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Prodicted %i' % predicted) plt.show() if sys.version_info.major != 2: y_or_n = input('正しいですか?(y or n)') else: y_or_n = raw_input('正しいですか?(y or n)') if y_or_n == 'n': correct_digit = input('正しい数字を教えてください:') grayim = np.append(grayim, int(correct_digit)) grayim = grayim.reshape(1, -1) f = open(join(module_path, 'data', 'digits.csv'), 'a') np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f') f.close() sleep(0.5) else: if sys.version_info.major != 2: next_y_or_n = input('続けますか?(y or n)') else: next_y_or_n = raw_input('続けますか?(y or n)') if next_y_or_n == 'n': break"digits.csv"には8x8のサイズで0~9の数字の画像データを、0~16の階調で格納し、最後にその画像が
示す数字の正解ラベルが付加されている。
1行に64個(8x8)の値と、1個の正解ラベルで、65個の値を持ち、全部で1796行になる。
「scikit-learn 手書き 数字」等で検索すると、1796のデータを学習用と分類用に分けて、その結果を
示すサンプルが多く見受けられる。
このスクリプトでは、"digits.csv"のデータをすべて読み込んでサポートベクトルマシンによる分類器を
用意し、同じディレクトリ内の画像データ"digit_test.png"を読み込んで分類器を使った判定を行う。
画像"digit_test.png"は8x8ピクセルである必要はなく、例えば100x100でも良いし、80x100とかでも
8x8へ変換するようになっている。
スクリプトを実行させると、10個の凡例と手書き画像のオリジナル・8x8変換後の判定が表示される。
一旦その表示ウィンドウを閉じると、「正しいですか?」と聞いてくるので答えを入力する。
「n」を入力すると、「正しい数字を教えてください」と聞いてくるので正解を入力する。
読み込んだデータに正解ラベルを付加した状態でデータに追加し、再度分類器を生成して判定を やり直す。
一度で駄目でも、何回か繰り返せば正解の判定を導き出すようになる。
「続けますか?」に答える前に、次に判定したい画像を"digit_test.png"として用意しておけば、
新たな画像を判定する。でないと、同じ画像の判定を繰り返すだけになるので注意。
ただ重要なのは、数字を理解して判定している訳ではなく、64個の値の組み合わせに正解を
紐付けているだけだと言うこと。
だから、アラビア数字じゃなく漢数字の画像データでも良いはず!
漢数字の画像を用意して、データの追加を何回か繰り替えして行けば、漢数字での判定が出きるようになる。
(要は、正解ラベルで意味づけをしているのは人間なので、数字である必要も、画像データである必要も無い)
おまけ(漢数字用データ)
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3 0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3 0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4 0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4 0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5 0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5 0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6 0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6 0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7 0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7 0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8 0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8 0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9 0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9
MLPClassifier使用版
※追記1(2018/08/14 16:20)※
Python2にも対応したつもり
※追記2(2018/08/14 19:40)※
バグ取り
#!/usr/bin/env python # -*- coding: utf-8 -*- import matplotlib.pyplot as plt from sklearn import datasets from sklearn.neural_network import MLPClassifier import numpy as np import cv2 from time import sleep from os.path import dirname, join from sklearn.datasets.base import Bunch import sys y_or_n = 'n' module_path = dirname(__file__) def load_digits(): module_path = dirname(__file__) data = np.loadtxt(join(module_path, 'data', 'digits.csv'), delimiter=',') target = data[:, -1].astype(np.int) flat_data = data[:, :-1] images = flat_data.view() images.shape = (-1, 8, 8) return Bunch(data=flat_data, target=target, target_names=np.arange(10), images=images) while(True): if y_or_n == 'n': digits = load_digits() n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1)) classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10), max_iter=10000, tol=0.00001, random_state=1) print(classifier) classifier.fit(data, digits.target) images_and_labels = list(zip(digits.images, digits.target)) for index, (image, label) in enumerate(images_and_labels[:10]): plt.subplot(3, 5, index + 1) plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Training %i' % label) img = cv2.imread('digit_test.png') size = (8, 8) im = cv2.resize(img, size, interpolation=cv2.INTER_AREA) imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY) #print(imgray) grayim = np.round((255 - imgray) / 16) grayim = grayim.reshape(1, -1) print(grayim) np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f') predicted = classifier.predict(grayim) print('判定は、' + str(predicted)) plt.subplot(3, 5, 11) plt.axis('off') plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Original') plt.subplot(3, 5, 12) plt.axis('off') plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Prodicted %i' % predicted) plt.show() if sys.version_info.major != 2: y_or_n = input('正しいですか?(y or n)') else: y_or_n = raw_input('正しいですか?(y or n)') if y_or_n == 'n': correct_digit = input('正しい数字を教えてください:') grayim = np.append(grayim, int(correct_digit)) #print(grayim) grayim = grayim.reshape(1, -1) f = open(join(module_path, 'data', 'digits.csv'), 'a') np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f') f.close() sleep(1) else: if sys.version_info.major != 2: next_y_or_n = input('続けますか?(y or n)') else: next_y_or_n = raw_input('続けますか?(y or n)') if next_y_or_n == 'n': break
済みません。openしたファイルをcloseしていなかったので、修正しました。
返信削除