2018年8月20日月曜日

scikit-learn おまけの話(おまけ付き)

scikit-learn付属のdigits.csvを使ってみたが、手書き数字の認識をしたいなら、別のデザインの方が良いだろう。

「手書き数字の認識 → scikit-learn」ではなくて、「scikit-learnの演習 → 手書き数字の認識例を使用」
である事を忘れると、勘違いが生じる。

と、書きながらも勘違いのスクリプトをあげる。
前回と同じ動作の、scikit-learnを使った数字認識のスクリプトだが、tkinterを使ってみた。
データのcsvファイルの選択、判定する画像pngファイルの選択ダイアログやメッセージ等がtkinterで表示される。

※追記1(2018/08/21 14:30)※
分類器を保存するようにした。
スクリプトでは分類器の保存ファイルが存在すれば、それを使うようにしているが、
実際問題としては意味が無い

※追記2(2018/08/26 9:30)※
「8x8=64 + ラベル」以外に対応するようにした。
(画像の)縦横比が同じトレーニング用データならば、それに合わせてテスト用画像を
リサイズして判定する。
分類器の保存ファイルは実際の判定では使わないように変更。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.neural_network import MLPClassifier

import numpy as np
import cv2
from time import sleep
from os.path import dirname, join, basename, exists
from sklearn.datasets.base import Bunch
import sys
from sklearn.externals import joblib

py_version = sys.version_info.major

module_path = dirname(__file__)
#print(dirname(abspath('__file__')))
#print(module_path)

if py_version != 2:
    import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
           tkinter.simpledialog as tksimpledialog
else:
    import Tkinter as tk, tkFileDialog  as tkfiledialog, tkMessageBox as tkmessage, \
           tkSimpleDialog as tksimpledialog

def csv_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*csv')]
    data_path = join(dirname(__file__), 'data')
    tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
    test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    if test_csv == () or test_csv == '':
        test_csv = 'digits.csv'
    else:
        test_csv = basename(test_csv)

    root.destroy()
    #print(test_csv)
    return test_csv

def img_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*png')]
    #data_path = dirname(abspath('__file__'))
    data_path = dirname(__file__)
    tkmessage.showinfo(test_csv, '画像(png)を選択してください')
    test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    #print(dirname(abspath('__file__')))
    if test_img == () or test_img == '':
        test_img = 'digit_test.png'

    root.destroy()
    #print(test_img)
    return test_img

def question0():
    root = tk.Tk()
    root.withdraw()

    y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')

    root.destroy()
    return y_or_n

def question1():
    root = tk.Tk()
    root.withdraw()

    next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')

    root.destroy()
    return next_y_or_n

def correct_input():
    root = tk.Tk()
    root.withdraw()
    root.after(1, lambda: root.focus_force())
    correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください', 
                                              initialvalue='ここに入力',parent=root)

    root.destroy()
    return correct_digit

def load_digits():
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', test_csv),
                      delimiter=',')

    data_len = int(np.sqrt(len(data[0]) - 1))
    #print(data_len)

    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, data_len, data_len)

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)


test_csv = csv_select()
test_img = img_select()

if exists('./' + test_csv + '.mlp.pkl'):
    #y_or_n = ''
    y_or_n = 'no'
else:
    y_or_n = 'no'

while(True):
    digits = load_digits()
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    if y_or_n == 'no':
        classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10),
                                   max_iter=10000, tol=0.00001, random_state=1)
        print(classifier)
        classifier.fit(data, digits.target)
        joblib.dump(classifier, test_csv + '.mlp.pkl', compress=True)
    else:
        classifier = joblib.load(test_csv + '.mlp.pkl')
        #print(classifier)

    images_and_labels = list(zip(digits.images, digits.target))
    #plt.figure(figsize=(5.5, 3))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread(test_img)

    data_len = int(np.sqrt(len(data[0])))
    #print(data_len)

    size = (data_len, data_len)
    im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)

    imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    #print(imgray)
    grayim = np.round((255 - imgray) / 16)
    grayim = grayim.reshape(1, -1)
    print(grayim)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('off')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('off')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    y_or_n = question0()

    if y_or_n == 'no':
        correct_digit = correct_input()
        if correct_digit != None:
            grayim = np.append(grayim, correct_digit)
            #print(grayim)
            grayim = grayim.reshape(1, -1)

            f = open(join(module_path, 'data', test_csv), 'a')
            np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
            f.close()
            sleep(1)
        else:
            correct_digit = str(predicted[0])
    else:
        next_y_or_n = question1()

        if next_y_or_n == 'no':
            break

        test_img = img_select()


おまけのおまけスクリプト

10個の画像から、トレーニング用csvデータを生成するスクリプト。
スクリプトと同一ディレクトリ内の「0test.png」「1test.png」・・・「9test.png」という
ファイル名の画像ファイルを、順番に読込み「8x8=64 + ラベル」の型でcsvファイルを生成する。
デフォルトでは8x8=64だが、スクリプトの引数に数値を与えるとそれを基にしたデータになる。
例えば、「5」を引数にすると、「5x5=25 + ラベル」の型になる。
(仮にスクリプトのファイル名を「prepare.py」とした場合、「python prepare.py 5」と言った記述)

生成されたcsvは、「5x5.csv」と言うファイル名で同一ディレクトリに保存される。
これをdataディレクトリへ移動して上のスクリプトで読み込めば、そのデータで分類器を生成してテスト
画像を判定する。

※追記3(2018/08/27 11:45)※
ファイル名からラベルの取り方を変更

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import sys, cv2

args = sys.argv
print(args)
if len(args) > 2:
    print('引数が多すぎます')
    sys.exit()
elif len(args) == 2:
    try:
        type(int(args[1])) == int
        pix = int(args[1])
    except ValueError:
        print('引数の型が違います')
        sys.exit()
else:
    print('既定のデータ型を使います')
    pix = 8

for num in range(0, 10):
    try:
        img_file = str(num) + 'test.png'
        im = cv2.imread(img_file)

        size = (pix, pix)
        im = cv2.resize(im, size, interpolation=cv2.INTER_AREA)

        imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
        print(imgray)
        grayim = np.round((255 - imgray) / 16)
        grayim = grayim.reshape(1, -1)
        grayim = np.array(grayim, dtype = np.uint8)
        grayim = np.append(grayim, int(str(num)[-1]))
        grayim = grayim.reshape(1, -1)
        print(grayim)

        f = open(str(pix) + 'x' + str(pix) + '.csv', 'a')
        np.savetxt(f, grayim, delimiter = ",", fmt = "%.0f")
        f.close()

        #cv2.imshow(str(pix) + 'test', imgray)
        #cv2.waitKey(0)
    except:
        print(str(num) + 'test.png をスキップ')


10個より多くのデータを生成したい場合は"range(0, 10)"を書き換える
for num in range(0, 1000):

2018年8月4日土曜日

scikit-learnで漢数字認識


scikit-learnにサンプルで付いてくる手書き数字認識のデータセットを弄ってみるのが目的。
(※画像の処理にOpenCVを使ってます※)
手書き数字のデータ(digits.csv)を書き換えたいので、付属のデータをコピーして使用する。
"pip"でscikit-learnをインストールした私の環境で、データは以下のディレクトリにある。

※Windowsの場合
"C:/Users/<username>/AppData/Local/Programs/Python/Python36-32/lib/site-packages/sklearn/datasets/data/digits.csv.gz"

※Arch Linux 32の場合
"/usr/lib/python3.6/site-packages/sklearn/datasets/data/digits.csv.gz"

予め任意のディレクトリを作成してその中に"data"ディレクトリを作成し、その"data"ディレクトリに
"digits.csv.gz"をコピーする。更にエディタ等で直接扱いたいので、解凍もしておく。
データを置いたディレクトリより一つ上の階層に以下のPythonスクリプトを置く。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
from sklearn import datasets, svm

import numpy as np
import cv2
from time import sleep
from os.path import dirname, join
from sklearn.datasets.base import Bunch
import sys

y_or_n = 'n'
module_path = dirname(__file__)

def load_digits(n_class=10, return_X_y=False):
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', 'digits.csv'),
                      delimiter=',')
    #with open(join(module_path, 'descr', 'digits.rst')) as f:
    #    descr = f.read()
    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, 8, 8)

    if n_class < 10:
        idx = target < n_class
        flat_data, target = flat_data[idx], target[idx]
        images = images[idx]

    if return_X_y:
        return flat_data, target

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)
                 #DESCR=descr)


while(True):
    if y_or_n == 'n':
        digits = load_digits()

        n_samples = len(digits.images)
        data = digits.images.reshape((n_samples, -1))

        classifier = svm.SVC(C=100, gamma=0.001)
        print(classifier)
        #classifier.fit(data[:n_samples], digits.target[:n_samples])
        classifier.fit(data, digits.target)

    images_and_labels = list(zip(digits.images, digits.target))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread('digit_test.png')

    size = (8, 8)
    im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)

    imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    #print(imgray)
    grayim = np.round((255 - imgray) / 16)
    grayim = grayim.reshape(1, -1)
    print(grayim)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('off')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('off')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    if sys.version_info.major != 2:
        y_or_n = input('正しいですか?(y or n)')
    else:
        y_or_n = raw_input('正しいですか?(y or n)')

    if y_or_n == 'n':
        correct_digit = input('正しい数字を教えてください:')

        grayim = np.append(grayim, int(correct_digit))
        grayim = grayim.reshape(1, -1)

        f = open(join(module_path, 'data', 'digits.csv'), 'a')
        np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
        f.close()
        sleep(0.5)
    else:
        if sys.version_info.major != 2:
            next_y_or_n = input('続けますか?(y or n)')
        else:
            next_y_or_n = raw_input('続けますか?(y or n)')

        if next_y_or_n == 'n':
            break

"digits.csv"には8x8のサイズで0~9の数字の画像データを、0~16の階調で格納し、最後にその画像が
示す数字の正解ラベルが付加されている。
 1行に64個(8x8)の値と、1個の正解ラベルで、65個の値を持ち、全部で1796行になる。
「scikit-learn 手書き 数字」等で検索すると、1796のデータを学習用と分類用に分けて、その結果を
 示すサンプルが多く見受けられる。
このスクリプトでは、"digits.csv"のデータをすべて読み込んでサポートベクトルマシンによる分類器を
用意し、同じディレクトリ内の画像データ"digit_test.png"を読み込んで分類器を使った判定を行う。

画像"digit_test.png"は8x8ピクセルである必要はなく、例えば100x100でも良いし、80x100とかでも
8x8へ変換するようになっている。

スクリプトを実行させると、10個の凡例と手書き画像のオリジナル・8x8変換後の判定が表示される。


一旦その表示ウィンドウを閉じると、「正しいですか?」と聞いてくるので答えを入力する。


「n」を入力すると、「正しい数字を教えてください」と聞いてくるので正解を入力する。
読み込んだデータに正解ラベルを付加した状態でデータに追加し、再度分類器を生成して判定を やり直す。
一度で駄目でも、何回か繰り返せば正解の判定を導き出すようになる。
「続けますか?」に答える前に、次に判定したい画像を"digit_test.png"として用意しておけば、
新たな画像を判定する。でないと、同じ画像の判定を繰り返すだけになるので注意。

 ただ重要なのは、数字を理解して判定している訳ではなく、64個の値の組み合わせに正解を
紐付けているだけだと言うこと。

だから、アラビア数字じゃなく漢数字の画像データでも良いはず!

漢数字の画像を用意して、データの追加を何回か繰り替えして行けば、漢数字での判定が出きるようになる。

 (要は、正解ラベルで意味づけをしているのは人間なので、数字である必要も、画像データである必要も無い)  

おまけ(漢数字用データ)
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3
0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3
0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4
0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4
0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5
0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5
0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6
0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6
0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7
0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7
0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8
0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8
0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9
0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9


MLPClassifier使用版
※追記1(2018/08/14 16:20)※
Python2にも対応したつもり
※追記2(2018/08/14 19:40)※
バグ取り

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.neural_network import MLPClassifier

import numpy as np
import cv2
from time import sleep
from os.path import dirname, join
from sklearn.datasets.base import Bunch
import sys

y_or_n = 'n'
module_path = dirname(__file__)

def load_digits():
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', 'digits.csv'),
                      delimiter=',')

    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, 8, 8)

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)


while(True):
    if y_or_n == 'n':
        digits = load_digits()
        n_samples = len(digits.images)
        data = digits.images.reshape((n_samples, -1))

        classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10), max_iter=10000, tol=0.00001, random_state=1)
        print(classifier)
        classifier.fit(data, digits.target)

    images_and_labels = list(zip(digits.images, digits.target))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread('digit_test.png')

    size = (8, 8)
    im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)

    imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    #print(imgray)
    grayim = np.round((255 - imgray) / 16)
    grayim = grayim.reshape(1, -1)
    print(grayim)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('off')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('off')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    if sys.version_info.major != 2:
        y_or_n = input('正しいですか?(y or n)')
    else:
        y_or_n = raw_input('正しいですか?(y or n)')

    if y_or_n == 'n':
        correct_digit = input('正しい数字を教えてください:')

        grayim = np.append(grayim, int(correct_digit))
        #print(grayim)
        grayim = grayim.reshape(1, -1)

        f = open(join(module_path, 'data', 'digits.csv'), 'a')
        np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
        f.close()
        sleep(1)
    else:
        if sys.version_info.major != 2:
            next_y_or_n = input('続けますか?(y or n)')
        else:
            next_y_or_n = raw_input('続けますか?(y or n)')

        if next_y_or_n == 'n':
            break