人工知能と競プロやってくブログ

深層学習・機械学習・AI・atcoder・競技プログラミングについて調べてやってみたことをまとめるブログです

Google ColaboratoryでPyTorchでMNISTを学習したモデルを保存し、それを読み出して使う簡単サンプル

Pytorch predict mnist

PyTorchの勉強に作ってみました。

MNISTをCNNで学習したモデルを保存する

pytorchのexamples/mnist/main.pyをjupyter notebookで動くようにちょっと改造。
Google colaboratoryで動かすことを目的としています。

ノートブックの設定で「GPU」

Google colaboratoryでコードを実行する前に、編集 -> ノートブックの設定 から ハードウェアアクセラレータ -> GPU を選択します。
これでGPUを使用して、高速で学習をすることができます。

このコードを実行して学習が終わると"mnist_cnn.pt"というファイルが作成されます。

Google colaboratoryで動かすと、"ファイル"に"mnist_cnn.pt"が保存されます。

PyTorch-MNIST ファイルから"mnist_cnn.pt"を選択してダウンロード。

学習したモデルを読み込んでMNISTを識別

先ほどダウンロードしたmnist_cnn.pyをファイルにアップロード。 Google colaboratoryにmnist_cnn.pyをアップロード

その後、このコードを実行でMNISTの手書き文字が識別されます。