උත්පාදකඅහිතකර ජාල MNIST සමඟ අත්හදා බැලීම

10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
27def weights_init(m):
28    classname = m.__class__.__name__
29    if classname.find('Linear') != -1:
30        nn.init.normal_(m.weight.data, 0.0, 0.02)
31    elif classname.find('BatchNorm') != -1:
32        nn.init.normal_(m.weight.data, 1.0, 0.02)
33        nn.init.constant_(m.bias.data, 0)

සරලඑම්එල්පී උත්පාදක යන්ත්රය

මෙය LeakyReLU සක්රිය කිරීම් සමඟ ප්රමාණය වැඩි කිරීමේ රේඛීය ස්ථර තුනක් ඇත. අවසාන ස්ථරය සක්රිය කිරීමක් ඇත.

36class Generator(Module):
44    def __init__(self):
45        super().__init__()
46        layer_sizes = [256, 512, 1024]
47        layers = []
48        d_prev = 100
49        for size in layer_sizes:
50            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51            d_prev = size
52
53        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55        self.apply(weights_init)
57    def forward(self, x):
58        return self.layers(x).view(x.shape[0], 1, 28, 28)

සරලඑම්එල්පී වෙනස්කම් කරන්නා

මෙය LeakyReLU සක්රිය කිරීම් සමඟ ප්රමාණය අඩු කිරීමේ රේඛීය ස්ථර තුනක් ඇත. අවසාන ස්ථරයට තනි ප්රතිදානයක් ඇති අතර එමඟින් ආදානය සැබෑ හෝ ව්යාජ ද යන්න පිළිබඳ පිවිසුම ලබා දේ. එය සිග්මෝයිඩ් ගණනය කිරීමෙන් ඔබට සම්භාවිතාව ලබා ගත හැකිය.

61class Discriminator(Module):
70    def __init__(self):
71        super().__init__()
72        layer_sizes = [1024, 512, 256]
73        layers = []
74        d_prev = 28 * 28
75        for size in layer_sizes:
76            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77            d_prev = size
78
79        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80        self.apply(weights_init)
82    def forward(self, x):
83        return self.layers(x.view(x.shape[0], -1))

වින්යාසකිරීම්

අපගේක්රියාත්මක කිරීම සරල කිරීම සඳහා දත්ත පැටවුම් සහ පුහුණු සහ වලංගු කිරීමේ ලූප වින්යාසයන් ලබා ගැනීම සඳහා මෙය MNIST වින්යාසයන් පුළුල් කරයි.

86class Configs(MNISTConfigs, TrainValidConfigs):
94    device: torch.device = DeviceConfigs()
95    dataset_transforms = 'mnist_gan_transforms'
96    epochs: int = 10
97
98    is_save_models = True
99    discriminator: Module = 'mlp'
100    generator: Module = 'mlp'
101    generator_optimizer: torch.optim.Adam
102    discriminator_optimizer: torch.optim.Adam
103    generator_loss: GeneratorLogitsLoss = 'original'
104    discriminator_loss: DiscriminatorLogitsLoss = 'original'
105    label_smoothing: float = 0.2
106    discriminator_k: int = 1

ආරම්භකකරණය

108    def init(self):
112        self.state_modules = []
113
114        hook_model_outputs(self.mode, self.generator, 'generator')
115        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116        tracker.set_scalar("loss.generator.*", True)
117        tracker.set_scalar("loss.discriminator.*", True)
118        tracker.set_image("generated", True, 1 / 100)

120    def sample_z(self, batch_size: int):
124        return torch.randn(batch_size, 100, device=self.device)

පුහුණුපියවරක් ගන්න

126    def step(self, batch: Any, batch_idx: BatchIndex):

ආදර්ශතත්වයන් සකසන්න

132        self.generator.train(self.mode.is_train)
133        self.discriminator.train(self.mode.is_train)

MNISTරූප ලබා ගන්න

136        data = batch[0].to(self.device)

පුහුණුමාදිලියේ වර්ධක පියවර

139        if self.mode.is_train:
140            tracker.add_global_step(len(data))

වෙනස්කම්කරන්නා පුහුණු කරන්න

143        with monit.section("discriminator"):

වෙනස්කම්කරන්නාගේ පාඩුව ලබා ගන්න

145            loss = self.calc_discriminator_loss(data)

දුම්රිය

148            if self.mode.is_train:
149                self.discriminator_optimizer.zero_grad()
150                loss.backward()
151                if batch_idx.is_last:
152                    tracker.add('discriminator', self.discriminator)
153                self.discriminator_optimizer.step()

සෑමවිටම උත්පාදක යන්ත්රය පුහුණු කරන්න discriminator_k

156        if batch_idx.is_interval(self.discriminator_k):
157            with monit.section("generator"):
158                loss = self.calc_generator_loss(data.shape[0])

දුම්රිය

161                if self.mode.is_train:
162                    self.generator_optimizer.zero_grad()
163                    loss.backward()
164                    if batch_idx.is_last:
165                        tracker.add('generator', self.generator)
166                    self.generator_optimizer.step()
167
168        tracker.save()

වෙනස්කම්කරන්නාගේ පාඩුව ගණනය කරන්න

170    def calc_discriminator_loss(self, data):
174        latent = self.sample_z(data.shape[0])
175        logits_true = self.discriminator(data)
176        logits_false = self.discriminator(self.generator(latent).detach())
177        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178        loss = loss_true + loss_false

ලොග්දේවල්

181        tracker.add("loss.discriminator.true.", loss_true)
182        tracker.add("loss.discriminator.false.", loss_false)
183        tracker.add("loss.discriminator.", loss)
184
185        return loss

උත්පාදකඅලාභය ගණනය කරන්න

187    def calc_generator_loss(self, batch_size: int):
191        latent =  self.sample_z(batch_size)
192        generated_images = self.generator(latent)
193        logits = self.discriminator(generated_images)
194        loss = self.generator_loss(logits)

ලොග්දේවල්

197        tracker.add('generated', generated_images[0:6])
198        tracker.add("loss.generator.", loss)
199
200        return loss
205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207    return transforms.Compose([
208        transforms.ToTensor(),
209        transforms.Normalize((0.5,), (0.5,))
210    ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215    opt_conf = OptimizerConfigs()
216    opt_conf.optimizer = 'Adam'
217    opt_conf.parameters = c.discriminator.parameters()
218    opt_conf.learning_rate = 2.5e-4

ශ්රේණියේපළමු මොහොත සඳහා on ාතීය ක්ෂය වීමේ අනුපාතය සැකසීම වැදගත් 0.5 වේ. 0.9 අසමත් වීමේ පෙරනිමි.

222    opt_conf.betas = (0.5, 0.999)
223    return opt_conf
226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228    opt_conf = OptimizerConfigs()
229    opt_conf.optimizer = 'Adam'
230    opt_conf.parameters = c.generator.parameters()
231    opt_conf.learning_rate = 2.5e-4

ශ්රේණියේපළමු මොහොත සඳහා on ාතීය ක්ෂය වීමේ අනුපාතය සැකසීම වැදගත් 0.5 වේ. 0.9 අසමත් වීමේ පෙරනිමි.

235    opt_conf.betas = (0.5, 0.999)
236    return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
245def main():
246    conf = Configs()
247    experiment.create(name='mnist_gan', comment='test')
248    experiment.configs(conf,
249                       {'label_smoothing': 0.01})
250    with experiment.start():
251        conf.run()
252
253
254if __name__ == '__main__':
255    main()