චක්රයGAN

මෙය PyTorch ක්රියාත්මක කිරීම/නිබන්ධනයක් වන අතර එය Cycle-Consistent adversarial Networks භාවිතා කරමින් PyTorch Image පරිවර්තනය නොකළ අනුරූපය-රූප පරිවර්තනය .

මම එරික්ලින්ඩර්නොරන්/පයිටෝර්ච්-ගන්වෙතින් කේත කෑලි ගෙන ඇත්තෙමි. ඔබට වෙනත් GAN වෙනස්කම් පරීක්ෂා කිරීමට අවශ්ය නම් එය ඉතා හොඳ සම්පතකි.

චක්රයGAN රූපය-කිරීමට රූප පරිවර්තනය කරයි. ලබා දී ඇති බෙදාහැරීමෙන් රූපයක් තවත් එකකට පරිවර්තනය කිරීම සඳහා ආකෘතියක් පුහුණු කරයි, පවසන්නේ, A සහ B පන්තියේ රූප යම් බෙදාහැරීමක් පිළිබඳ රූප යම් ශෛලියක හෝ ස්වභාවධර්මයේ රූප වැනි දේවල් විය හැකිය. ආකෘති A සහ B අතර යුගල රූප අවශ්ය නොවේ එක් එක් පන්තියේ රූප සමූහයක් පමණක් ප්රමාණවත් වේ. රූප මෝස්තර, ආලෝකකරණ වෙනස්වීම්, රටා වෙනස්වීම් යනාදිය වෙනස් කිරීම සඳහා මෙය ඉතා හොඳින් ක්රියා කරයි. නිදසුනක් ලෙස ගිම්හානය ශීත to තුව දක්වා වෙනස් කිරීම, ඡායාරූප වලට පින්තාරු කිරීමේ ශෛලිය සහ අශ්වයන් සීබ්රා වෙත.

චක්රයGAN උත්පාදක ආකෘති දෙකක් සහ වෙනස්කම් කිරීමේ ආකෘති දෙකක් දුම්රිය කරයි. එක් උත්පාදක යන්ත්රයක් A සිට B දක්වාත් අනෙක B සිට A දක්වාත් පරිවර්තනය කරයි වෙනස්කම් කරන්නන් ජනනය කරන ලද රූප සැබෑ ලෙස පෙනෙනවාද යන්න පරීක්ෂා කරයි.

මෙමගොනුවේ ආදර්ශ කේතය මෙන්ම පුහුණු කේතයද අඩංගු වේ. අපට ගූගල් කොලැබ් සටහන් පොතක් ද ඇත.

Open In Colab View Run

36import itertools
37import random
38import zipfile
39from typing import Tuple
40
41import torch
42import torch.nn as nn
43import torchvision.transforms as transforms
44from PIL import Image
45from torch.utils.data import DataLoader, Dataset
46from torchvision.transforms import InterpolationMode
47from torchvision.utils import make_grid
48
49from labml import lab, tracker, experiment, monit
50from labml.configs import BaseConfigs
51from labml.utils.download import download_file
52from labml.utils.pytorch import get_modules
53from labml_helpers.device import DeviceConfigs
54from labml_helpers.module import Module

උත්පාදකයන්ත්රය අවශේෂ ජාලයකි.

57class GeneratorResNet(Module):
62    def __init__(self, input_channels: int, n_residual_blocks: int):
63        super().__init__()

මෙමපළමු කොටස සංකෝචනය වන අතර රූපය විශේෂාංග සිතියමකට සිතියම් ගත කරයි. නිමැවුම් විශේෂාංග සිතියමට සමාන උස හා පළල ඇත්තේ අපට පෑඩින් එකක් ඇති බැවිනි . පරාවර්තන පුරවන එය දාරවල වඩා හොඳ රූපයේ ගුණාත්මක බවක් ලබා දෙන නිසා භාවිතා වේ.

inplace=True තුළ මතකය ටිකක් ReLU ඉතිරි කරයි.

71        out_features = 64
72        layers = [
73            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
74            nn.InstanceNorm2d(out_features),
75            nn.ReLU(inplace=True),
76        ]
77        in_features = out_features

අපි2 ක stride සමග convolutions දෙකක් සමග පහළ-ආදර්ශ

81        for _ in range(2):
82            out_features *= 2
83            layers += [
84                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
85                nn.InstanceNorm2d(out_features),
86                nn.ReLU(inplace=True),
87            ]
88            in_features = out_features

අපිමෙය හරහා ගන්නෙමු n_residual_blocks . මෙම මොඩියුලය පහත අර්ථ දක්වා ඇත.

92        for _ in range(n_residual_blocks):
93            layers += [ResidualBlock(out_features)]

එවිටඑහි ප්රතිඵලයක් ලක්ෂණය සිතියම මුල් රූපයේ උස හා පළල ගැලපෙන දක්වා sampled ඇත.

97        for _ in range(2):
98            out_features //= 2
99            layers += [
100                nn.Upsample(scale_factor=2),
101                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
102                nn.InstanceNorm2d(out_features),
103                nn.ReLU(inplace=True),
104            ]
105            in_features = out_features

අවසානවශයෙන් අපි විශේෂාංග සිතියම RGB රූපයකට සිතියම් ගත කරමු

108        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

ස්ථරසමඟ අනුක්රමික මොඩියුලයක් සාදන්න

111        self.layers = nn.Sequential(*layers)

බරආරම්භ කරන්න

114        self.apply(weights_init_normal)
116    def forward(self, x):
117        return self.layers(x)

මෙයඅවශේෂ කොටස වන අතර, කැටි ගැසුණු ස්ථර දෙකක් ඇත.

120class ResidualBlock(Module):
125    def __init__(self, in_features: int):
126        super().__init__()
127        self.block = nn.Sequential(
128            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
129            nn.InstanceNorm2d(in_features),
130            nn.ReLU(inplace=True),
131            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
132            nn.InstanceNorm2d(in_features),
133            nn.ReLU(inplace=True),
134        )
136    def forward(self, x: torch.Tensor):
137        return x + self.block(x)

මෙයවෙනස්කම් කරන්නා ය.

140class Discriminator(Module):
145    def __init__(self, input_shape: Tuple[int, int, int]):
146        super().__init__()
147        channels, height, width = input_shape

වෙනස්කම්කරන්නාගේ ප්රතිදානය ද සම්භාවිතාවන්ගේ සිතියමකි, රූපයේ එක් එක් කලාපය සැබෑ හෝ ජනනය වේද යන්න

151        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
152
153        self.layers = nn.Sequential(

මෙමඑක් එක් බ්ලොක් 2 ක සාධකයක් මගින් උස හා පළල හැකිලෙනු ඇත

155            DiscriminatorBlock(channels, 64, normalize=False),
156            DiscriminatorBlock(64, 128),
157            DiscriminatorBlock(128, 256),
158            DiscriminatorBlock(256, 512),

නිමැවුම්උස සහ පළල කර්නලය සමඟ සමානව තබා ගැනීම සඳහා ඉහළින් සහ වමේ ශුන්ය පෑඩ්

161            nn.ZeroPad2d((1, 0, 1, 0)),
162            nn.Conv2d(512, 1, kernel_size=4, padding=1)
163        )

බරආරම්භ කරන්න

166        self.apply(weights_init_normal)
168    def forward(self, img):
169        return self.layers(img)

මෙයවෙනස්කම් කිරීමේ බ්ලොක් මොඩියුලයයි. එය convolution, විකල්ප සාමාන්යකරණය, සහ කාන්දු RELU කරන්නේ.

එයආදාන විශේෂාංග සිතියමේ උස හා පළල අඩකින් හැකිලී යයි.

172class DiscriminatorBlock(Module):
180    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
181        super().__init__()
182        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
183        if normalize:
184            layers.append(nn.InstanceNorm2d(out_filters))
185        layers.append(nn.LeakyReLU(0.2, inplace=True))
186        self.layers = nn.Sequential(*layers)
188    def forward(self, x: torch.Tensor):
189        return self.layers(x)

කැටිගැසුණු ස්ථර බර ආරම්භ කරන්න

192def weights_init_normal(m):
196    classname = m.__class__.__name__
197    if classname.find("Conv") != -1:
198        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

අළුපරිමාණයෙන් නම් රූපයක් පටවා RGB වෙත වෙනස් කරන්න.

201def load_image(path: str):
205    image = Image.open(path)
206    if image.mode != 'RGB':
207        image = Image.new("RGB", image.size).paste(image)
208
209    return image

රූපපූරණය කිරීම සඳහා දත්ත කට්ටලය

212class ImageDataset(Dataset):

දත්තකට්ටලය බාගත කර දත්ත උපුටා ගන්න

217    @staticmethod
218    def download(dataset_name: str):

URL

223        url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'

බාගතෆෝල්ඩරය

225        root = lab.get_data_path() / 'cycle_gan'
226        if not root.exists():
227            root.mkdir(parents=True)

ගමනාන්තයබාගන්න

229        archive = root / f'{dataset_name}.zip'

බාගතගොනුව (සාමාන්යයෙන් ~ 100MB)

231        download_file(url, archive)

සංරක්ෂිතයඋපුටා ගන්න

233        with zipfile.ZipFile(archive, 'r') as f:
234            f.extractall(root)

දත්තසමුදාය ආරම්භ කරන්න

  • dataset_name දත්ත සමුදාය නම
  • transforms_ රූප පරිවර්තනය කිරීමේ කට්ටලය වේ
  • mode එක්කෝ train හෝ test
236    def __init__(self, dataset_name: str, transforms_, mode: str):

දත්තසමුදාය මාර්ගය

245        root = lab.get_data_path() / 'cycle_gan' / dataset_name

අස්ථානගතවුවහොත් බාගත කරන්න

247        if not root.exists():
248            self.download(dataset_name)

රූපපරිවර්තනය කරයි

251        self.transform = transforms.Compose(transforms_)

රූපමාර්ග ලබා ගන්න

254        path_a = root / f'{mode}A'
255        path_b = root / f'{mode}B'
256        self.files_a = sorted(str(f) for f in path_a.iterdir())
257        self.files_b = sorted(str(f) for f in path_b.iterdir())
259    def __getitem__(self, index):

රූපයුගලයක් ආපසු එවන්න. මෙම යුගල එකට එකතු වන අතර ඒවා පුහුණුවීම්වල යුගල මෙන් ක්රියා නොකරයි. එබැවින් අපි සෑම විටම එකම යුගලයක් ලබා දීම හරි ය.

263        return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
264                "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
266    def __len__(self):

දත්තකට්ටුවේ පින්තූර ගණන

268        return max(len(self.files_a), len(self.files_b))

බෆරයනැවත ධාවනය කරන්න

වෙනස්කම්කරන්නා පුහුණු කිරීම සඳහා නැවත ධාවනය කිරීමේ බෆරය භාවිතා කරයි. ජනනය කරන ලද පින්තූර නැවත ධාවනය කිරීමේ බෆරයට එකතු කර එයින් සාම්පල ලබා ගනී.

නැවතධාවනය වන බෆරය අලුතින් එකතු කරන ලද රූපය සම්භාවිතාවයකින් නැවත ලබා දෙයි . එසේ නොමැතිනම්, එය පැරණි ජනනය කරන ලද රූපයක් යවන අතර පැරණි රූපය අලුතින් ජනනය කරන ලද රූපය සමඟ ප්රතිස්ථාපනය කරයි.

ආදර්ශදෝලනය අඩු කිරීම සඳහා මෙය සිදු කෙරේ.

271class ReplayBuffer:
285    def __init__(self, max_size: int = 50):
286        self.max_size = max_size
287        self.data = []

රූපයක්එකතු කරන්න/ලබා ගන්න

289    def push_and_pop(self, data: torch.Tensor):
291        data = data.detach()
292        res = []
293        for element in data:
294            if len(self.data) < self.max_size:
295                self.data.append(element)
296                res.append(element)
297            else:
298                if random.uniform(0, 1) > 0.5:
299                    i = random.randint(0, self.max_size - 1)
300                    res.append(self.data[i].clone())
301                    self.data[i] = element
302                else:
303                    res.append(element)
304        return torch.stack(res)

වින්යාසකිරීම්

307class Configs(BaseConfigs):

DeviceConfigs තිබේ නම් GPU එකක් තෝරා ගනු ඇත

311    device: torch.device = DeviceConfigs()

අධිපරාමිතීන්

314    epochs: int = 200
315    dataset_name: str = 'monet2photo'
316    batch_size: int = 1
317
318    data_loader_workers = 8
319
320    learning_rate = 0.0002
321    adam_betas = (0.5, 0.999)
322    decay_start = 100

කඩදාසියෝජනා කරන්නේ සෘණ ලොග් වීමේ සම්භාවිතාව වෙනුවට අවම චතුරස්ර අලාභයක් භාවිතා කිරීමයි, එය වඩා ස්ථායී බව සොයාගෙන ඇත.

326    gan_loss = torch.nn.MSELoss()

L1අලාභය චක්රීය අලාභය සහ අනන්යතාවය නැතිවීම සඳහා භාවිතා කරයි

329    cycle_loss = torch.nn.L1Loss()
330    identity_loss = torch.nn.L1Loss()

රූපමානයන්

333    img_height = 256
334    img_width = 256
335    img_channels = 3

උත්පාදකයන්ත්රයේ අවශේෂ කොටස් ගණන

338    n_residual_blocks = 9

පාඩුසංගුණක

341    cyclic_loss_coefficient = 10.0
342    identity_loss_coefficient = 5.
343
344    sample_interval = 500

ආකෘති

347    generator_xy: GeneratorResNet
348    generator_yx: GeneratorResNet
349    discriminator_x: Discriminator
350    discriminator_y: Discriminator

ප්රශස්තකරණය

353    generator_optimizer: torch.optim.Adam
354    discriminator_optimizer: torch.optim.Adam

ඉගෙනුම්අනුපාත කාලසටහන්

357    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
358    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

දත්තකාරකයන්

361    dataloader: DataLoader
362    valid_dataloader: DataLoader

පරීක්ෂණකට්ටලයෙන් සාම්පල ජනනය කර ඒවා සුරකින්න

364    def sample_images(self, n: int):
366        batch = next(iter(self.valid_dataloader))
367        self.generator_xy.eval()
368        self.generator_yx.eval()
369        with torch.no_grad():
370            data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
371            gen_y = self.generator_xy(data_x)
372            gen_x = self.generator_yx(data_y)

x-අක්ෂය ඔස්සේ රූප සකස් කරන්න

375            data_x = make_grid(data_x, nrow=5, normalize=True)
376            data_y = make_grid(data_y, nrow=5, normalize=True)
377            gen_x = make_grid(gen_x, nrow=5, normalize=True)
378            gen_y = make_grid(gen_y, nrow=5, normalize=True)

Y-අක්ෂය ඔස්සේ රූප සකස් කරන්න

381            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

සාම්පලපෙන්වන්න

384        plot_image(image_grid)

ආකෘතිසහ දත්ත කාරකයන් ආරම්භ කරන්න

386    def initialize(self):
390        input_shape = (self.img_channels, self.img_height, self.img_width)

ආකෘතිසාදන්න

393        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
395        self.discriminator_x = Discriminator(input_shape).to(self.device)
396        self.discriminator_y = Discriminator(input_shape).to(self.device)

මෙමoptmizers නිර්මාණය

399        self.generator_optimizer = torch.optim.Adam(
400            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
401            lr=self.learning_rate, betas=self.adam_betas)
402        self.discriminator_optimizer = torch.optim.Adam(
403            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
404            lr=self.learning_rate, betas=self.adam_betas)

ඉගෙනුම්අනුපාත කාලසටහන් සාදන්න. ඉගෙනුම් අනුපාතය decay_start එපොච් තෙක් පැතලි වන අතර පසුව පුහුණුව අවසානයේ රේඛීයව අඩු වේ.

409        decay_epochs = self.epochs - self.decay_start
410        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
411            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
412        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
413            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

රූපපරිවර්තනයන්

416        transforms_ = [
417            transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
418            transforms.RandomCrop((self.img_height, self.img_width)),
419            transforms.RandomHorizontalFlip(),
420            transforms.ToTensor(),
421            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
422        ]

පුහුණුදත්ත පැටවුම

425        self.dataloader = DataLoader(
426            ImageDataset(self.dataset_name, transforms_, 'train'),
427            batch_size=self.batch_size,
428            shuffle=True,
429            num_workers=self.data_loader_workers,
430        )

වලංගුදත්ත පැටවුම

433        self.valid_dataloader = DataLoader(
434            ImageDataset(self.dataset_name, transforms_, "test"),
435            batch_size=5,
436            shuffle=True,
437            num_workers=self.data_loader_workers,
438        )

පුහුණු

විසඳීමටඅපගේ ඉලක්කය:

කොහෙන්ද, පින්තූර පරිවර්තනය කරයි , පින්තූර පරිවර්තනය කරයි , රූප තිබේදැයි පරීක්ෂා කරයි අභ්යවකාශයේ සිට, රූප අභ්යවකාශයේ සිට නම් පරීක්ෂණ සහ

යනු මුල් GAN කඩදාසි වලින් උත්පාදක අහිතකර අලාභයයි.

යනු චක්රීය අලාභයයි, එහිදී අපි සමාන වීමට හා සමාන වීමට උත්සාහ කරමු . මූලික වශයෙන් ජනක යන්ත්ර දෙක (පරිවර්තනයන්) ශ්රේණිගතව යෙදුවහොත් එය මුල් රූපය ආපසු ලබා දිය යුතුය. මෙම ලිපියේ ප්රධාන දායකත්වය මෙයයි. මුල් රූපයට සමාන වෙනත් බෙදාහැරීමේ රූපයක් ජනනය කිරීමට එය ජනක යන්ත්ර පුහුණු කරයි. මෙම අලාභය නොමැතිව බෙදා හැරීමෙන් ඕනෑම දෙයක් ජනනය කළ හැකිය. දැන් එය බෙදා හැරීමෙන් යමක් ජනනය කළ යුතු නමුත් තවමත් ගුණාංග ඇත , එවිට වැනි දෙයක් නැවත ජනනය කළ හැකිය .

අනන්යතාව අහිමි වේ. ආදානය සහ ප්රතිදානය අතර වර්ණ සංයුතිය ආරක්ෂා කර ගැනීම සඳහා සිතියම්කරණය දිරිමත් කිරීම සඳහා මෙය භාවිතා කරන ලදී.

විසඳීමට , වෙනස්කම් කරන්නන් සහ ශ්රේණිය මත නගින්නේ සිටිය යුතුය,

බව සෘණ ලොග්-සම්භාවිතාව අහිමි මත බැස ඇත.

පුහුණුවස්ථාවර කිරීම සඳහා negative ණාත්මක ලොග්- සම්භාවිතාව පරමාර්ථය අවම වශයෙන් කොටු වූ අලාභයක් මගින් ප්රතිස්ථාපනය කරන ලදි - වෙනස්කම් කරන්නාගේ අවම කොටු දෝෂය, 1 සමඟ සැබෑ රූප ලේබල් කිරීම සහ 0 සමඟ රූප ජනනය කිරීම. ඒ නිසා අපි ඵලය අනුක්රමික මත බැස කිරීමට අවශ්ය,

අපිජනක යන්ත්ර සඳහා අඩු චතුරස්රයන් ද භාවිතා කරමු. ජනක යන්ත්ර ශ්රේණිය මතට බැස යා යුතුය,

අපි generator_xy generator_yx සඳහා සහ භාවිතා කරමු . අපි discriminator_x discriminator_y සඳහා සහ භාවිතා කරමු .

440    def run(self):

ජනනයසාම්පල තබා ගැනීමට බෆර නැවත ධාවනය

542        gen_x_buffer = ReplayBuffer()
543        gen_y_buffer = ReplayBuffer()

එපොච්හරහා ලූප්

546        for epoch in monit.loop(self.epochs):

දත්තකට්ටලය හරහා ලූප

548            for i, batch in monit.enum('Train', self.dataloader):

උපාංගයටරූප ගෙනයන්න

550                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)

සැබෑලේබල වලට සමාන වේ

553                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
554                                         device=self.device, requires_grad=False)

ව්යාජලේබල වලට සමාන වේ

556                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
557                                           device=self.device, requires_grad=False)

ජනකයන්ත්ර පුහුණු කරන්න. මෙය ජනනය කරන ලද රූප නැවත ලබා දෙයි.

561                gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)

වෙනස්කම්කරන්නන් පුහුණු කරන්න

564                self.optimize_discriminator(data_x, data_y,
565                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
566                                            true_labels, false_labels)

පුහුණුසංඛ්යාලේඛන සුරකින්න සහ ගෝලීය පියවර කවුන්ටරය වැඩි කරන්න

569                tracker.save()
570                tracker.add_global_step(max(len(data_x), len(data_y)))

කාලපරතරයන් පින්තූර සුරකින්න

573                batches_done = epoch * len(self.dataloader) + i
574                if batches_done % self.sample_interval == 0:

පින්තූරනියැදි කිරීමේදී ආකෘති සුරකින්න

576                    experiment.save_checkpoint()

නියැදිරූප

578                    self.sample_images(batches_done)

ඉගෙනුම්අනුපාත යාවත්කාලීන කරන්න

581            self.generator_lr_scheduler.step()
582            self.discriminator_lr_scheduler.step()

නවපේළිය

584            tracker.new_line()

අනන්යතාවය, ගන් සහ චක්රීය පාඩු සහිත ජනක යන්ත්ර ප්රශස්ත කරන්න.

586    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):

පුහුණුප්රකාරයට වෙනස් කරන්න

592        self.generator_xy.train()
593        self.generator_yx.train()

අනන්යතාවයනැතිවීම

598        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
599                         self.identity_loss(self.generator_xy(data_y), data_y))

රූප ජනනය කරන්න

602        gen_y = self.generator_xy(data_x)
603        gen_x = self.generator_yx(data_y)

GANපාඩුව

608        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
609                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

චක්රයඅහිමි

616        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
617                      self.cycle_loss(self.generator_xy(gen_x), data_y))

මුළුඅලාභය

620        loss_generator = (loss_gan +
621                          self.cyclic_loss_coefficient * loss_cycle +
622                          self.identity_loss_coefficient * loss_identity)

ප්රශස්තකරණයේපියවරක් ගන්න

625        self.generator_optimizer.zero_grad()
626        loss_generator.backward()
627        self.generator_optimizer.step()

ලොග්පාඩු

630        tracker.add({'loss.generator': loss_generator,
631                     'loss.generator.cycle': loss_cycle,
632                     'loss.generator.gan': loss_gan,
633                     'loss.generator.identity': loss_identity})

ජනනයකරන ලද රූප ආපසු

636        return gen_x, gen_y

Ganඅහිමි වීමෙන් වෙනස්කම් කරන්නන් ප්රශස්ත කරන්න.

638    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
639                               gen_x: torch.Tensor, gen_y: torch.Tensor,
640                               true_labels: torch.Tensor, false_labels: torch.Tensor):

GANපාඩුව

653        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
654                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
655                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
656                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

ප්රශස්තකරණයේපියවරක් ගන්න

659        self.discriminator_optimizer.zero_grad()
660        loss_discriminator.backward()
661        self.discriminator_optimizer.step()

ලොග්පාඩු

664        tracker.add({'loss.discriminator': loss_discriminator})

දුම්රියචක්රය GAN

667def train():

වින්යාසයන්සාදන්න

672    conf = Configs()

අත්හදාබැලීමක් සාදන්න

674    experiment.create(name='cycle_gan')

වින්යාසයන්ගණනය කරන්න. එය ගණනය කරනු conf.run ඇති අතර එයට අවශ්ය අනෙකුත් සියලුම වින්යාස.

677    experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
678    conf.initialize()

සුරැකීමසහ පැටවීම සඳහා ආකෘති ලියාපදිංචි කරන්න. get_modules ශබ්ද කෝෂයක් ලබා nn.Modules දෙයි conf . ඔබට ආකෘති වල අභිරුචි ශබ්ද කෝෂයක් ද නියම කළ හැකිය.

683    experiment.add_pytorch_models(get_modules(conf))

අත්හදාබැලීම ආරම්භ කර නරඹන්න

685    with experiment.start():

පුහුණුවක්රියාත්මක කරන්න

687        conf.run()

මැට්ප්ලොට්ලිබ්සමඟ රූපයක් සැලසුම් කරන්න

690def plot_image(img: torch.Tensor):
694    from matplotlib import pyplot as plt

ආතකයCPU වෙත ගෙන යන්න

697    img = img.cpu()

සාමාන්යකරණයසඳහා රූපයේ විනාඩි සහ උපරිම අගයන් ලබා ගන්න

699    img_min, img_max = img.min(), img.max()

රූපඅගයන් 0... 1 විය යුතුය

701    img = (img - img_min) / (img_max - img_min + 1e-5)

අපිHWC වෙත මානයන් අනුපිළිවෙල වෙනස් කළ යුතුය.

703    img = img.permute(1, 2, 0)

රූපයපෙන්වන්න

705    plt.imshow(img)

අපටඅක්ෂ අවශ්ය නොවේ

707    plt.axis('off')

සංදර්ශකය

709    plt.show()

පුහුණුපාපැදි GAN ඇගයීම

712def evaluate():

පුහුණුධාවනයෙන් UUID ධාවනය කරන්න

717    trained_run_uuid = 'f73c1164184711eb9190b74249275441'

වින්යාසවස්තුව සාදන්න

719    conf = Configs()

අත්හදාබැලීම සාදන්න

721    experiment.create(name='cycle_gan_inference')

පුහුණුවසඳහා සකසා ඇති අධි පරාමිතීන් පටවන්න

723    conf_dict = experiment.load_configs(trained_run_uuid)

වින්යාසයන්ගණනය කරන්න. අපි ජනක යන්ත්ර නියම කර ඇති අතර 'generator_xy', 'generator_yx' එමඟින් එය පටවන්නේ ඒවා සහ ඒවායේ පරායත්තතාවයන් පමණි. වින්යාස කිරීම device හා ගණනය img_channels කරනු ලැබේ, මේවා අවශ්ය වන බැවින් generator_xy සහ generator_yx .

ඔබටවෙනත් පරාමිතීන් අවශ්ය නම් dataset_name ඔබ ඒවා මෙහි සඳහන් කළ යුතුය. ඔබ කිසිවක් සඳහන් නොකරන්නේ නම්, දත්ත පැටවුම් ඇතුළුව සියලු වින්යාසයන් ගණනය කරනු ලැබේ. වින්යාසයන් සහ ඒවායේ පරායත්තතා ගණනය කිරීම ඔබ අමතන විට සිදුවනු ඇත experiment.start

732    experiment.configs(conf, conf_dict)
733    conf.initialize()

සුරැකීමසහ පැටවීම සඳහා ආකෘති ලියාපදිංචි කරන්න. get_modules ශබ්ද කෝෂයක් ලබා nn.Modules දෙයි conf . ඔබට ආකෘති වල අභිරුචි ශබ්ද කෝෂයක් ද නියම කළ හැකිය.

738    experiment.add_pytorch_models(get_modules(conf))

කුමනධාවනයෙන් පූරණය කළ යුතු දැයි සඳහන් කරන්න. ඔබ කතා කරන විට පැටවීම ඇත්ත වශයෙන්ම සිදුවනු ඇත experiment.start

741    experiment.load(trained_run_uuid)

අත්හදාබැලීම ආරම්භ කරන්න

744    with experiment.start():

රූපපරිවර්තනයන්

746        transforms_ = [
747            transforms.ToTensor(),
748            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
749        ]

ඔබේමදත්ත පූරණය කරන්න. මෙන්න අපි පරීක්ෂණ කට්ටලය උත්සාහ කරමු. මම යෝස්මයිට් ඡායාරූප සමඟ උත්සාහ කරමින් සිටියෙමි, ඒවා නියමයි. ඔබට ඇමතුමේ ගණනය කිරීමට අවශ්ය වූ දෙයක් dataset_name ලෙස ඔබ නියම කර ඇත්නම් conf.dataset_name , ඔබට භාවිතා කළ

හැකිය experiment.configs
755        dataset = ImageDataset(conf.dataset_name, transforms_, 'train')

දත්තකට්ටලයෙන් රූපයක් ලබා ගන්න

757        x_image = dataset[10]['x']

රූපයපෙන්වන්න

759        plot_image(x_image)

ඇගයීම්මාදිලිය

762        conf.generator_xy.eval()
763        conf.generator_yx.eval()

අපටඅනුක්රමික අවශ්ය නොවේ

766        with torch.no_grad():

කණ්ඩායම්මානය එක් කර අප භාවිතා කරන උපාංගයට යන්න

768            data = x_image.unsqueeze(0).to(conf.device)
769            generated_y = conf.generator_xy(data)

ජනනයකරන ලද රූපය පෙන්වන්න.

772        plot_image(generated_y[0].cpu())
773
774
775if __name__ == '__main__':
776    train()

ඇගයීම()