මෙය සමානාත්මතා කාර්ය පිළිබඳ පොන්ඩර්නෙට් එකක් පුහුණු කරයි.
13from typing import Any
14
15import torch
16from torch import nn
17from torch.utils.data import DataLoader
18
19from labml import tracker, experiment
20from labml_helpers.metrics.accuracy import AccuracyDirect
21from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
22from labml_nn.adaptive_computation.parity import ParityDataset
23from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLossසරල පුහුණු ලූපයක්සහිත වින්යාස කිරීම්
26class Configs(SimpleTrainValidConfigs):එපොච්ගණන
33 epochs: int = 100එක්ඊපෝච්චයකට කණ්ඩායම් ගණන
35 n_batches: int = 500කණ්ඩායම්ප්රමාණය
37 batch_size: int = 128ආකෘතිය
40 model: ParityPonderGRU43 loss_rec: ReconstructionLoss45 loss_reg: RegularizationLossආදානදෛශිකයේ මූලද්රව්ය ගණන. නිරූපණයසඳහා අපි එය අඩු මට්ටමක තබා ගනිමු; එසේ නොමැතිනම් පුහුණුව සඳහා බොහෝ කාලයක් ගත වේ. සමානාත්මතා කාර්යය සරල යැයි පෙනුනද, සාම්පල දෙස බැලීමෙන් රටාව අවබොධ කිරීම තරමක් අපහසුය.
51 n_elems: int = 8සැඟවුණුස්ථරයේ ඒකක ගණන (රාජ්ය)
53 n_hidden: int = 64උපරිමපියවර ගණන
55 max_steps: int = 20ජ්යාමිතික ව්යාප්තිය සඳහා
58 lambda_p: float = 0.2පාඩු සංගුණකය විධිමත් කිරීම
60 beta: float = 0.01සම්මතයඅනුව ශ්රේණිය ක්ලිපින් කිරීම
63 grad_norm_clip: float = 1.0පුහුණුවසහ වලංගු කිරීමේ කාරකයන්
66 train_loader: DataLoader
67 valid_loader: DataLoaderනිරවද්යතාවයකැල්ක්යුලේටරය
70 accuracy = AccuracyDirect()72 def init(self):තිරයවෙත දර්ශක මුද්රණය කරන්න
74 tracker.set_scalar('loss.*', True)
75 tracker.set_scalar('loss_reg.*', True)
76 tracker.set_scalar('accuracy.*', True)
77 tracker.set_scalar('steps.*', True)පුහුණුවසහ වලංගු කිරීම සඳහා එපෝච් සඳහා ඒවා ගණනය කිරීම සඳහා ප්රමිතික සකස් කළ යුතුය
80 self.state_modules = [self.accuracy]ආකෘතියආරම්භ කරන්න
83 self.model = ParityPonderGRU(self.n_elems, self.n_hidden, self.max_steps).to(self.device)85 self.loss_rec = ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)87 self.loss_reg = RegularizationLoss(self.lambda_p, self.max_steps).to(self.device)පුහුණුවසහ වලංගු කිරීමේ කාරකයන්
90 self.train_loader = DataLoader(ParityDataset(self.batch_size * self.n_batches, self.n_elems),
91 batch_size=self.batch_size)
92 self.valid_loader = DataLoader(ParityDataset(self.batch_size * 32, self.n_elems),
93 batch_size=self.batch_size)මෙමක්රමය එක් එක් කණ්ඩායම සඳහා පුහුණුකරු විසින් කැඳවනු ලැබේ
95 def step(self, batch: Any, batch_idx: BatchIndex):ආදර්ශප්රකාරය සකසන්න
100 self.model.train(self.mode.is_train)ආදානසහ ලේබල ලබාගෙන ඒවා ආකෘතියේ උපාංගයට ගෙන යන්න
103 data, target = batch[0].to(self.device), batch[1].to(self.device)පුහුණුමාදිලියේ වර්ධක පියවර
106 if self.mode.is_train:
107 tracker.add_global_step(len(data))ආකෘතියධාවනය කරන්න
110 p, y_hat, p_sampled, y_hat_sampled = self.model(data)ප්රතිසංස්කරණඅලාභය ගණනය කරන්න
113 loss_rec = self.loss_rec(p, y_hat, target.to(torch.float))
114 tracker.add("loss.", loss_rec)නියාමනයකිරීමේ අලාභය ගණනය කරන්න
117 loss_reg = self.loss_reg(p)
118 tracker.add("loss_reg.", loss_reg)121 loss = loss_rec + self.beta * loss_regගෙනඇති පියවර ගණන ගණනය කරන්න
124 steps = torch.arange(1, p.shape[0] + 1, device=p.device)
125 expected_steps = (p * steps[:, None]).sum(dim=0)
126 tracker.add("steps.", expected_steps)ඇමතුම්නිරවද්යතාව මෙට්රික්
129 self.accuracy(y_hat_sampled > 0, target)
130
131 if self.mode.is_train:අනුක්රමිකගණනය
133 loss.backward()ක්ලිප්අනුක්රමික
135 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)ප්රශස්තකරණය
137 self.optimizer.step()අනුක්රමිකපැහැදිලි කරන්න
139 self.optimizer.zero_grad()141 tracker.save()අත්හදාබැලීම ක්රියාත්මක කරන්න
144def main():148 experiment.create(name='ponder_net')
149
150 conf = Configs()
151 experiment.configs(conf, {
152 'optimizer.optimizer': 'Adam',
153 'optimizer.learning_rate': 0.0003,
154 })
155
156 with experiment.start():
157 conf.run()160if __name__ == '__main__':
161 main()