スマートフォン・タブレットからインターネットサーバーオペレーション

APPW.jp

Keras の MNIST モデルで Predict を体験

Keras で学習済みの MNIST モデルを使用して、その Predict を体験してみます。

MNIST は、Mixed National Institute of Standards and Technology database の略で、エムニストと読みます。MNIST データは、数字の手書き画像とその画像に表されている数字のラベルからなります。

MNIST は、機械学習とディープラーニングの入門の題材として、よく利用されています。

次の黒い背景の部分に手書きで1文字の数字を書き込むと、そこに書かれた数字を予測します。

数字予測精度数字予測精度
0 - 5 -
1 - 6 -
2 - 7 -
3 - 8 -
4 - 9 -

学習モデルの作成には、Keras の examples に収録されている https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py の最終行に、モデルを保存する次のコードを1行追加して、実行しました。



model.save('mnist_model.h5')

ちなみに、今回の学習の実行環境は、お名前.com VPS の 2GB プランで、およそ1週間かかっています。

学習済みモデルを使って、自前で作成した example.jpg の手書き数字を判定する例です。



from __future__ import print_function
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img

img_width, img_height, channel = 28, 28, 1

test_model = load_model('mnist_model.h5')
img = load_img('example.jpg',False,target_size=(img_width,img_height))
x = img.convert('L')
x = img_to_array(x)
x = np.expand_dims(x, axis=0)
x = x.reshape(x.shape[0], img_width, img_height, channel)
x = x.astype('float32')
x /= 255
preds = test_model.predict_classes(x)
pred = preds.tolist()
prob = test_model.predict_proba(x)
prlist = prob.tolist()
print(pred, prlist)

load_img() は、PIL 形式の画像データを返します。Web から POST された Base64 の JPEG 画像データを直接処理する場合など、load_img() の代替例です。



import re
import json
import base64
import io 
from PIL import Image

def getPILimage(base64_img):
  img_data = re.sub('^data:image/.+;base64,', '', base64_img)
  dec_img = base64.b64decode(img_data)

  img_bin = io.BytesIO(dec_img)
  pil_img = Image.open(img_bin)

  return pil_img

『Keras の MNIST モデルで Predict を体験』を公開しました。