これに書いてあるとおりですが、やり方わかって実際動いたので貼ってみる。
class LrScheduler(extension.Extension):
trigger = (1, 'epoch')
def __init__(self, base_lr, epochs, optimizer_name='main', lr_name='lr'):
self._base_lr = base_lr
self._epochs = [int(e) for e in epochs]
self._optimizer_name = optimizer_name
self._lr_name = lr_name
def __call__(self, trainer):
optimizer = trainer.updater.get_optimizer(self._optimizer_name)
e = trainer.updater.epochif e < self._epochs[0]:
lr = self._base_lr * 0.1 + e * 0.9 * self._base_lr / self._epochs[0]
elif e > self._epochs[1]:
lr = self._base_lr * (0.9 ** (e - self._epochs[1]))
else:
lr = self._base_lr
setattr(optimizer, self._lr_name, lr)
# end class
setattr(optimizer, 'lr', 0.1*args.lr)trainer.extend(LrScheduler(base_lr=args.lr, epochs=('3', '5')))
これで、epoch3までかけてlrが徐々に増加することで突然の局所解を回避。epoch5までは定常飛行して、epoch7から徐々になましていく、という動きが実現できる。
progressを貼り付けるとこんな感じ。
epoch elapsed_time lr main/loss
1 9.02064 0.001 2.81905
2 14.4791 0.004 0.757175
3 19.9546 0.007 0.333922
4 25.3488 0.01 0.73813
5 31.0634 0.01 1.48048
6 37.1175 0.01 1.3194
7 42.8478 0.009 0.645383
8 48.345 0.0081 0.404334
9 53.6649 0.00729 0.359146