ウォームアップとコサインディケイを備えた Adam オプティマイザー

これにより AMSgrad オプティマイザが拡張され、ウォームアップステージが追加されます。

11import math
12from typing import Dict
13
14from labml_nn.optimizers import WeightDecay
15from labml_nn.optimizers.amsgrad import AMSGrad

ウォームアップとコサインディケイを備えた Adam オプティマイザー

このクラスは、で定義されている AMSGrad オプティマイザを拡張したものです。amsgrad.py

18class AdamWarmupCosineDecay(AMSGrad):

オプティマイザを初期化

  • params はパラメータのリストです
  • lr は学習率
  • betas (,) のタプルです
  • eps またはそれに基づいている optimized_update
  • weight_decay WeightDecay で定義されているクラスのインスタンスです __init__.py
  • 'optimized_update'は追加後に行うことでセカンドモーメントのバイアス補正を最適化するかどうかのフラグです
  • amsgrad amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグです
  • warmup ウォームアップステップ数
  • total_steps ステップの総数。この時点でコサイン減衰は0に達しますが、lr 取るため10%にとどまります
  • defaults グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdamWarmup
27    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
28                 weight_decay: WeightDecay = WeightDecay(),
29                 optimized_update: bool = True,
30                 amsgrad=False, warmup=0, total_steps=1e10, defaults=None):
49        defaults = {} if defaults is None else defaults
50        defaults.update(dict(warmup=warmup, total_steps=total_steps))
51        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)

学習率を取得

はウォームアップステップの数です。

53    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):

ウォームアップ段階の場合

61        if group['warmup'] > state['step']:

学習率が 1 から 1 に直線的に増加している

63            return 1e-8 + state['step'] * group['lr'] / group['warmup']
64        else:

一定の学習率

66            progress = (state['step'] - group['warmup']) / max(1, group['total_steps'] - group['warmup'])
67            return group['lr'] * max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))

さまざまなウォームアップとモデルサイズの学習率をプロット

Plot of learning rate

70def _test_lr():
76    import matplotlib.pyplot as plt
77    import numpy as np
78    from torch import nn
79
80    model = nn.Linear(10, 10)
81    opt = AdamWarmupCosineDecay(model.parameters(), warmup=5000, lr=1e-4, total_steps=4e6)
82    steps = 20_000
83    plt.plot(np.arange(1, steps), [opt.get_lr({'step': i}, opt.defaults) for i in range(1, steps)])
84    plt.legend(["5000:4e6", "5000:2e6", "5000:1e6"])
85    plt.title("Learning Rate")
86    plt.show()
87
88    steps = int(6e6)
89    step_size = 1000
90    plt.plot(np.arange(1, steps, step_size), [opt.get_lr({'step': i}, opt.defaults) for i in range(1, steps, step_size)])
91    plt.legend(["5000:4e6", "5000:2e6", "5000:1e6"])
92    plt.title("Learning Rate")
93    plt.show()
94
95
96if __name__ == '__main__':
97    _test_lr()