මෙය AMSGrad ප්රශස්තකරණය පුළුල් කරන අතර උනුසුම් අවධියක් එක් කරයි.
11import math
12from typing import Dict
13
14from labml_nn.optimizers import WeightDecay
15from labml_nn.optimizers.amsgrad import AMSGradමෙමපන්තිය AMSGrad ප්රශස්තකරණයෙන් අර්ථ දක්වා ඇත amsgrad.py
.
18class AdamWarmupCosineDecay(AMSGrad):params
යනු පරාමිතීන් ලැයිස්තුවයි lr
යනු ඉගෙනුම් අනුපාතයයි betas
(, ) ක tuple වේ eps
හෝ මත පදනම් වේ optimized_update
weight_decay
WeightDecay
අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
amsgrad
ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකි warmup
උනුසුම් පියවර ගණන total_steps
මුළු පියවර ගණන. කොසයින් ක්ෂය වීම මේ වන විට 0 දක්වා ළඟා වේ, නමුත් අප ගන්නා lr
නිසා 10% ක රැඳී සිටියි defaults
කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් AdamWarmup
වේ. 27 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
28 weight_decay: WeightDecay = WeightDecay(),
29 optimized_update: bool = True,
30 amsgrad=False, warmup=0, total_steps=1e10, defaults=None):49 defaults = {} if defaults is None else defaults
50 defaults.update(dict(warmup=warmup, total_steps=total_steps))
51 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)53 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):අපිඋනුසුම් අවධියක සිටී නම්
61 if group['warmup'] > state['step']:සිට රේඛීයව වැඩිවන ඉගෙනුම් අනුපාතය
63 return 1e-8 + state['step'] * group['lr'] / group['warmup']
64 else:නිරන්තරඉගෙනුම් අනුපාතය
66 progress = (state['step'] - group['warmup']) / max(1, group['total_steps'] - group['warmup'])
67 return group['lr'] * max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))70def _test_lr():76 import matplotlib.pyplot as plt
77 import numpy as np
78 from torch import nn
79
80 model = nn.Linear(10, 10)
81 opt = AdamWarmupCosineDecay(model.parameters(), warmup=5000, lr=1e-4, total_steps=4e6)
82 steps = 20_000
83 plt.plot(np.arange(1, steps), [opt.get_lr({'step': i}, opt.defaults) for i in range(1, steps)])
84 plt.legend(["5000:4e6", "5000:2e6", "5000:1e6"])
85 plt.title("Learning Rate")
86 plt.show()
87
88 steps = int(6e6)
89 step_size = 1000
90 plt.plot(np.arange(1, steps, step_size), [opt.get_lr({'step': i}, opt.defaults) for i in range(1, steps, step_size)])
91 plt.legend(["5000:4e6", "5000:2e6", "5000:1e6"])
92 plt.title("Learning Rate")
93 plt.show()
94
95
96if __name__ == '__main__':
97 _test_lr()