මෙයPyTorch සමඟ MNIST ඉලක්කම් වර්ගීකරණය කිරීම සඳහා විනීත පයිටෝච් කේතයකි.
මෙමලිපිය කඩදාසි විස්තර කර ඇති අත්හදා බැලීම ක්රියාත්මක කරයි ඩයිනමික් රවුටින් කැප්සියුල අතර.
14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss29class MNISTCapsuleNetworkModel(Module):34 def __init__(self):
35 super().__init__()පළමුකැටි ගැසුණු ස්ථරය , convolution කර්නල්
37 self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)දෙවනස්ථරය (ප්රාථමික කරල්) s convolutional කරල් නාලිකා (කරලක් අනුව ලක්ෂණ) සමග convolutional කරලක් ස්ථරය. එනම්, සෑම ප්රාථමික කැප්සියුලයකම 9 × 9 කර්නලයක් සහ 2 ක ඉරි සහිත සංයුක්ත ඒකක 8 ක් අඩංගු වේ. මෙය ක්රියාත්මක කිරීම සඳහා අපි නාලිකා සහිත සංවහන තට්ටුවක් නිර්මාණය කර එක් එක් විශේෂාංග කැප්සියුල ලබා ගැනීම සඳහා එහි ප්රතිදානය නැවත සකස් කර පරිපූර්ණ කරමු.
43 self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44 self.squash = Squash()රවුටින්ස්තරය ප්රාථමික කැප්සියුල ලැබෙන අතර කැප්සියුල නිෂ්පාදනය කරයි. සෑම ප්රාථමික කැප්සියුලයකම විශේෂාංග ඇති අතර ප්රතිදාන කැප්සියුල (ඉලක්කම් කැප්සියුල) විශේෂාංග ඇත. රවුටින් ඇල්ගොරිතම වරක් පුනරාවර්තනය වේ.
50 self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)මෙමකඩදාසි සඳහන් විකේතකය වේ. එය ඉලක්කම් කැප්සියුල වල ප්රතිදානයන් ගනී, එක් එක් රූපය ප්රතිනිෂ්පාදනය කිරීම සඳහා විශේෂාංග ඇත. එය ප්රමාණවලින් රේඛීය ස්ථර හරහා සහ සක්රිය කිරීම් සමඟ ගමන් කරයි.
55 self.decoder = nn.Sequential(
56 nn.Linear(16 * 10, 512),
57 nn.ReLU(),
58 nn.Linear(512, 1024),
59 nn.ReLU(),
60 nn.Linear(1024, 784),
61 nn.Sigmoid()
62 ) data
හැඩය සහිත MNIST රූප [batch_size, 1, 28, 28]
64 def forward(self, data: torch.Tensor):පළමුකැටි ගැසුණු ස්තරය හරහා ගමන් කරන්න. මෙම ස්ථරයේ ප්රතිදානය හැඩය ඇත [batch_size, 256, 20, 20]
70 x = F.relu(self.conv1(data))දෙවනකැටි ගැසුණු ස්තරය හරහා ගමන් කරන්න. මෙම ප්රතිදානය හැඩය ඇත [batch_size, 32 * 8, 6, 6]
. මෙමස්ථරයට දිගු දිගක් ඇති බව සලකන්න .
74 x = self.conv2(x)කැප්සියුලලබා ගැනීම සඳහා ප්රමාණය වෙනස් කර permutate කරන්න
77 caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)කැප්සියුලස්කොෂ් කරන්න
79 caps = self.squash(caps)ඉලක්කම්කැප්සියුල ලබා ගැනීම සඳහා රවුටරය හරහා ඒවා රැගෙන යන්න. මෙය හැඩය ඇත [batch_size, 10, 16]
.
82 caps = self.digit_capsules(caps)ප්රතිනිර්මාණයසඳහා වෙස් මුහුණු ලබා ගන්න
85 with torch.no_grad():කැප්සියුලජාලය විසින් පුරෝකථනය කරනු ලබන්නේ දිගම දිග සහිත කැප්සියුලයයි
87 pred = (caps ** 2).sum(-1).argmax(-1)අනෙක්සියලුම කැප්සියුල වෙස්මුහුණ දීමට වෙස්මුහුණක් සාදන්න
89 mask = torch.eye(10, device=data.device)[pred]අනාවැකියකළ කැප්සියුලය පමණක් ලබා ගැනීම සඳහා ඉලක්කම් කැප්සියුල Mask කර ප්රතිනිර්මාණය ලබා ගැනීම සඳහා විකේතකය හරහා එය රැගෙන යන්න
93 reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))රූපමානයන් ගැලපෙන පරිදි ප්රතිනිර්මාණය නැවත සකස් කරන්න
95 reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97 return caps, reconstructions, predMNISTදත්ත සහ දුම්රිය සහ වලංගු කිරීමේ සැකසුම සමඟ වින්යාස කිරීම්
100class Configs(MNISTConfigs, SimpleTrainValidConfigs):104 epochs: int = 10
105 model: nn.Module = 'capsule_network_model'
106 reconstruction_loss = nn.MSELoss()
107 margin_loss = MarginLoss(n_labels=10)
108 accuracy = AccuracyDirect()110 def init(self):පාඩුසහ නිරවද්යතාව තිරයට මුද්රණය කරන්න
112 tracker.set_scalar('loss.*', True)
113 tracker.set_scalar('accuracy.*', True)පුහුණුවසහ වලංගු කිරීම සඳහා එපෝච් සඳහා ඒවා ගණනය කිරීම සඳහා ප්රමිතික සකස් කළ යුතුය
116 self.state_modules = [self.accuracy]මෙමක්රමය පුහුණුකරු විසින් කැඳවනු ලැබේ
118 def step(self, batch: Any, batch_idx: BatchIndex):ආදර්ශප්රකාරය සකසන්න
123 self.model.train(self.mode.is_train)පින්තූරසහ ලේබල් ලබාගෙන ඒවා ආකෘතියේ උපාංගයට ගෙන යන්න
126 data, target = batch[0].to(self.device), batch[1].to(self.device)පුහුණුමාදිලියේ වර්ධක පියවර
129 if self.mode.is_train:
130 tracker.add_global_step(len(data))සක්රියකිරීම් ලොග් කළ යුතුද යන්න
133 with self.mode.update(is_log_activations=batch_idx.is_last):ආකෘතියධාවනය කරන්න
135 caps, reconstructions, pred = self.model(data)සම්පූර්ණඅලාභය ගණනය කරන්න
138 loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139 tracker.add("loss.", loss)ඇමතුම්නිරවද්යතාව මෙට්රික්
142 self.accuracy(pred, target)
143
144 if self.mode.is_train:
145 loss.backward()
146
147 self.optimizer.step()ලොග්පරාමිතීන් සහ අනුක්රමික
149 if batch_idx.is_last:
150 tracker.add('model', self.model)
151 self.optimizer.zero_grad()
152
153 tracker.save()ආකෘතියසකසන්න
156@option(Configs.model)
157def capsule_network_model(c: Configs):159 return MNISTCapsuleNetworkModel().to(c.device)අත්හදාබැලීම ක්රියාත්මක කරන්න
162def main():166 experiment.create(name='capsule_network_mnist')
167 conf = Configs()
168 experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169 'optimizer.learning_rate': 1e-3})
170
171 experiment.add_pytorch_models({'model': conf.model})
172
173 with experiment.start():
174 conf.run()
175
176
177if __name__ == '__main__':
178 main()