Chainerで訓練の半自動再開(resume semi-automatically)
会社的に推奨されたので、適当にsnapshotとってresumeする仕組みを実装した。
やりかったことは
- 適当なタイミングでtrainerのsnapshotをとって、そこから学習再開できるようにする
- chainermnだとmaybe_loadという便利関数があってiteration数とか解決してくれるんだけど、普通のchainerにはそれがない
- ので、resume用のsnapshotは名前を固定、随時上書き方針として、resume flagを立てて起動すると固定のsnapshotファイルがあればロードしてそこからやりなおす
適当なタイミングでsnapshotをとる@固定の名前
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
snapshot_filename = args.prefix + "_snapshot_for_resume.npz"
trainer.extend(extensions.snapshot(filename=snapshot_filename), trigger=(frequency, 'epoch'))
こんな感じ。argsにfrequencyというsnapshotの頻度指定を入れておくのと、共通ファイル名は{args.prefix}_snapshot_for_resume.npz。
これで、frequency epochごとに trainerのout引数に指定したディレクトリに{args.prefix}_snapshot_for_resume.npzが保存される
フラグでresumeチェックする
parser.add_argument('--resume', '-r', action="store_true", help='Resume training from snapshot')
これをonにすると、起動時に上記ファイルがあるかどうかチェックする
if args.resume and os.path.exists(model_dir + "/" + snapshot_filename):
print("Resuming a training from a snapshot '{}'".format(model_dir + "/" + snapshot_filename))
chainer.serializers.load_npz(model_dir + "/" + snapshot_filename, trainer)trainer.run()
model_dirというのが、trainerのout引数に指定したディレクトリだと思ってください。
trainer.run()の直前に上記のコードを書き足す
あとは、train.pyの起動時に -f 10 -r と書き足すと、10エポックごとにsnapshotとるのと、もしすでにsnapshotがあればそこから学習する