人工知能と競プロやってくブログ

深層学習・機械学習・AI・atcoder・競技プログラミングについて調べてやってみたことをまとめるブログです

KerasでMNISTの手書き画像を3つの方法で表示してみる

はてなブログ最初のエントリーなんで、お試しがてら書いてみます。

MNISTといえば「深層学習業界のHelloWorld!」と言われる定番データセット
とりあえずKerasを入れたらexamplesに入ってるmnist_cnn.pyとかを、みんな動かしてみてると思います。

かくいう私も学習は回してましたが
「そういえば、学習は回しても中身を表示したことないな?」
と思い、どうやるか調べて試してみました。

PILを使ったMNIST手書き画像の簡単表示

f:id:uchidamax:20171219130023p:plain

下記のコードを実行するとMNISTの手書き画像が4枚表示されます。
これがMACで実行してみた画像。
けっこう簡単に表示はできんのね。

import keras
from keras.datasets import mnist
import numpy as np
from PIL import Image

# 文字画像表示
def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()

# Kerasの関数でデータの読み込み。データをシャッフルして学習データと訓練データに分割
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# MNISTの文字画像を読み出して表示
num = 4
for i in range(num):
    img_show(x_train[i].reshape(28, 28))

PILを使って一枚の画像にMNIST手書き画像をまとめて表示

f:id:uchidamax:20171219130643p:plain

上の簡単表示だと、たくさん表示するとウインドウがたくさん出てしまうんで一枚の画像にまとめて表示するコードを書いてみました。
表示ついでに、結果をJPEGに保存してます。

import keras
from keras.datasets import mnist
import numpy as np
from PIL import Image

# 文字画像表示
def ConvertToImg(img):
    return Image.fromarray(np.uint8(img))

# Kerasの関数でデータの読み込み。データをシャッフルして学習データと訓練データに分割
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# MNIST一文字の幅
chr_w = 28
# MNIST一文字の高さ
chr_h = 28
# 表示する文字数
num = 128

# MNISTの文字をPILで1枚の画像に描画する
canvas = Image.new('RGB', (int(chr_w * num/2), int(chr_h * num/2)), (255, 255, 255))

# MNISTの文字を読み込んで描画
i = 0
for y in range( int(num/2) ):
    for x in range( int(num/2) ):
        chrImg = ConvertToImg(x_train[i].reshape(chr_w, chr_h))
        canvas.paste(chrImg, (chr_w*x, chr_h*y))
        i = i + 1

canvas.show()
# 表示した画像をJPEGとして保存
canvas.save('mnist.jpg', 'JPEG', quality=100, optimize=True)

matplotを使ってMNIST手書き画像を表示

f:id:uchidamax:20171219130809p:plain

PILでなく、matplotで表示してみます。 実行すると、こんな感じ。 以下コード。

import keras
from keras.datasets import mnist
import matplotlib.pyplot as plt

# Kerasの関数でデータの読み込み。データをシャッフルして学習データと訓練データに分割
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# MNISTデータの表示
W = 16  # 横に並べる個数
H = 8   # 縦に並べる個数
fig = plt.figure(figsize=(H, W))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1.0, hspace=0.05, wspace=0.05)
for i in range(W*H):
    ax = fig.add_subplot(H, W, i + 1, xticks=[], yticks=[])
    ax.imshow(x_train[i].reshape((28, 28)), cmap='gray')

plt.show()

参考

次のサイトを参考にさせていただきました。

tekenuko.hatenablog.com

oikakeru.hateblo.jp

Python 3.5 対応画像処理ライブラリ Pillow (PIL) の使い方 - Librabuch