2020年3月7日土曜日

WebSocketでチャット(Tornadoを使う)

暇なので、WebSocketでチャットしてみた。
(Tornadoを使う)
クライアントからアクセスするIPアドレスやポートは適当に変更する。

※追記3(2020/03/09 20:40)※
HTTPとWebSocketサーバを一つにしてみた。

サーバ(HTTPサーバ及びWebSocketサーバ)
例えば:tornado_server.py というファイル名にする。
このファイルを置いたディレクトリに templates と static という名の子ディレクトリを作っておく。
#!/usr/bin/env python
#-*- coding:utf-8 -*-

import tornado.ioloop
import tornado.web
import tornado.websocket
import datetime
import os.path

cl = []
dic = {}
todaydetail = datetime.datetime.today()
today = todaydetail.strftime("%Y_%m_%d")
print(today)
today_log = str(today) + "_log.txt"
print(today_log)
if not os.path.exists(today_log):
    f = open(today_log, "w")

#クライアントからメッセージを受けるとopen → on_message → on_closeが起動する
class WebSocketHandler(tornado.websocket.WebSocketHandler):
    def check_origin(self, origin):
        return True

    #websocketオープン
    def open(self):
        print("open")
        if self not in cl:
            cl.append(self)
            print(len(dic))
            dic[self] = len(dic) 
            print(self)

            f = open(today_log, "r")
            for row in f:
                self.write_message(row.strip())
            f.close()

            self.write_message("Welcome !;;")
 
    #処理
    def on_message(self, message):
        print("on_message")
        mess = str(message) + "; ID=" + str(self)[37:47]
        f = open(today_log, "a")
        f.write(mess + "\n")
        f.close()

        for client in cl:
            print(str(self))
            print(message)
            #クライアントへメッセージを送信
            client.write_message(mess)
 
    #websockeクローズ
    def on_close(self):
        print("close")
        if self in cl:
            cl.remove(self)

#HTTPサーバ
class MainHandler(tornado.web.RequestHandler):
    def get(self):
        self.render("index.html")

app = tornado.web.Application([
    (r"/", MainHandler),
    (r"/websocket", WebSocketHandler),
    ],

    #HTTPサーバで使うpathを設定
    template_path = os.path.join(os.getcwd(),  "templates"),
    static_path = os.path.join(os.getcwd(),  "static"),
)

if __name__ == "__main__":
   app.listen(8888)
   tornado.ioloop.IOLoop.instance().start()


インデックス
index.html と言うファイル名にする。これを templates ディレクトリに置く。
HTTPサーバにアクセスすると、この index.html が呼ばれる。
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,
 initial-scale=1.0,user-scalable=yes" />
</head>
 <body>
  Tornado is awesome !<br>
  <a href="{{static_url("websocket_chat_client.html")}}">Chat Client</a>
 </body>
</html>


クライアント
websocket_chat_client.html と言うファイル名にする。これを static ディレクトリに置く。
インデックスのリンクから呼び出される。
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,
  initial-scale=1.0,user-scalable=yes" />
<title>WebSocketでチャットしてみよう</title>
<style>
    #name{
        width:100px;
        height:20px;
    }

    #message{
        width:250px;
        height:20px;
    }
</style>
<script langage="JavaScript">
    ws = new WebSocket("ws://192.168.1.11:8888/websocket");

    ws.onopen = function(e) {
//  ws.send("Hi Open; ");
    }

    ws.onmessage = function(e) {
        var mon = document.getElementById("monitor");
        var mes = e.data;
        mes_line = mes.split(";");

        var div = document.createElement("div");
        mon.appendChild(div);
        div.style.width = "500px";
        div.style.padding = "10px 10px 10px 10px";
        div.style.margin = "10px 0px 10px 0px";
        div.style.border = "solid 1px #0000ff";
        var div_att = "<div class='mess' style='overflow:auto;padding:5px 5px 5px 15px'>";
        div.innerHTML = mes_line[0] + mes_line[2] + div_att + mes_line[1] + "</div>";

//alert(document.body.clientHeight);
        var c_height = document.body.clientHeight;
        if(mes_line[0] == "Welcome !"){
            c_height = 0;
        }
        window.scrollTo(0, c_height);
    }

    function button00(mes){
        var weeks = new Array('日','月','火','水','木','金','土');
        var now = new Date();

        var year = now.getYear(); // 年
        var month = now.getMonth() + 1; // 月
        var day = now.getDate(); // 日
        var week = weeks[ now.getDay() ]; // 曜日
        var hour = now.getHours(); // 時
        var min = now.getMinutes(); // 分
        var sec = now.getSeconds(); // 秒

        if(year < 2000) { year += 1900; }

        // 数値が1桁の場合、頭に0を付けて2桁で表示する指定
        if(month < 10) { month = "0" + month; }
        if(day < 10) { day = "0" + day; }
        if(hour < 10) { hour = "0" + hour; }
        if(min < 10) { min = "0" + min; }
        if(sec < 10) { sec = "0" + sec; }

        var now_date = year + "/" + month + "/" + day + "(" + week + ")";
        var now_time = hour + ":" + min + ":" + sec;
        var nam = document.getElementById("name").value;
        ws.send(nam + " " + now_date + " " + now_time + ";" + mes);

        document.getElementById("message").value = "";
    }

</script>
</head>
<body>
<h1>WebSocketでチャットしてみよう</h1>
<h3>メッセージ</h3>
<div id="monitor"></div>

<p>
表示する名前
<input type="text" id="name" value="ななし" class="name" />
</p>

<p>
<input type="text" id="message" value="text" class="message"/>
<button onClick="button00(document.getElementById('message').value)">送信</button>
</p>
<p>
※ WebSocketのテストです ※
</p>
</body>
</html>

2018年9月6日木曜日

MNISTは「重心」が重要

scikit-learnの手書き数字認識で、MNISTのデータを使ってみることにした。
先ずは、どんな数字の画像データなのか現物確認しないと始まらないので、MNISTデータを
画像に落とすスクリプトを書いた。
どうせなら、先に作ったスクリプトで使えるファイル名で画像保存する。

"counter_limit"で各数字ごとに保存する件数を指定。
下の例だと、「0」1,000件、「1」1,000件、、、「9」1,000件、と合計10,000件の画像が保存される。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from sklearn import datasets
import cv2, os

counter = [''] * 10
counter_limit = 1000

try:
    os.mkdir('mnist_img')
except:
    pass

for i in range (0, 10):
    counter[i] = 0
    try:
        os.mkdir(str(i) + 'test')
    except:
        continue


mnist = datasets.fetch_mldata('MNIST original')
X, Y = mnist.data,mnist.target

for i in range(0, 70000):
    for n in range(0, 10):
        if Y[i] == n:
            counter[n] += 1
            if counter[n] <= counter_limit:
                img =255 - X[i].reshape(28,28)
                #cv2.imshow('mnist', img)
                #k = cv2.waitKey()
                print('i = ' + str(i) + ', target = ' + str(Y[i]))

                save_name0 = './' + str(n) + 'test/' + str(counter[n] * 10 + n) + 'test.png'
                cv2.imwrite(save_name0, img)

                save_name1 = './mnist_img/' + str(counter[n] * 10 + n) + 'test.png'
                cv2.imwrite(save_name1, img)


因みに、70,000件全ての画像を保存するには"counter_limit"を(余裕を見て)"counter_limit = 8000"とかにする。
そうすると、各数字に以下の件数の画像が保存される。

0:6903
1:7877
2:6990
3:7141
4:6824
5:6313
6:6876
7:7293
8:6825
9:6958

さて、保存した画像を眺めると、書かれた数字はバラエティがあるように見えるが、意外と偏っていて、
描画領域については広範に描かれている訳ではなさそうだ。

例えば、
「7」は28x28の上部に余白が多い
「6」だと28x28の下部に余白が多い
「4」が右に寄っているように見える(縦棒が中心より右に書かれている傾向にあるから)
「9」も28x28の上部に余白がある
等々。
ちょっと見ただけでも上記のような偏り(傾向)が分かる。

ただ、これには訳がある。
MNISTを公開しているサイトに以下の記述がある。
the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
※次のページが参考になりました。ありがとうございます。
「自分の手書きデータをTensorFlowで予測する」
「MNISTを学習させたモデルの気持ちを調べる」

「重心」を計算して中央に配置したデータのため、重心が上にあるものは下に移動して
上部に余白ができる。

重心が下にあるものは上に移動して下部に余白ができる。
このデータで学習をした場合、同じ様に重心で移動したデータの検証ならば正解率が
上がるが、そうでないデータでは誤認識が生じやすくなる。

そこで、MNISTのデータで学習をする場合は、それに合わせて判定用データも加工する
スクリプトを書いてみた。
前「おまけのおまけ」スクリプト"mnist_img"ディレクトリに置き、引数に28を渡して
生成した「28x28.csv」を選択すると、MNISTデータでの学習を前提にして判定用データ
を加工する。
具体的には"def mnist_mod(img):"の部分がそれにあたる。
横方向の移動に関しては、重心が中央よりも右方向に行くようにしてある。
※追記1(2018/09/09 10:00)※
"def mnist_mod(img):"の重心移動方法を修正。
※追記2(2018/09/10 22:00)※
digits.csv用の前処理を追加したつもり。
※追記3(2018/09/12 20:10)※
縦横比が同じでないテスト用画像の読み込みに対応したつもり。
現状下の様に、パーツに分かれた数字は、重心を求めていないためMNISTの学習結果では、正常に判定できない。
※追記4(2018/09/14 20:30)※
パーツに分かれた数字の重心(と言うか画像の重心)を求めるように修正。

ある程度の調整は出きるが、、、。うーん、今一。
ちなみに分類機のパラメータは出鱈目なので悪しからず。

※追記5(2019/12/05 11:10)※
OpenCVのバージョンにより、findContoursの返り値の数が異なる事に対応(?)。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import cv2
from time import sleep
from os.path import dirname, join, basename, exists
from sklearn.datasets.base import Bunch
import sys
from sklearn.externals import joblib

py_version = sys.version_info.major
cv_version = cv2.__version__[0:1]
module_path = dirname(__file__)

if py_version != 2:
    import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
           tkinter.simpledialog as tksimpledialog
else:
    import Tkinter as tk, tkFileDialog  as tkfiledialog, tkMessageBox as tkmessage, \
           tkSimpleDialog as tksimpledialog

def clf_select():
    def check():
        root.quit()
        return v.get()

    root = tk.Tk()
    root.title('digit_test')

    v = tk.IntVar()
    v.set(0)

    radio1 = tk.Radiobutton(text='SVM', variable=v, value=1, font=("",14))
    radio1.pack(anchor='w')

    radio2 = tk.Radiobutton(text='MLPClassifier', variable=v, value=2, font=("",14))
    radio2.pack(anchor='w')

    radio3 = tk.Radiobutton(text='LogisticRegression', variable=v, value=3, font=("",14))
    radio3.pack(anchor='w')

    radio4 = tk.Radiobutton(text='DecisionTreeClassifier', variable=v, value=4, font=("",14))
    radio4.pack(anchor='w')

    button1 = tk.Button(root, text='決 定', command= lambda: check(), width=30, font=("",14))
    button1.pack()

    root.mainloop()
    try:
        root.destroy()
    except:
        sys.exit()

    return check()

def csv_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*csv')]
    data_path = join(dirname(__file__), 'data')
    tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
    test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    if test_csv == () or test_csv == '':
        test_csv = 'digits.csv'
    else:
        test_csv = basename(test_csv)

    root.destroy()
    print(test_csv)
    return test_csv

def img_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*png')]
    data_path = dirname(__file__)
    #tkmessage.showinfo(test_csv, '画像(png)を選択してください')
    test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    if test_img == () or test_img == '':
        test_img = 'digit_test.png'

    root.destroy()
    print(test_img)
    return test_img

def question0():
    root = tk.Tk()
    root.withdraw()

    y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')

    root.destroy()
    return y_or_n

def question1():
    root = tk.Tk()
    root.withdraw()

    next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')

    root.destroy()
    return next_y_or_n

def correct_input():
    root = tk.Tk()
    root.withdraw()
    root.after(1, lambda: root.focus_force())
    correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください', 
                                              initialvalue='ここに入力',parent=root)

    root.destroy()
    return correct_digit

def load_digits():
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', test_csv),
                      delimiter=',')

    data_len = int(np.sqrt(len(data[0]) - 1))

    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, data_len, data_len)

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)

def pre_mod(img):
    data_len = int(np.sqrt(len(data[0])))
    size = (data_len, data_len)
    temp_h, temp_w = img.shape[:2]

    if temp_w > temp_h:
        temp_size = temp_w
    else:
        temp_size = temp_h

    if test_csv == '28x28.csv':
        temp_scale = 0.79
    elif test_csv == 'digits.csv':
        temp_scale = 1
    else:
        temp_scale = 0.9

    temp_re_size = temp_size + 2

    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = np.round(255 - img)
    M = np.float32([[1,0,1],[0,1,1]])
    img = cv2.warpAffine(img, M, (temp_re_size, temp_re_size))
    img = np.round(255 - img)

    ret,thresh = cv2.threshold(img, 50, 255, 0)
    if cv_version == '3':
        imgEdge,contours,hierarchy = cv2.findContours(thresh, cv2.RETR_LIST,
                                                      cv2.CHAIN_APPROX_SIMPLE)
    else:
        contours,hierarchy = cv2.findContours(thresh, cv2.RETR_LIST,
                                              cv2.CHAIN_APPROX_SIMPLE)

    x_locale = []
    y_locale = []
    if len(contours) > 2:
        n = len(contours) - 1
    else:
        n = 1

    for i in range(0, n):
        x,y,w,h = cv2.boundingRect(contours[i])
        x_locale.append(x)
        x_locale.append(x + w)
        y_locale.append(y)
        y_locale.append(y + h)
    x = min(x_locale)
    y = min(y_locale)
    w = max(x_locale) - min(x_locale)
    h = max(y_locale) - min(y_locale)
    img = img[y:y+h, x:x+w]

    if h >= w:
        r_scale_x = int(w * data_len * temp_scale / h)
        r_scale_y = int(data_len * temp_scale)
    else:
        r_scale_x = int(data_len * temp_scale)
        r_scale_y = int(h * data_len * temp_scale / w)

    img = cv2.resize(img, (r_scale_x, r_scale_y), interpolation=cv2.INTER_AREA)
    img = np.round(255 - img)

    if test_csv == '28x28.csv':
        M = np.float32([[1,0,1],[0,1,1]])
        img = cv2.warpAffine(img, M, (r_scale_x + 2, r_scale_y + 2))

        M = cv2.moments(img)
        cx = int(M['m10']/M['m00']) + 1
        cy = int(M['m01']/M['m00']) + 1

    else:
        cx = int(r_scale_x / 2)
        cy = int(r_scale_y / 2)

    print(cx, cy)

    xc = int(data_len / 2) - cx
    yc = int(data_len / 2) - cy
    print(xc, yc)
    M = np.float32([[1,0,xc],[0,1,yc]])
    img = cv2.warpAffine(img, M, size)
    img = np.round(img / 16)
    return img


test_clf = clf_select()
if test_clf == 1:
    print('SVM')
    from sklearn import svm
    clf_name = 'svm'
    classifier = svm.SVC(C=10, gamma='scale')
elif test_clf == 2:
    print('MLPClassifier')
    from sklearn.neural_network import MLPClassifier
    clf_name = 'mlp'
    classifier = MLPClassifier(hidden_layer_sizes=(256, 128, 64, 64, 10),
                               verbose=True, alpha=0.001,
                               max_iter=10000, tol=0.00001, random_state=1)
elif test_clf == 3:
    print('LogisticRegression')
    from sklearn.linear_model import LogisticRegression
    clf_name = 'lr'
    classifier = LogisticRegression()
elif test_clf == 4:
    print('DecisionTreeClassifier')
    from sklearn.tree import DecisionTreeClassifier
    clf_name = 'dt'
    classifier = DecisionTreeClassifier()
else:
    sys.exit()

test_csv = csv_select()
test_img = img_select()

if exists('./' + test_csv + '.' + clf_name + '.pkl'):
    y_or_n = ''
    #y_or_n = 'no'
else:
    y_or_n = 'no'

while(True):
    digits = load_digits()
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    if y_or_n == 'no':
        print(classifier)
        classifier.fit(data, digits.target)
        joblib.dump(classifier, test_csv + '.' + clf_name + '.pkl', compress=True)
    else:
        classifier = joblib.load(test_csv + '.' + clf_name + '.pkl')

    images_and_labels = list(zip(digits.images, digits.target))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('on')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread(test_img)
    grayim = pre_mod(img)
    im = grayim

    print(grayim)
    grayim = grayim.reshape(1, -1)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('on')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('on')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    y_or_n = question0()

    if y_or_n == 'no':
        correct_digit = correct_input()
        if correct_digit != None:
            grayim = np.append(grayim, correct_digit)
            grayim = grayim.reshape(1, -1)

            f = open(join(module_path, 'data', test_csv), 'a')
            np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
            f.close()
            sleep(1)
        else:
            correct_digit = str(predicted[0])
    else:
        next_y_or_n = question1()

        if next_y_or_n == 'no':
            sys.exit()

        test_img = img_select()


2018年8月20日月曜日

scikit-learn おまけの話(おまけ付き)

scikit-learn付属のdigits.csvを使ってみたが、手書き数字の認識をしたいなら、別のデザインの方が良いだろう。

「手書き数字の認識 → scikit-learn」ではなくて、「scikit-learnの演習 → 手書き数字の認識例を使用」
である事を忘れると、勘違いが生じる。

と、書きながらも勘違いのスクリプトをあげる。
前回と同じ動作の、scikit-learnを使った数字認識のスクリプトだが、tkinterを使ってみた。
データのcsvファイルの選択、判定する画像pngファイルの選択ダイアログやメッセージ等がtkinterで表示される。

※追記1(2018/08/21 14:30)※
分類器を保存するようにした。
スクリプトでは分類器の保存ファイルが存在すれば、それを使うようにしているが、
実際問題としては意味が無い

※追記2(2018/08/26 9:30)※
「8x8=64 + ラベル」以外に対応するようにした。
(画像の)縦横比が同じトレーニング用データならば、それに合わせてテスト用画像を
リサイズして判定する。
分類器の保存ファイルは実際の判定では使わないように変更。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.neural_network import MLPClassifier

import numpy as np
import cv2
from time import sleep
from os.path import dirname, join, basename, exists
from sklearn.datasets.base import Bunch
import sys
from sklearn.externals import joblib

py_version = sys.version_info.major

module_path = dirname(__file__)
#print(dirname(abspath('__file__')))
#print(module_path)

if py_version != 2:
    import tkinter as tk, tkinter.filedialog as tkfiledialog, tkinter.messagebox as tkmessage, \
           tkinter.simpledialog as tksimpledialog
else:
    import Tkinter as tk, tkFileDialog  as tkfiledialog, tkMessageBox as tkmessage, \
           tkSimpleDialog as tksimpledialog

def csv_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*csv')]
    data_path = join(dirname(__file__), 'data')
    tkmessage.showinfo('digit_test', 'データ(csv)を選択してください')
    test_csv = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    if test_csv == () or test_csv == '':
        test_csv = 'digits.csv'
    else:
        test_csv = basename(test_csv)

    root.destroy()
    #print(test_csv)
    return test_csv

def img_select():
    root = tk.Tk()
    root.withdraw()

    file_type = [('', '*png')]
    #data_path = dirname(abspath('__file__'))
    data_path = dirname(__file__)
    tkmessage.showinfo(test_csv, '画像(png)を選択してください')
    test_img = tkfiledialog.askopenfilename(filetypes=file_type, initialdir=data_path)
    #print(dirname(abspath('__file__')))
    if test_img == () or test_img == '':
        test_img = 'digit_test.png'

    root.destroy()
    #print(test_img)
    return test_img

def question0():
    root = tk.Tk()
    root.withdraw()

    y_or_n = tkmessage.askquestion(test_csv, '判定は' + str(predicted) + '\n正しいですか?')

    root.destroy()
    return y_or_n

def question1():
    root = tk.Tk()
    root.withdraw()

    next_y_or_n = tkmessage.askquestion(test_csv, '続けますか?')

    root.destroy()
    return next_y_or_n

def correct_input():
    root = tk.Tk()
    root.withdraw()
    root.after(1, lambda: root.focus_force())
    correct_digit = tksimpledialog.askinteger(test_csv, '正解を教えてください', 
                                              initialvalue='ここに入力',parent=root)

    root.destroy()
    return correct_digit

def load_digits():
    module_path = dirname(__file__)
    data = np.loadtxt(join(module_path, 'data', test_csv),
                      delimiter=',')

    data_len = int(np.sqrt(len(data[0]) - 1))
    #print(data_len)

    target = data[:, -1].astype(np.int)
    flat_data = data[:, :-1]
    images = flat_data.view()
    images.shape = (-1, data_len, data_len)

    return Bunch(data=flat_data,
                 target=target,
                 target_names=np.arange(10),
                 images=images)


test_csv = csv_select()
test_img = img_select()

if exists('./' + test_csv + '.mlp.pkl'):
    #y_or_n = ''
    y_or_n = 'no'
else:
    y_or_n = 'no'

while(True):
    digits = load_digits()
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    if y_or_n == 'no':
        classifier = MLPClassifier(hidden_layer_sizes=(100, 100, 100, 10),
                                   max_iter=10000, tol=0.00001, random_state=1)
        print(classifier)
        classifier.fit(data, digits.target)
        joblib.dump(classifier, test_csv + '.mlp.pkl', compress=True)
    else:
        classifier = joblib.load(test_csv + '.mlp.pkl')
        #print(classifier)

    images_and_labels = list(zip(digits.images, digits.target))
    #plt.figure(figsize=(5.5, 3))
    for index, (image, label) in enumerate(images_and_labels[:10]):
        plt.subplot(3, 5, index + 1)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training %i' % label)

    img = cv2.imread(test_img)

    data_len = int(np.sqrt(len(data[0])))
    #print(data_len)

    size = (data_len, data_len)
    im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)

    imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    #print(imgray)
    grayim = np.round((255 - imgray) / 16)
    grayim = grayim.reshape(1, -1)
    print(grayim)
    np.savetxt('digit_test.csv', grayim, delimiter = ',', fmt = '%.0f')

    predicted = classifier.predict(grayim)
    print('判定は、' + str(predicted))

    plt.subplot(3, 5, 11)
    plt.axis('off')
    plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Original')
    plt.subplot(3, 5, 12)
    plt.axis('off')
    plt.imshow(im, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prodicted %i' % predicted)
    plt.show()

    y_or_n = question0()

    if y_or_n == 'no':
        correct_digit = correct_input()
        if correct_digit != None:
            grayim = np.append(grayim, correct_digit)
            #print(grayim)
            grayim = grayim.reshape(1, -1)

            f = open(join(module_path, 'data', test_csv), 'a')
            np.savetxt(f, grayim, delimiter = ',', fmt = '%.0f')
            f.close()
            sleep(1)
        else:
            correct_digit = str(predicted[0])
    else:
        next_y_or_n = question1()

        if next_y_or_n == 'no':
            break

        test_img = img_select()


おまけのおまけスクリプト

10個の画像から、トレーニング用csvデータを生成するスクリプト。
スクリプトと同一ディレクトリ内の「0test.png」「1test.png」・・・「9test.png」という
ファイル名の画像ファイルを、順番に読込み「8x8=64 + ラベル」の型でcsvファイルを生成する。
デフォルトでは8x8=64だが、スクリプトの引数に数値を与えるとそれを基にしたデータになる。
例えば、「5」を引数にすると、「5x5=25 + ラベル」の型になる。
(仮にスクリプトのファイル名を「prepare.py」とした場合、「python prepare.py 5」と言った記述)

生成されたcsvは、「5x5.csv」と言うファイル名で同一ディレクトリに保存される。
これをdataディレクトリへ移動して上のスクリプトで読み込めば、そのデータで分類器を生成してテスト
画像を判定する。

※追記3(2018/08/27 11:45)※
ファイル名からラベルの取り方を変更

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import sys, cv2

args = sys.argv
print(args)
if len(args) > 2:
    print('引数が多すぎます')
    sys.exit()
elif len(args) == 2:
    try:
        type(int(args[1])) == int
        pix = int(args[1])
    except ValueError:
        print('引数の型が違います')
        sys.exit()
else:
    print('既定のデータ型を使います')
    pix = 8

for num in range(0, 10):
    try:
        img_file = str(num) + 'test.png'
        im = cv2.imread(img_file)

        size = (pix, pix)
        im = cv2.resize(im, size, interpolation=cv2.INTER_AREA)

        imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
        print(imgray)
        grayim = np.round((255 - imgray) / 16)
        grayim = grayim.reshape(1, -1)
        grayim = np.array(grayim, dtype = np.uint8)
        grayim = np.append(grayim, int(str(num)[-1]))
        grayim = grayim.reshape(1, -1)
        print(grayim)

        f = open(str(pix) + 'x' + str(pix) + '.csv', 'a')
        np.savetxt(f, grayim, delimiter = ",", fmt = "%.0f")
        f.close()

        #cv2.imshow(str(pix) + 'test', imgray)
        #cv2.waitKey(0)
    except:
        print(str(num) + 'test.png をスキップ')


10個より多くのデータを生成したい場合は"range(0, 10)"を書き換える
for num in range(0, 1000):