කැප්සියුලජාල සමඟ MNIST ඉලක්කම් වර්ගීකරණය කරන්න

මෙයPyTorch සමඟ MNIST ඉලක්කම් වර්ගීකරණය කිරීම සඳහා විනීත පයිටෝච් කේතයකි.

මෙමලිපිය කඩදාසි විස්තර කර ඇති අත්හදා බැලීම ක්රියාත්මක කරයි ඩයිනමික් රවුටින් කැප්සියුල අතර.

14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss

MNISTඉලක්කම් වර්ගීකරණය කිරීමේ ආකෘතිය

29class MNISTCapsuleNetworkModel(Module):
34    def __init__(self):
35        super().__init__()

පළමුකැටි ගැසුණු ස්ථරය , convolution කර්නල්

37        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)

දෙවනස්ථරය (ප්රාථමික කරල්) s convolutional කරල් නාලිකා (කරලක් අනුව ලක්ෂණ) සමග convolutional කරලක් ස්ථරය. එනම්, සෑම ප්රාථමික කැප්සියුලයකම 9 × 9 කර්නලයක් සහ 2 ක ඉරි සහිත සංයුක්ත ඒකක 8 ක් අඩංගු වේ. මෙය ක්රියාත්මක කිරීම සඳහා අපි නාලිකා සහිත සංවහන තට්ටුවක් නිර්මාණය කර එක් එක් විශේෂාංග කැප්සියුල ලබා ගැනීම සඳහා එහි ප්රතිදානය නැවත සකස් කර පරිපූර්ණ කරමු.

43        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44        self.squash = Squash()

රවුටින්ස්තරය ප්රාථමික කැප්සියුල ලැබෙන අතර කැප්සියුල නිෂ්පාදනය කරයි. සෑම ප්රාථමික කැප්සියුලයකම විශේෂාංග ඇති අතර ප්රතිදාන කැප්සියුල (ඉලක්කම් කැප්සියුල) විශේෂාංග ඇත. රවුටින් ඇල්ගොරිතම වරක් පුනරාවර්තනය වේ.

50        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)

මෙමකඩදාසි සඳහන් විකේතකය වේ. එය ඉලක්කම් කැප්සියුල වල ප්රතිදානයන් ගනී, එක් එක් රූපය ප්රතිනිෂ්පාදනය කිරීම සඳහා විශේෂාංග ඇත. එය ප්රමාණවලින් රේඛීය ස්ථර හරහා සහ සක්රිය කිරීම් සමඟ ගමන් කරයි.

55        self.decoder = nn.Sequential(
56            nn.Linear(16 * 10, 512),
57            nn.ReLU(),
58            nn.Linear(512, 1024),
59            nn.ReLU(),
60            nn.Linear(1024, 784),
61            nn.Sigmoid()
62        )

data හැඩය සහිත MNIST රූප [batch_size, 1, 28, 28]

64    def forward(self, data: torch.Tensor):

පළමුකැටි ගැසුණු ස්තරය හරහා ගමන් කරන්න. මෙම ස්ථරයේ ප්රතිදානය හැඩය ඇත [batch_size, 256, 20, 20]

70        x = F.relu(self.conv1(data))

දෙවනකැටි ගැසුණු ස්තරය හරහා ගමන් කරන්න. මෙම ප්රතිදානය හැඩය ඇත [batch_size, 32 * 8, 6, 6] . මෙමස්ථරයට දිගු දිගක් ඇති බව සලකන්න .

74        x = self.conv2(x)

කැප්සියුලලබා ගැනීම සඳහා ප්රමාණය වෙනස් කර permutate කරන්න

77        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)

කැප්සියුලස්කොෂ් කරන්න

79        caps = self.squash(caps)

ඉලක්කම්කැප්සියුල ලබා ගැනීම සඳහා රවුටරය හරහා ඒවා රැගෙන යන්න. මෙය හැඩය ඇත [batch_size, 10, 16] .

82        caps = self.digit_capsules(caps)

ප්රතිනිර්මාණයසඳහා වෙස් මුහුණු ලබා ගන්න

85        with torch.no_grad():

කැප්සියුලජාලය විසින් පුරෝකථනය කරනු ලබන්නේ දිගම දිග සහිත කැප්සියුලයයි

87            pred = (caps ** 2).sum(-1).argmax(-1)

අනෙක්සියලුම කැප්සියුල වෙස්මුහුණ දීමට වෙස්මුහුණක් සාදන්න

89            mask = torch.eye(10, device=data.device)[pred]

අනාවැකියකළ කැප්සියුලය පමණක් ලබා ගැනීම සඳහා ඉලක්කම් කැප්සියුල Mask කර ප්රතිනිර්මාණය ලබා ගැනීම සඳහා විකේතකය හරහා එය රැගෙන යන්න

93        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))

රූපමානයන් ගැලපෙන පරිදි ප්රතිනිර්මාණය නැවත සකස් කරන්න

95        reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97        return caps, reconstructions, pred

MNISTදත්ත සහ දුම්රිය සහ වලංගු කිරීමේ සැකසුම සමඟ වින්යාස කිරීම්

100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
104    epochs: int = 10
105    model: nn.Module = 'capsule_network_model'
106    reconstruction_loss = nn.MSELoss()
107    margin_loss = MarginLoss(n_labels=10)
108    accuracy = AccuracyDirect()
110    def init(self):

පාඩුසහ නිරවද්යතාව තිරයට මුද්රණය කරන්න

112        tracker.set_scalar('loss.*', True)
113        tracker.set_scalar('accuracy.*', True)

පුහුණුවසහ වලංගු කිරීම සඳහා එපෝච් සඳහා ඒවා ගණනය කිරීම සඳහා ප්රමිතික සකස් කළ යුතුය

116        self.state_modules = [self.accuracy]

මෙමක්රමය පුහුණුකරු විසින් කැඳවනු ලැබේ

118    def step(self, batch: Any, batch_idx: BatchIndex):

ආදර්ශප්රකාරය සකසන්න

123        self.model.train(self.mode.is_train)

පින්තූරසහ ලේබල් ලබාගෙන ඒවා ආකෘතියේ උපාංගයට ගෙන යන්න

126        data, target = batch[0].to(self.device), batch[1].to(self.device)

පුහුණුමාදිලියේ වර්ධක පියවර

129        if self.mode.is_train:
130            tracker.add_global_step(len(data))

සක්රියකිරීම් ලොග් කළ යුතුද යන්න

133        with self.mode.update(is_log_activations=batch_idx.is_last):

ආකෘතියධාවනය කරන්න

135            caps, reconstructions, pred = self.model(data)

සම්පූර්ණඅලාභය ගණනය කරන්න

138        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139        tracker.add("loss.", loss)

ඇමතුම්නිරවද්යතාව මෙට්රික්

142        self.accuracy(pred, target)
143
144        if self.mode.is_train:
145            loss.backward()
146
147            self.optimizer.step()

ලොග්පරාමිතීන් සහ අනුක්රමික

149            if batch_idx.is_last:
150                tracker.add('model', self.model)
151            self.optimizer.zero_grad()
152
153            tracker.save()

ආකෘතියසකසන්න

156@option(Configs.model)
157def capsule_network_model(c: Configs):
159    return MNISTCapsuleNetworkModel().to(c.device)

අත්හදාබැලීම ක්රියාත්මක කරන්න

162def main():
166    experiment.create(name='capsule_network_mnist')
167    conf = Configs()
168    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169                              'optimizer.learning_rate': 1e-3})
170
171    experiment.add_pytorch_models({'model': conf.model})
172
173    with experiment.start():
174        conf.run()
175
176
177if __name__ == '__main__':
178    main()