MNISTසමඟ WGAN-GP අත්හදා බැලීම

10import torch
11
12from labml import experiment, tracker

වොසර්ස්ටයින් අත්හදා බැලීමෙන් වින්යාසයන් ආනයනය කරන්න

14from labml_nn.gan.wasserstein.experiment import Configs as OriginalConfigs

16from labml_nn.gan.wasserstein.gradient_penalty import GradientPenalty

වින්යාසපන්තිය

අපි මුල් GAN ක්රියාත්මක කිරීම දීර් extend කර වර්ගීකරණ ද penalty ුවම ඇතුළත් කිරීම සඳහා වෙනස්කම් කරන්නා (විචාරක) පාඩු ගණනය කිරීම අභිබවා යමු.

19class Configs(OriginalConfigs):

ශ්රේණියේදණ්ඩන සංගුණකය

28    gradient_penalty_coefficient: float = 10.0

30    gradient_penalty = GradientPenalty()

මෙයමුල් වෙනස්කම් කරන්නාගේ අලාභය ගණනය කිරීම අභිබවා යන අතර ශ්රේණියේ ද penalty ුවම් ද ඇතුළත් වේ.

32    def calc_discriminator_loss(self, data: torch.Tensor):

ඵලයඅනුක්රමික දඬුවම ගණනය කිරීමට මත ඵලය අනුක්රමික අවශ්ය

38        data.requires_grad_()

නියැදිය

40        latent = self.sample_z(data.shape[0])

42        f_real = self.discriminator(data)

44        f_fake = self.discriminator(self.generator(latent).detach())

වෙනස්කම්කරන්නාගේ පාඩු ලබා ගන්න

46        loss_true, loss_false = self.discriminator_loss(f_real, f_fake)

පුහුණුප්රකාරයේදී ශ්රේණියේ ද ties ුවම් ගණනය කරන්න

48        if self.mode.is_train:
49            gradient_penalty = self.gradient_penalty(data, f_real)
50            tracker.add("loss.gp.", gradient_penalty)
51            loss = loss_true + loss_false + self.gradient_penalty_coefficient * gradient_penalty

වෙනත්ආකාරයකින් ශ්රේණියේ ද penalty ුවම මඟ හරින්න

53        else:
54            loss = loss_true + loss_false

ලොග්දේවල්

57        tracker.add("loss.discriminator.true.", loss_true)
58        tracker.add("loss.discriminator.false.", loss_false)
59        tracker.add("loss.discriminator.", loss)
60
61        return loss
64def main():

වින්යාසවස්තුව සාදන්න

66    conf = Configs()

අත්හදාබැලීම සාදන්න

68    experiment.create(name='mnist_wassertein_gp_dcgan')

වින්යාසයන්අභිබවා යන්න

70    experiment.configs(conf,
71                       {
72                           'discriminator': 'cnn',
73                           'generator': 'cnn',
74                           'label_smoothing': 0.01,
75                           'generator_loss': 'wasserstein',
76                           'discriminator_loss': 'wasserstein',
77                           'discriminator_k': 5,
78                       })

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

81    with experiment.start():
82        conf.run()
83
84
85if __name__ == '__main__':
86    main()