උත්පාදකඅහිතකර ජාල (GAN)

මෙය Generative Aversarial Networkක්රියාත්මක කිරීමයි.

උත්පාදකයන්ත්රය, දත්ත බෙදා හැරීමට ගැලපෙන සාම්පල ජනනය කරන අතර වෙනස්කම් කරන්නා, වඩා දත්ත වලින් පැමිණි සම්භාවිතාව ලබා දෙයි .

වටිනාකම්ක්රියාකාරිත්වය සහිත ක්රීඩක දෙකක මිනි-මැක්ස් ක්රීඩාවක් සඳහා අපි පුහුණු කරමු .

යනු දත්ත වලට වඩා සම්භාවිතා බෙදා හැරීම වන අතර සම්භාවිතා ව්යාප්තිය වන අතර එය ගවුසියානු ශබ්දයට සකසා ඇත.

මෙමගොනුව පාඩු කාර්යයන් අර්ථ දක්වයි. මෙන්න උත්පාදක යන්ත්රය සහ වෙනස්කම් කරන්නා සඳහා බහු ස්ථර perceptron දෙකක් සහිත MNIST උදාහරණයකි.

34import torch
35import torch.nn as nn
36import torch.utils.data
37import torch.utils.data
38
39from labml_helpers.module import Module

වෙනස්කම්කරන්නාගේ පාඩුව

වෙනස්කම්කරන්නා ශ්රේණිය මතට නැග්විය යුතුය,

කුඩා කණ්ඩායම් ප්රමාණය වන අතර කුඩා කණ්ඩායමේ සාම්පල දර්ශකය සඳහා භාවිතා කරයි. වෙතින් සාම්පල වන අතර ඒවා සාම්පල වේ.

42class DiscriminatorLogitsLoss(Module):
57    def __init__(self, smoothing: float = 0.2):
58        super().__init__()

අපිභාවිතා කරමු PyTorch ද්විමය ක්රොස් එන්ට්රොපි නැතිවීම , එනම් ලේබල් සහ අනාවැකි කොහේද යන්නයි. සෘණලකුණ සටහන් කරන්න. අපි සිට සමාන ලේබල් භාවිතා කරන අතර සිට සමාන ලේබල් එවිට මේවායේ එකතුව මතට බැස යාම ඉහත ශ්රේණීය මත නැඟීම හා සමාන වේ.

BCEWithLogitsLoss සොෆ්ට්මැක්ස් සහ ද්විමය හරස් එන්ට්රොපි නැතිවීම ඒකාබද්ධ කරයි.

69        self.loss_true = nn.BCEWithLogitsLoss()
70        self.loss_false = nn.BCEWithLogitsLoss()

අපිලේබල් සුමටනය භාවිතා කරන්නේ එය සමහර අවස්ථාවලදී වඩා හොඳින් ක්රියා කරන බව පෙනෙන බැවිනි

73        self.smoothing = smoothing

ලේබලබෆර් ලෙස ලියාපදිංචි කර ඇති අතර නොපසුබට උත්සාහය සකසා False ඇත.

76        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
77        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)

logits_true සිට පිවිසුම් logits_false වන අතර සිට පිවිසුම් වේ

79    def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
84        if len(logits_true) > len(self.labels_true):
85            self.register_buffer("labels_true",
86                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
87        if len(logits_false) > len(self.labels_false):
88            self.register_buffer("labels_false",
89                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
90
91        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
92                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))

උත්පාදකනැතිවීම

උත්පාදකයන්ත්රය අනුක්රමික මතට බැස යා යුතුය,

95class GeneratorLogitsLoss(Module):
105    def __init__(self, smoothing: float = 0.2):
106        super().__init__()
107        self.loss_true = nn.BCEWithLogitsLoss()
108        self.smoothing = smoothing

අපිසමාන ලේබල් භාවිතා කරමු සිට එවිට මෙම අලාභය මත බැස යාම ඉහත ශ්රේණියෙන් බැස යාම හා සමාන වේ.

112        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
114    def forward(self, logits: torch.Tensor):
115        if len(logits) > len(self.fake_labels):
116            self.register_buffer("fake_labels",
117                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
118
119        return self.loss_true(logits, self.fake_labels[:len(logits)])

සුමටලේබල සාදන්න

122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
126    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)