KerasでMNISTの手書き画像を3つの方法で表示してみる
はてなブログ最初のエントリーなんで、お試しがてら書いてみます。
MNISTといえば「深層学習業界のHelloWorld!」と言われる定番データセット。
とりあえずKerasを入れたらexamplesに入ってるmnist_cnn.pyとかを、みんな動かしてみてると思います。
かくいう私も学習は回してましたが
「そういえば、学習は回しても中身を表示したことないな?」
と思い、どうやるか調べて試してみました。
PILを使ったMNIST手書き画像の簡単表示
下記のコードを実行するとMNISTの手書き画像が4枚表示されます。
これがMACで実行してみた画像。
けっこう簡単に表示はできんのね。
import keras from keras.datasets import mnist import numpy as np from PIL import Image # 文字画像表示 def img_show(img): pil_img = Image.fromarray(np.uint8(img)) pil_img.show() # Kerasの関数でデータの読み込み。データをシャッフルして学習データと訓練データに分割 (x_train, y_train), (x_test, y_test) = mnist.load_data() # MNISTの文字画像を読み出して表示 num = 4 for i in range(num): img_show(x_train[i].reshape(28, 28))
PILを使って一枚の画像にMNIST手書き画像をまとめて表示
上の簡単表示だと、たくさん表示するとウインドウがたくさん出てしまうんで一枚の画像にまとめて表示するコードを書いてみました。
表示ついでに、結果をJPEGに保存してます。
import keras from keras.datasets import mnist import numpy as np from PIL import Image # 文字画像表示 def ConvertToImg(img): return Image.fromarray(np.uint8(img)) # Kerasの関数でデータの読み込み。データをシャッフルして学習データと訓練データに分割 (x_train, y_train), (x_test, y_test) = mnist.load_data() # MNIST一文字の幅 chr_w = 28 # MNIST一文字の高さ chr_h = 28 # 表示する文字数 num = 128 # MNISTの文字をPILで1枚の画像に描画する canvas = Image.new('RGB', (int(chr_w * num/2), int(chr_h * num/2)), (255, 255, 255)) # MNISTの文字を読み込んで描画 i = 0 for y in range( int(num/2) ): for x in range( int(num/2) ): chrImg = ConvertToImg(x_train[i].reshape(chr_w, chr_h)) canvas.paste(chrImg, (chr_w*x, chr_h*y)) i = i + 1 canvas.show() # 表示した画像をJPEGとして保存 canvas.save('mnist.jpg', 'JPEG', quality=100, optimize=True)
matplotを使ってMNIST手書き画像を表示
PILでなく、matplotで表示してみます。 実行すると、こんな感じ。 以下コード。
import keras from keras.datasets import mnist import matplotlib.pyplot as plt # Kerasの関数でデータの読み込み。データをシャッフルして学習データと訓練データに分割 (x_train, y_train), (x_test, y_test) = mnist.load_data() # MNISTデータの表示 W = 16 # 横に並べる個数 H = 8 # 縦に並べる個数 fig = plt.figure(figsize=(H, W)) fig.subplots_adjust(left=0, right=1, bottom=0, top=1.0, hspace=0.05, wspace=0.05) for i in range(W*H): ax = fig.add_subplot(H, W, i + 1, xticks=[], yticks=[]) ax.imshow(x_train[i].reshape((28, 28)), cmap='gray') plt.show()
参考
次のサイトを参考にさせていただきました。