人工知能プログラミングやってくブログ

深層学習・機械学習・AIについて調べてやってみたことをまとめるブログです

Google ColaboratoryでPyTorchでMNISTを学習したモデルを保存し、それを読み出して使う簡単サンプル

Pytorch predict mnist

PyTorchの勉強に作ってみました。

MNISTをCNNで学習したモデルを保存する

pytorchのexamples/mnist/main.pyをjupyter notebookで動くようにちょっと改造。
Google colaboratoryで動かすことを目的としています。

ノートブックの設定で「GPU」

Google colaboratoryでコードを実行する前に、編集 -> ノートブックの設定 から ハードウェアアクセラレータ -> GPU を選択します。
これでGPUを使用して、高速で学習をすることができます。

このコードを実行して学習が終わると"mnist_cnn.pt"というファイルが作成されます。

Google colaboratoryで動かすと、"ファイル"に"mnist_cnn.pt"が保存されます。

PyTorch-MNIST ファイルから"mnist_cnn.pt"を選択してダウンロード。

学習したモデルを読み込んでMNISTを識別

先ほどダウンロードしたmnist_cnn.pyをファイルにアップロード。 Google colaboratoryにmnist_cnn.pyをアップロード

その後、このコードを実行でMNISTの手書き文字が識別されます。