2018年9月6日木曜日

MNISTは「重心」が重要

scikit-learnの手書き数字認識で、MNISTのデータを使ってみることにした。
先ずは、どんな数字の画像データなのか現物確認しないと始まらないので、MNISTデータを
画像に落とすスクリプトを書いた。
どうせなら、先に作ったスクリプトで使えるファイル名で画像保存する。

"counter_limit"で各数字ごとに保存する件数を指定。
下の例だと、「0」1,000件、「1」1,000件、、、「9」1,000件、と合計10,000件の画像が保存される。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from sklearn import datasets
import cv2, os

counter = [''] * 10
counter_limit = 1000

try:
    os.mkdir('mnist_img')
except:
    pass

for i in range (0, 10):
    counter[i] = 0
    try:
        os.mkdir(str(i) + 'test')
    except:
        continue


mnist = datasets.fetch_mldata('MNIST original')
X, Y = mnist.data,mnist.target

for i in range(0, 70000):
    for n in range(0, 10):
        if Y[i] == n:
            counter[n] += 1
            if counter[n] <= counter_limit:
                img =255 - X[i].reshape(28,28)
                #cv2.imshow('mnist', img)
                #k = cv2.waitKey()
                print('i = ' + str(i) + ', target = ' + str(Y[i]))

                save_name0 = './' + str(n) + 'test/' + str(counter[n] * 10 + n) + 'test.png'
                cv2.imwrite(save_name0, img)

                save_name1 = './mnist_img/' + str(counter[n] * 10 + n) + 'test.png'
                cv2.imwrite(save_name1, img)


因みに、70,000件全ての画像を保存するには"counter_limit"を(余裕を見て)"counter_limit = 8000"とかにする。
そうすると、各数字に以下の件数の画像が保存される。

0:6903
1:7877
2:6990
3:7141
4:6824
5:6313
6:6876
7:7293
8:6825
9:6958

さて、保存した画像を眺めると、書かれた数字はバラエティがあるように見えるが、意外と偏っていて、
描画領域については広範に描かれている訳ではなさそうだ。

例えば、
「7」は28x28の上部に余白が多い
「6」だと28x28の下部に余白が多い
「4」が右に寄っているように見える(縦棒が中心より右に書かれている傾向にあるから)
「9」も28x28の上部に余白がある
等々。
ちょっと見ただけでも上記のような偏り(傾向)が分かる。

ただ、これには訳がある。
MNISTを公開しているサイトに以下の記述がある。
the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
※次のページが参考になりました。ありがとうございます。
「自分の手書きデータをTensorFlowで予測する」
「MNISTを学習させたモデルの気持ちを調べる」

「重心」を計算して中央に配置したデータのため、重心が上にあるものは下に移動して
上部に余白ができる。

重心が下にあるものは上に移動して下部に余白ができる。
このデータで学習をした場合、同じ様に重心で移動したデータの検証ならば正解率が
上がるが、そうでないデータでは誤認識が生じやすくなる。

そこで、MNISTのデータで学習をする場合は、それに合わせて判定用データも加工する
スクリプトを書いてみた。
前「おまけのおまけ」スクリプト"mnist_img"ディレクトリに置き、引数に28を渡して
生成した「28x28.csv」を選択すると、MNISTデータでの学習を前提にして判定用データ
を加工する。
具体的には"def mnist_mod(img):"の部分がそれにあたる。
横方向の移動に関しては、重心が中央よりも右方向に行くようにしてある。
※追記1(2018/09/09 10:00)※
"def mnist_mod(img):"の重心移動方法を修正。
※追記2(2018/09/10 22:00)※
digits.csv用の前処理を追加したつもり。
※追記3(2018/09/12 20:10)※
縦横比が同じでないテスト用画像の読み込みに対応したつもり。
現状下の様に、パーツに分かれた数字は、重心を求めていないためMNISTの学習結果では、正常に判定できない。
※追記4(2018/09/14 20:30)※
パーツに分かれた数字の重心(と言うか画像の重心)を求めるように修正。

ある程度の調整は出きるが、、、。うーん、今一。
ちなみに分類機のパラメータは出鱈目なので悪しからず。

※追記5(2019/12/05 11:10)※
OpenCVのバージョンにより、findContoursの返り値の数が異なる事に対応(?)。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import cv2
from time import sleep
from os.path import dirname, join, basename, exists
from sklearn.datasets.base import Bunch
import sys
from sklearn.externals import joblib

py_version = sys.version_info.major
cv_version = cv2.__version__[0:1]
module_path = dirname(__file__)

if py_version != 2:
    import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
           tkinter.simpledialog as tksimpledialog
else:
    import Tkinter as tk, tkFileDialog  as tkfiledialog, tkMessageBox as tkmessage, \
           tkSimpleDialog as tksimpledialog

def clf_select():
    def check():
        root.quit()
        return v.get()

    root = tk.Tk()
    root.title('digit_test')

    v = tk.IntVar()
    v.set(0)

    radio1 = tk.Radiobutton(text='SVM', variable=v, value=1, font=("",14))
    radio1.pack(anchor='w')

    radio2 = tk.Radiobutton(text='MLPClassifier', variable=v, value=2, font=("",14))
    radio2.pack(anchor='w')

    radio3 = tk.Radiobutton(text='LogisticRegression', variable=v, value=3, font=("",14))
    radio3.pack(anchor='w')

    radio4 = tk.Radiobutton(text='DecisionTreeClassifier', variable=v, value=4, font=("",14))
    radio4.pack(anchor='w')

    button1 = tk.Button(root, text='決 定', command= lambda: check(), width=30, font=("",14))
    button1.pack()

    root.mainloop()
    try:
        root.destroy()
    except:
        sys.exit()

    return check()

def csv_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*csv')]
    data_path = join(dirname(__file__), 'data')
    tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
    test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    if test_csv == () or test_csv == '':
        test_csv = 'digits.csv'
    else:
        test_csv = basename(test_csv)

    root.destroy()
    print(test_csv)
    return test_csv

def img_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*png')]
    data_path = dirname(__file__)
    #tkmessage.showinfo(test_csv, '画像(png)を選択してください')
    test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    if test_img == () or test_img == '':
        test_img = 'digit_test.png'

    root.destroy()
    print(test_img)
    return test_img

def question0():
    root = tk.Tk()
    root.withdraw()

    y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')

    root.destroy()
    return y_or_n

def question1():
    root = tk.Tk()
    root.withdraw()

    next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')

    root.destroy()
    return next_y_or_n

def correct_input():
    root = tk.Tk()
    root.withdraw()
    root.after(1, lambda: root.focus_force())
    correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください', 
                                              initialvalue='ここに入力',parent=root)

    root.destroy()
    return correct_digit

def load_digits():
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', test_csv),
                      delimiter=',')

    data_len = int(np.sqrt(len(data[0]) - 1))

    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, data_len, data_len)

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)

def pre_mod(img):
    data_len = int(np.sqrt(len(data[0])))
    size = (data_len, data_len)
    temp_h, temp_w = img.shape[:2]

    if temp_w > temp_h:
        temp_size = temp_w
    else:
        temp_size = temp_h

    if test_csv == '28x28.csv':
        temp_scale = 0.79
    elif test_csv == 'digits.csv':
        temp_scale = 1
    else:
        temp_scale = 0.9

    temp_re_size = temp_size + 2

    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = np.round(255 - img)
    M = np.float32([[1,0,1],[0,1,1]])
    img = cv2.warpAffine(img, M, (temp_re_size, temp_re_size))
    img = np.round(255 - img)

    ret,thresh = cv2.threshold(img, 50, 255, 0)
    if cv_version == '3':
        imgEdge,contours,hierarchy = cv2.findContours(thresh, cv2.RETR_LIST,
                                                      cv2.CHAIN_APPROX_SIMPLE)
    else:
        contours,hierarchy = cv2.findContours(thresh, cv2.RETR_LIST,
                                              cv2.CHAIN_APPROX_SIMPLE)

    x_locale = []
    y_locale = []
    if len(contours) > 2:
        n = len(contours) - 1
    else:
        n = 1

    for i in range(0, n):
        x,y,w,h = cv2.boundingRect(contours[i])
        x_locale.append(x)
        x_locale.append(x + w)
        y_locale.append(y)
        y_locale.append(y + h)
    x = min(x_locale)
    y = min(y_locale)
    w = max(x_locale) - min(x_locale)
    h = max(y_locale) - min(y_locale)
    img = img[y:y+h, x:x+w]

    if h >= w:
        r_scale_x = int(w * data_len * temp_scale / h)
        r_scale_y = int(data_len * temp_scale)
    else:
        r_scale_x = int(data_len * temp_scale)
        r_scale_y = int(h * data_len * temp_scale / w)

    img = cv2.resize(img, (r_scale_x, r_scale_y), interpolation=cv2.INTER_AREA)
    img = np.round(255 - img)

    if test_csv == '28x28.csv':
        M = np.float32([[1,0,1],[0,1,1]])
        img = cv2.warpAffine(img, M, (r_scale_x + 2, r_scale_y + 2))

        M = cv2.moments(img)
        cx = int(M['m10']/M['m00']) + 1
        cy = int(M['m01']/M['m00']) + 1

    else:
        cx = int(r_scale_x / 2)
        cy = int(r_scale_y / 2)

    print(cx, cy)

    xc = int(data_len / 2) - cx
    yc = int(data_len / 2) - cy
    print(xc, yc)
    M = np.float32([[1,0,xc],[0,1,yc]])
    img = cv2.warpAffine(img, M, size)
    img = np.round(img / 16)
    return img


test_clf = clf_select()
if test_clf == 1:
    print('SVM')
    from sklearn import svm
    clf_name = 'svm'
    classifier = svm.SVC(C=10, gamma='scale')
elif test_clf == 2:
    print('MLPClassifier')
    from sklearn.neural_network import MLPClassifier
    clf_name = 'mlp'
    classifier = MLPClassifier(hidden_layer_sizes=(256, 128, 64, 64, 10),
                               verbose=True, alpha=0.001,
                               max_iter=10000, tol=0.00001, random_state=1)
elif test_clf == 3:
    print('LogisticRegression')
    from sklearn.linear_model import LogisticRegression
    clf_name = 'lr'
    classifier = LogisticRegression()
elif test_clf == 4:
    print('DecisionTreeClassifier')
    from sklearn.tree import DecisionTreeClassifier
    clf_name = 'dt'
    classifier = DecisionTreeClassifier()
else:
    sys.exit()

test_csv = csv_select()
test_img = img_select()

if exists('./' + test_csv + '.' + clf_name + '.pkl'):
    y_or_n = ''
    #y_or_n = 'no'
else:
    y_or_n = 'no'

while(True):
    digits = load_digits()
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    if y_or_n == 'no':
        print(classifier)
        classifier.fit(data, digits.target)
        joblib.dump(classifier, test_csv + '.' + clf_name + '.pkl', compress=True)
    else:
        classifier = joblib.load(test_csv + '.' + clf_name + '.pkl')

    images_and_labels = list(zip(digits.images, digits.target))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('on')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread(test_img)
    grayim = pre_mod(img)
    im = grayim

    print(grayim)
    grayim = grayim.reshape(1, -1)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('on')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('on')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    y_or_n = question0()

    if y_or_n == 'no':
        correct_digit = correct_input()
        if correct_digit != None:
            grayim = np.append(grayim, correct_digit)
            grayim = grayim.reshape(1, -1)

            f = open(join(module_path, 'data', test_csv), 'a')
            np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
            f.close()
            sleep(1)
        else:
            correct_digit = str(predicted[0])
    else:
        next_y_or_n = question1()

        if next_y_or_n == 'no':
            sys.exit()

        test_img = img_select()