2018年8月4日土曜日

scikit-learnで漢数字認識


scikit-learnにサンプルで付いてくる手書き数字認識のデータセットを弄ってみるのが目的。
(※画像の処理にOpenCVを使ってます※)
手書き数字のデータ(digits.csv)を書き換えたいので、付属のデータをコピーして使用する。
"pip"でscikit-learnをインストールした私の環境で、データは以下のディレクトリにある。

※Windowsの場合
"C:/Users/<username>/AppData/Local/Programs/Python/Python36-32/lib/site-packages/sklearn/datasets/data/digits.csv.gz"

※Arch Linux 32の場合
"/usr/lib/python3.6/site-packages/sklearn/datasets/data/digits.csv.gz"

予め任意のディレクトリを作成してその中に"data"ディレクトリを作成し、その"data"ディレクトリに
"digits.csv.gz"をコピーする。更にエディタ等で直接扱いたいので、解凍もしておく。
データを置いたディレクトリより一つ上の階層に以下のPythonスクリプトを置く。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
from sklearn import datasets, svm

import numpy as np
import cv2
from time import sleep
from os.path import dirname, join
from sklearn.datasets.base import Bunch
import sys

y_or_n = 'n'
module_path = dirname(__file__)

def load_digits(n_class=10, return_X_y=False):
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', 'digits.csv'),
                      delimiter=',')
    #with open(join(module_path, 'descr', 'digits.rst')) as f:
    #    descr = f.read()
    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, 8, 8)

    if n_class < 10:
        idx = target < n_class
        flat_data, target = flat_data[idx], target[idx]
        images = images[idx]

    if return_X_y:
        return flat_data, target

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)
                 #DESCR=descr)


while(True):
    if y_or_n == 'n':
        digits = load_digits()

        n_samples = len(digits.images)
        data = digits.images.reshape((n_samples, -1))

        classifier = svm.SVC(C=100, gamma=0.001)
        print(classifier)
        #classifier.fit(data[:n_samples], digits.target[:n_samples])
        classifier.fit(data, digits.target)

    images_and_labels = list(zip(digits.images, digits.target))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread('digit_test.png')

    size = (8, 8)
    im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)

    imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    #print(imgray)
    grayim = np.round((255 - imgray) / 16)
    grayim = grayim.reshape(1, -1)
    print(grayim)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('off')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('off')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    if sys.version_info.major != 2:
        y_or_n = input('正しいですか?(y or n)')
    else:
        y_or_n = raw_input('正しいですか?(y or n)')

    if y_or_n == 'n':
        correct_digit = input('正しい数字を教えてください:')

        grayim = np.append(grayim, int(correct_digit))
        grayim = grayim.reshape(1, -1)

        f = open(join(module_path, 'data', 'digits.csv'), 'a')
        np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
        f.close()
        sleep(0.5)
    else:
        if sys.version_info.major != 2:
            next_y_or_n = input('続けますか?(y or n)')
        else:
            next_y_or_n = raw_input('続けますか?(y or n)')

        if next_y_or_n == 'n':
            break

"digits.csv"には8x8のサイズで0~9の数字の画像データを、0~16の階調で格納し、最後にその画像が
示す数字の正解ラベルが付加されている。
 1行に64個(8x8)の値と、1個の正解ラベルで、65個の値を持ち、全部で1796行になる。
「scikit-learn 手書き 数字」等で検索すると、1796のデータを学習用と分類用に分けて、その結果を
 示すサンプルが多く見受けられる。
このスクリプトでは、"digits.csv"のデータをすべて読み込んでサポートベクトルマシンによる分類器を
用意し、同じディレクトリ内の画像データ"digit_test.png"を読み込んで分類器を使った判定を行う。

画像"digit_test.png"は8x8ピクセルである必要はなく、例えば100x100でも良いし、80x100とかでも
8x8へ変換するようになっている。

スクリプトを実行させると、10個の凡例と手書き画像のオリジナル・8x8変換後の判定が表示される。


一旦その表示ウィンドウを閉じると、「正しいですか?」と聞いてくるので答えを入力する。


「n」を入力すると、「正しい数字を教えてください」と聞いてくるので正解を入力する。
読み込んだデータに正解ラベルを付加した状態でデータに追加し、再度分類器を生成して判定を やり直す。
一度で駄目でも、何回か繰り返せば正解の判定を導き出すようになる。
「続けますか?」に答える前に、次に判定したい画像を"digit_test.png"として用意しておけば、
新たな画像を判定する。でないと、同じ画像の判定を繰り返すだけになるので注意。

 ただ重要なのは、数字を理解して判定している訳ではなく、64個の値の組み合わせに正解を
紐付けているだけだと言うこと。

だから、アラビア数字じゃなく漢数字の画像データでも良いはず!

漢数字の画像を用意して、データの追加を何回か繰り替えして行けば、漢数字での判定が出きるようになる。

 (要は、正解ラベルで意味づけをしているのは人間なので、数字である必要も、画像データである必要も無い)  

おまけ(漢数字用データ)
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,6,6,6,6,8,8,0,3,4,4,4,4,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,10,10,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,5,10,10,11,9,8,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3
0,0,0,0,0,0,0,0,0,0,9,11,10,8,0,0,0,0,0,0,0,0,0,0,0,0,2,9,9,0,0,0,0,0,0,1,0,0,0,0,2,4,5,6,9,11,9,1,3,6,6,4,1,0,4,5,0,0,0,0,0,0,0,0,3
0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4
0,0,0,0,0,0,0,0,1,1,3,5,6,8,1,0,9,6,10,9,11,8,4,0,4,7,6,5,10,6,4,0,2,12,11,0,9,13,2,0,2,9,4,4,6,12,3,0,0,3,6,6,5,6,2,0,0,0,0,0,0,0,0,0,4
0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5
0,0,0,0,0,0,0,0,0,3,10,10,10,8,0,0,0,0,0,10,2,0,0,0,0,3,7,13,6,4,0,0,0,1,7,9,4,10,0,0,0,0,7,4,4,7,0,0,7,10,12,10,12,11,11,10,0,0,0,0,0,0,0,0,5
0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6
0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,2,0,10,4,5,6,1,0,9,10,10,8,6,5,1,0,0,3,0,1,6,0,0,0,2,10,0,0,10,8,0,0,11,2,0,0,0,6,5,0,1,0,0,0,0,0,0,6
0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7
0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,10,0,0,1,0,0,2,8,13,10,11,10,3,0,0,2,11,1,0,0,0,0,0,0,11,0,0,0,0,0,0,0,7,11,10,10,7,0,0,0,0,0,0,0,0,0,7
0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8
0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,7,0,11,0,0,0,0,4,7,0,6,5,0,0,0,10,2,0,1,11,0,0,3,9,0,0,0,8,5,0,1,1,0,0,0,1,11,2,0,0,0,0,0,0,0,0,8
0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9
0,0,0,5,0,0,0,0,0,0,0,11,0,0,0,0,0,2,8,13,12,10,0,0,0,0,8,6,7,4,0,0,0,0,11,0,8,2,0,0,0,2,10,0,6,7,0,7,0,8,4,0,0,9,11,7,0,1,0,0,0,0,0,0,9


MLPClassifier使用版
※追記1(2018/08/14 16:20)※
Python2にも対応したつもり
※追記2(2018/08/14 19:40)※
バグ取り

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.neural_network import MLPClassifier

import numpy as np
import cv2
from time import sleep
from os.path import dirname, join
from sklearn.datasets.base import Bunch
import sys

y_or_n = 'n'
module_path = dirname(__file__)

def load_digits():
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', 'digits.csv'),
                      delimiter=',')

    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, 8, 8)

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)


while(True):
    if y_or_n == 'n':
        digits = load_digits()
        n_samples = len(digits.images)
        data = digits.images.reshape((n_samples, -1))

        classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10), max_iter=10000, tol=0.00001, random_state=1)
        print(classifier)
        classifier.fit(data, digits.target)

    images_and_labels = list(zip(digits.images, digits.target))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread('digit_test.png')

    size = (8, 8)
    im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)

    imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    #print(imgray)
    grayim = np.round((255 - imgray) / 16)
    grayim = grayim.reshape(1, -1)
    print(grayim)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('off')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('off')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    if sys.version_info.major != 2:
        y_or_n = input('正しいですか?(y or n)')
    else:
        y_or_n = raw_input('正しいですか?(y or n)')

    if y_or_n == 'n':
        correct_digit = input('正しい数字を教えてください:')

        grayim = np.append(grayim, int(correct_digit))
        #print(grayim)
        grayim = grayim.reshape(1, -1)

        f = open(join(module_path, 'data', 'digits.csv'), 'a')
        np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
        f.close()
        sleep(1)
    else:
        if sys.version_info.major != 2:
            next_y_or_n = input('続けますか?(y or n)')
        else:
            next_y_or_n = raw_input('続けますか?(y or n)')

        if next_y_or_n == 'n':
            break

1 件のコメント :

  1. 済みません。openしたファイルをcloseしていなかったので、修正しました。

    返信削除