人工知能と競プロやってくブログ

深層学習・機械学習・AI・atcoder・競技プログラミングについて調べてやってみたことをまとめるブログです

PyTorch用フレームワークCatalystをGoogle Colaboratory上で使ってMNISTを学習させてみる簡単サンプル

Catalyst kaggleに上がってるカーネルをみてたら、catalystというフレームワークを使っているものを見かけました。 なかなか便利そうなので、試しに使ってみました。

Catalystとは?

Catalystは、Pytorch上に構築された高レベルのフレームワークです。
レーニングループや他の多くのことは組み込み済みで、最初から記述する必要はありません。

学習用ソースコード

CNNでMNISTを学習するコードにCatalystを導入してみました。
PyTorch特有のtrain、testのコードを書く必要がなくてスッキリしますね。

checkpointsに保存されたモデルの重みファイル(pthファイル)から識別をかける

先ほどの学習でcheckpointsフォルダの中に作られてるbes.pthを、Google Colaboratoryの「ファイル」にアップロードして実行します。

参考

github.com