Google ColaboratoryでPyTorchでMNISTを学習したモデルを保存し、それを読み出して使う簡単サンプル
PyTorchの勉強に作ってみました。
MNISTをCNNで学習したモデルを保存する
pytorchのexamples/mnist/main.pyをjupyter notebookで動くようにちょっと改造。
Google colaboratoryで動かすことを目的としています。
Google colaboratoryでコードを実行する前に、編集 -> ノートブックの設定 から ハードウェアアクセラレータ -> GPU を選択します。
これでGPUを使用して、高速で学習をすることができます。
このコードを実行して学習が終わると"mnist_cnn.pt"というファイルが作成されます。
Google colaboratoryで動かすと、"ファイル"に"mnist_cnn.pt"が保存されます。
ファイルから"mnist_cnn.pt"を選択してダウンロード。
学習したモデルを読み込んでMNISTを識別
先ほどダウンロードしたmnist_cnn.pyをファイルにアップロード。
その後、このコードを実行でMNISTの手書き文字が識別されます。