පොන්ඩර්නෙට් සමානාත්මතා කාර්ය අත්හදා බැලීම

මෙය සමානාත්මතා කාර්ය පිළිබඳ පොන්ඩර්නෙට් එකක් පුහුණු කරයි.

13from typing import Any
14
15import torch
16from torch import nn
17from torch.utils.data import DataLoader
18
19from labml import tracker, experiment
20from labml_helpers.metrics.accuracy import AccuracyDirect
21from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
22from labml_nn.adaptive_computation.parity import ParityDataset
23from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLoss

සරල පුහුණු ලූපයක්සහිත වින්යාස කිරීම්

26class Configs(SimpleTrainValidConfigs):

එපොච්ගණන

33    epochs: int = 100

එක්ඊපෝච්චයකට කණ්ඩායම් ගණන

35    n_batches: int = 500

කණ්ඩායම්ප්රමාණය

37    batch_size: int = 128

ආකෘතිය

40    model: ParityPonderGRU

43    loss_rec: ReconstructionLoss

45    loss_reg: RegularizationLoss

ආදානදෛශිකයේ මූලද්රව්ය ගණන. නිරූපණයසඳහා අපි එය අඩු මට්ටමක තබා ගනිමු; එසේ නොමැතිනම් පුහුණුව සඳහා බොහෝ කාලයක් ගත වේ. සමානාත්මතා කාර්යය සරල යැයි පෙනුනද, සාම්පල දෙස බැලීමෙන් රටාව අවබොධ කිරීම තරමක් අපහසුය.

51    n_elems: int = 8

සැඟවුණුස්ථරයේ ඒකක ගණන (රාජ්ය)

53    n_hidden: int = 64

උපරිමපියවර ගණන

55    max_steps: int = 20

ජ්යාමිතික ව්යාප්තිය සඳහා

58    lambda_p: float = 0.2

පාඩු සංගුණකය විධිමත් කිරීම

60    beta: float = 0.01

සම්මතයඅනුව ශ්රේණිය ක්ලිපින් කිරීම

63    grad_norm_clip: float = 1.0

පුහුණුවසහ වලංගු කිරීමේ කාරකයන්

66    train_loader: DataLoader
67    valid_loader: DataLoader

නිරවද්යතාවයකැල්ක්යුලේටරය

70    accuracy = AccuracyDirect()
72    def init(self):

තිරයවෙත දර්ශක මුද්රණය කරන්න

74        tracker.set_scalar('loss.*', True)
75        tracker.set_scalar('loss_reg.*', True)
76        tracker.set_scalar('accuracy.*', True)
77        tracker.set_scalar('steps.*', True)

පුහුණුවසහ වලංගු කිරීම සඳහා එපෝච් සඳහා ඒවා ගණනය කිරීම සඳහා ප්රමිතික සකස් කළ යුතුය

80        self.state_modules = [self.accuracy]

ආකෘතියආරම්භ කරන්න

83        self.model = ParityPonderGRU(self.n_elems, self.n_hidden, self.max_steps).to(self.device)

85        self.loss_rec = ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)

87        self.loss_reg = RegularizationLoss(self.lambda_p, self.max_steps).to(self.device)

පුහුණුවසහ වලංගු කිරීමේ කාරකයන්

90        self.train_loader = DataLoader(ParityDataset(self.batch_size * self.n_batches, self.n_elems),
91                                       batch_size=self.batch_size)
92        self.valid_loader = DataLoader(ParityDataset(self.batch_size * 32, self.n_elems),
93                                       batch_size=self.batch_size)

මෙමක්රමය එක් එක් කණ්ඩායම සඳහා පුහුණුකරු විසින් කැඳවනු ලැබේ

95    def step(self, batch: Any, batch_idx: BatchIndex):

ආදර්ශප්රකාරය සකසන්න

100        self.model.train(self.mode.is_train)

ආදානසහ ලේබල ලබාගෙන ඒවා ආකෘතියේ උපාංගයට ගෙන යන්න

103        data, target = batch[0].to(self.device), batch[1].to(self.device)

පුහුණුමාදිලියේ වර්ධක පියවර

106        if self.mode.is_train:
107            tracker.add_global_step(len(data))

ආකෘතියධාවනය කරන්න

110        p, y_hat, p_sampled, y_hat_sampled = self.model(data)

ප්රතිසංස්කරණඅලාභය ගණනය කරන්න

113        loss_rec = self.loss_rec(p, y_hat, target.to(torch.float))
114        tracker.add("loss.", loss_rec)

නියාමනයකිරීමේ අලාභය ගණනය කරන්න

117        loss_reg = self.loss_reg(p)
118        tracker.add("loss_reg.", loss_reg)

121        loss = loss_rec + self.beta * loss_reg

ගෙනඇති පියවර ගණන ගණනය කරන්න

124        steps = torch.arange(1, p.shape[0] + 1, device=p.device)
125        expected_steps = (p * steps[:, None]).sum(dim=0)
126        tracker.add("steps.", expected_steps)

ඇමතුම්නිරවද්යතාව මෙට්රික්

129        self.accuracy(y_hat_sampled > 0, target)
130
131        if self.mode.is_train:

අනුක්රමිකගණනය

133            loss.backward()

ක්ලිප්අනුක්රමික

135            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

ප්රශස්තකරණය

137            self.optimizer.step()

අනුක්රමිකපැහැදිලි කරන්න

139            self.optimizer.zero_grad()

141            tracker.save()

අත්හදාබැලීම ක්රියාත්මක කරන්න

144def main():
148    experiment.create(name='ponder_net')
149
150    conf = Configs()
151    experiment.configs(conf, {
152        'optimizer.optimizer': 'Adam',
153        'optimizer.learning_rate': 0.0003,
154    })
155
156    with experiment.start():
157        conf.run()

160if __name__ == '__main__':
161    main()