මෙය PyTorch ක්රියාත්මක කිරීම/නිබන්ධනයක් වන අතර එය Cycle-Consistent adversarial Networks භාවිතා කරමින් PyTorch Image පරිවර්තනය නොකළ අනුරූපය-රූප පරිවර්තනය .
මම එරික්ලින්ඩර්නොරන්/පයිටෝර්ච්-ගන්වෙතින් කේත කෑලි ගෙන ඇත්තෙමි. ඔබට වෙනත් GAN වෙනස්කම් පරීක්ෂා කිරීමට අවශ්ය නම් එය ඉතා හොඳ සම්පතකි.
චක්රයGAN රූපය-කිරීමට රූප පරිවර්තනය කරයි. ලබා දී ඇති බෙදාහැරීමෙන් රූපයක් තවත් එකකට පරිවර්තනය කිරීම සඳහා ආකෘතියක් පුහුණු කරයි, පවසන්නේ, A සහ B පන්තියේ රූප යම් බෙදාහැරීමක් පිළිබඳ රූප යම් ශෛලියක හෝ ස්වභාවධර්මයේ රූප වැනි දේවල් විය හැකිය. ආකෘති A සහ B අතර යුගල රූප අවශ්ය නොවේ එක් එක් පන්තියේ රූප සමූහයක් පමණක් ප්රමාණවත් වේ. රූප මෝස්තර, ආලෝකකරණ වෙනස්වීම්, රටා වෙනස්වීම් යනාදිය වෙනස් කිරීම සඳහා මෙය ඉතා හොඳින් ක්රියා කරයි. නිදසුනක් ලෙස ගිම්හානය ශීත to තුව දක්වා වෙනස් කිරීම, ඡායාරූප වලට පින්තාරු කිරීමේ ශෛලිය සහ අශ්වයන් සීබ්රා වෙත.
චක්රයGAN උත්පාදක ආකෘති දෙකක් සහ වෙනස්කම් කිරීමේ ආකෘති දෙකක් දුම්රිය කරයි. එක් උත්පාදක යන්ත්රයක් A සිට B දක්වාත් අනෙක B සිට A දක්වාත් පරිවර්තනය කරයි වෙනස්කම් කරන්නන් ජනනය කරන ලද රූප සැබෑ ලෙස පෙනෙනවාද යන්න පරීක්ෂා කරයි.
මෙමගොනුවේ ආදර්ශ කේතය මෙන්ම පුහුණු කේතයද අඩංගු වේ. අපට ගූගල් කොලැබ් සටහන් පොතක් ද ඇත.
36import itertools
37import random
38import zipfile
39from typing import Tuple
40
41import torch
42import torch.nn as nn
43import torchvision.transforms as transforms
44from PIL import Image
45from torch.utils.data import DataLoader, Dataset
46from torchvision.transforms import InterpolationMode
47from torchvision.utils import make_grid
48
49from labml import lab, tracker, experiment, monit
50from labml.configs import BaseConfigs
51from labml.utils.download import download_file
52from labml.utils.pytorch import get_modules
53from labml_helpers.device import DeviceConfigs
54from labml_helpers.module import Moduleඋත්පාදකයන්ත්රය අවශේෂ ජාලයකි.
57class GeneratorResNet(Module):62 def __init__(self, input_channels: int, n_residual_blocks: int):
63 super().__init__()මෙමපළමු කොටස සංකෝචනය වන අතර රූපය විශේෂාංග සිතියමකට සිතියම් ගත කරයි. නිමැවුම් විශේෂාංග සිතියමට සමාන උස හා පළල ඇත්තේ අපට පෑඩින් එකක් ඇති බැවිනි . පරාවර්තන පුරවන එය දාරවල වඩා හොඳ රූපයේ ගුණාත්මක බවක් ලබා දෙන නිසා භාවිතා වේ.
inplace=True
තුළ මතකය ටිකක් ReLU
ඉතිරි කරයි.
71 out_features = 64
72 layers = [
73 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
74 nn.InstanceNorm2d(out_features),
75 nn.ReLU(inplace=True),
76 ]
77 in_features = out_featuresඅපි2 ක stride සමග convolutions දෙකක් සමග පහළ-ආදර්ශ
81 for _ in range(2):
82 out_features *= 2
83 layers += [
84 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
85 nn.InstanceNorm2d(out_features),
86 nn.ReLU(inplace=True),
87 ]
88 in_features = out_featuresඅපිමෙය හරහා ගන්නෙමු n_residual_blocks
. මෙම මොඩියුලය පහත අර්ථ දක්වා ඇත.
92 for _ in range(n_residual_blocks):
93 layers += [ResidualBlock(out_features)]එවිටඑහි ප්රතිඵලයක් ලක්ෂණය සිතියම මුල් රූපයේ උස හා පළල ගැලපෙන දක්වා sampled ඇත.
97 for _ in range(2):
98 out_features //= 2
99 layers += [
100 nn.Upsample(scale_factor=2),
101 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
102 nn.InstanceNorm2d(out_features),
103 nn.ReLU(inplace=True),
104 ]
105 in_features = out_featuresඅවසානවශයෙන් අපි විශේෂාංග සිතියම RGB රූපයකට සිතියම් ගත කරමු
108 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]ස්ථරසමඟ අනුක්රමික මොඩියුලයක් සාදන්න
111 self.layers = nn.Sequential(*layers)බරආරම්භ කරන්න
114 self.apply(weights_init_normal)116 def forward(self, x):
117 return self.layers(x)මෙයඅවශේෂ කොටස වන අතර, කැටි ගැසුණු ස්ථර දෙකක් ඇත.
120class ResidualBlock(Module):125 def __init__(self, in_features: int):
126 super().__init__()
127 self.block = nn.Sequential(
128 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
129 nn.InstanceNorm2d(in_features),
130 nn.ReLU(inplace=True),
131 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
132 nn.InstanceNorm2d(in_features),
133 nn.ReLU(inplace=True),
134 )136 def forward(self, x: torch.Tensor):
137 return x + self.block(x)මෙයවෙනස්කම් කරන්නා ය.
140class Discriminator(Module):145 def __init__(self, input_shape: Tuple[int, int, int]):
146 super().__init__()
147 channels, height, width = input_shapeවෙනස්කම්කරන්නාගේ ප්රතිදානය ද සම්භාවිතාවන්ගේ සිතියමකි, රූපයේ එක් එක් කලාපය සැබෑ හෝ ජනනය වේද යන්න
151 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
152
153 self.layers = nn.Sequential(මෙමඑක් එක් බ්ලොක් 2 ක සාධකයක් මගින් උස හා පළල හැකිලෙනු ඇත
155 DiscriminatorBlock(channels, 64, normalize=False),
156 DiscriminatorBlock(64, 128),
157 DiscriminatorBlock(128, 256),
158 DiscriminatorBlock(256, 512),නිමැවුම්උස සහ පළල කර්නලය සමඟ සමානව තබා ගැනීම සඳහා ඉහළින් සහ වමේ ශුන්ය පෑඩ්
161 nn.ZeroPad2d((1, 0, 1, 0)),
162 nn.Conv2d(512, 1, kernel_size=4, padding=1)
163 )බරආරම්භ කරන්න
166 self.apply(weights_init_normal)168 def forward(self, img):
169 return self.layers(img)මෙයවෙනස්කම් කිරීමේ බ්ලොක් මොඩියුලයයි. එය convolution, විකල්ප සාමාන්යකරණය, සහ කාන්දු RELU කරන්නේ.
එයආදාන විශේෂාංග සිතියමේ උස හා පළල අඩකින් හැකිලී යයි.
172class DiscriminatorBlock(Module):180 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
181 super().__init__()
182 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
183 if normalize:
184 layers.append(nn.InstanceNorm2d(out_filters))
185 layers.append(nn.LeakyReLU(0.2, inplace=True))
186 self.layers = nn.Sequential(*layers)188 def forward(self, x: torch.Tensor):
189 return self.layers(x)කැටිගැසුණු ස්ථර බර ආරම්භ කරන්න
192def weights_init_normal(m):196 classname = m.__class__.__name__
197 if classname.find("Conv") != -1:
198 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)අළුපරිමාණයෙන් නම් රූපයක් පටවා RGB වෙත වෙනස් කරන්න.
201def load_image(path: str):205 image = Image.open(path)
206 if image.mode != 'RGB':
207 image = Image.new("RGB", image.size).paste(image)
208
209 return image212class ImageDataset(Dataset):217 @staticmethod
218 def download(dataset_name: str):URL
223 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'බාගතෆෝල්ඩරය
225 root = lab.get_data_path() / 'cycle_gan'
226 if not root.exists():
227 root.mkdir(parents=True)ගමනාන්තයබාගන්න
229 archive = root / f'{dataset_name}.zip'බාගතගොනුව (සාමාන්යයෙන් ~ 100MB)
231 download_file(url, archive)සංරක්ෂිතයඋපුටා ගන්න
233 with zipfile.ZipFile(archive, 'r') as f:
234 f.extractall(root)dataset_name
දත්ත සමුදාය නම transforms_
රූප පරිවර්තනය කිරීමේ කට්ටලය වේ mode
එක්කෝ train
හෝ test
236 def __init__(self, dataset_name: str, transforms_, mode: str):දත්තසමුදාය මාර්ගය
245 root = lab.get_data_path() / 'cycle_gan' / dataset_nameඅස්ථානගතවුවහොත් බාගත කරන්න
247 if not root.exists():
248 self.download(dataset_name)රූපපරිවර්තනය කරයි
251 self.transform = transforms.Compose(transforms_)රූපමාර්ග ලබා ගන්න
254 path_a = root / f'{mode}A'
255 path_b = root / f'{mode}B'
256 self.files_a = sorted(str(f) for f in path_a.iterdir())
257 self.files_b = sorted(str(f) for f in path_b.iterdir())259 def __getitem__(self, index):රූපයුගලයක් ආපසු එවන්න. මෙම යුගල එකට එකතු වන අතර ඒවා පුහුණුවීම්වල යුගල මෙන් ක්රියා නොකරයි. එබැවින් අපි සෑම විටම එකම යුගලයක් ලබා දීම හරි ය.
263 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
264 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}266 def __len__(self):දත්තකට්ටුවේ පින්තූර ගණන
268 return max(len(self.files_a), len(self.files_b))වෙනස්කම්කරන්නා පුහුණු කිරීම සඳහා නැවත ධාවනය කිරීමේ බෆරය භාවිතා කරයි. ජනනය කරන ලද පින්තූර නැවත ධාවනය කිරීමේ බෆරයට එකතු කර එයින් සාම්පල ලබා ගනී.
නැවතධාවනය වන බෆරය අලුතින් එකතු කරන ලද රූපය සම්භාවිතාවයකින් නැවත ලබා දෙයි . එසේ නොමැතිනම්, එය පැරණි ජනනය කරන ලද රූපයක් යවන අතර පැරණි රූපය අලුතින් ජනනය කරන ලද රූපය සමඟ ප්රතිස්ථාපනය කරයි.
ආදර්ශදෝලනය අඩු කිරීම සඳහා මෙය සිදු කෙරේ.
271class ReplayBuffer:285 def __init__(self, max_size: int = 50):
286 self.max_size = max_size
287 self.data = []රූපයක්එකතු කරන්න/ලබා ගන්න
289 def push_and_pop(self, data: torch.Tensor):291 data = data.detach()
292 res = []
293 for element in data:
294 if len(self.data) < self.max_size:
295 self.data.append(element)
296 res.append(element)
297 else:
298 if random.uniform(0, 1) > 0.5:
299 i = random.randint(0, self.max_size - 1)
300 res.append(self.data[i].clone())
301 self.data[i] = element
302 else:
303 res.append(element)
304 return torch.stack(res)307class Configs(BaseConfigs):DeviceConfigs
තිබේ නම් GPU එකක් තෝරා ගනු ඇත
311 device: torch.device = DeviceConfigs()අධිපරාමිතීන්
314 epochs: int = 200
315 dataset_name: str = 'monet2photo'
316 batch_size: int = 1
317
318 data_loader_workers = 8
319
320 learning_rate = 0.0002
321 adam_betas = (0.5, 0.999)
322 decay_start = 100කඩදාසියෝජනා කරන්නේ සෘණ ලොග් වීමේ සම්භාවිතාව වෙනුවට අවම චතුරස්ර අලාභයක් භාවිතා කිරීමයි, එය වඩා ස්ථායී බව සොයාගෙන ඇත.
326 gan_loss = torch.nn.MSELoss()L1අලාභය චක්රීය අලාභය සහ අනන්යතාවය නැතිවීම සඳහා භාවිතා කරයි
329 cycle_loss = torch.nn.L1Loss()
330 identity_loss = torch.nn.L1Loss()රූපමානයන්
333 img_height = 256
334 img_width = 256
335 img_channels = 3උත්පාදකයන්ත්රයේ අවශේෂ කොටස් ගණන
338 n_residual_blocks = 9පාඩුසංගුණක
341 cyclic_loss_coefficient = 10.0
342 identity_loss_coefficient = 5.
343
344 sample_interval = 500ආකෘති
347 generator_xy: GeneratorResNet
348 generator_yx: GeneratorResNet
349 discriminator_x: Discriminator
350 discriminator_y: Discriminatorප්රශස්තකරණය
353 generator_optimizer: torch.optim.Adam
354 discriminator_optimizer: torch.optim.Adamඉගෙනුම්අනුපාත කාලසටහන්
357 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
358 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLRදත්තකාරකයන්
361 dataloader: DataLoader
362 valid_dataloader: DataLoaderපරීක්ෂණකට්ටලයෙන් සාම්පල ජනනය කර ඒවා සුරකින්න
364 def sample_images(self, n: int):366 batch = next(iter(self.valid_dataloader))
367 self.generator_xy.eval()
368 self.generator_yx.eval()
369 with torch.no_grad():
370 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
371 gen_y = self.generator_xy(data_x)
372 gen_x = self.generator_yx(data_y)x-අක්ෂය ඔස්සේ රූප සකස් කරන්න
375 data_x = make_grid(data_x, nrow=5, normalize=True)
376 data_y = make_grid(data_y, nrow=5, normalize=True)
377 gen_x = make_grid(gen_x, nrow=5, normalize=True)
378 gen_y = make_grid(gen_y, nrow=5, normalize=True)Y-අක්ෂය ඔස්සේ රූප සකස් කරන්න
381 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)සාම්පලපෙන්වන්න
384 plot_image(image_grid)386 def initialize(self):390 input_shape = (self.img_channels, self.img_height, self.img_width)ආකෘතිසාදන්න
393 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
395 self.discriminator_x = Discriminator(input_shape).to(self.device)
396 self.discriminator_y = Discriminator(input_shape).to(self.device)මෙමoptmizers නිර්මාණය
399 self.generator_optimizer = torch.optim.Adam(
400 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
401 lr=self.learning_rate, betas=self.adam_betas)
402 self.discriminator_optimizer = torch.optim.Adam(
403 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
404 lr=self.learning_rate, betas=self.adam_betas)ඉගෙනුම්අනුපාත කාලසටහන් සාදන්න. ඉගෙනුම් අනුපාතය decay_start
එපොච් තෙක් පැතලි වන අතර පසුව පුහුණුව අවසානයේ රේඛීයව අඩු වේ.
409 decay_epochs = self.epochs - self.decay_start
410 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
411 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
412 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
413 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)රූපපරිවර්තනයන්
416 transforms_ = [
417 transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
418 transforms.RandomCrop((self.img_height, self.img_width)),
419 transforms.RandomHorizontalFlip(),
420 transforms.ToTensor(),
421 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
422 ]පුහුණුදත්ත පැටවුම
425 self.dataloader = DataLoader(
426 ImageDataset(self.dataset_name, transforms_, 'train'),
427 batch_size=self.batch_size,
428 shuffle=True,
429 num_workers=self.data_loader_workers,
430 )වලංගුදත්ත පැටවුම
433 self.valid_dataloader = DataLoader(
434 ImageDataset(self.dataset_name, transforms_, "test"),
435 batch_size=5,
436 shuffle=True,
437 num_workers=self.data_loader_workers,
438 )විසඳීමටඅපගේ ඉලක්කය:
කොහෙන්ද, පින්තූර පරිවර්තනය කරයි , පින්තූර පරිවර්තනය කරයි , රූප තිබේදැයි පරීක්ෂා කරයි අභ්යවකාශයේ සිට, රූප අභ්යවකාශයේ සිට නම් පරීක්ෂණ සහ
යනු මුල් GAN කඩදාසි වලින් උත්පාදක අහිතකර අලාභයයි.
යනු චක්රීය අලාභයයි, එහිදී අපි සමාන වීමට හා සමාන වීමට උත්සාහ කරමු . මූලික වශයෙන් ජනක යන්ත්ර දෙක (පරිවර්තනයන්) ශ්රේණිගතව යෙදුවහොත් එය මුල් රූපය ආපසු ලබා දිය යුතුය. මෙම ලිපියේ ප්රධාන දායකත්වය මෙයයි. මුල් රූපයට සමාන වෙනත් බෙදාහැරීමේ රූපයක් ජනනය කිරීමට එය ජනක යන්ත්ර පුහුණු කරයි. මෙම අලාභය නොමැතිව බෙදා හැරීමෙන් ඕනෑම දෙයක් ජනනය කළ හැකිය. දැන් එය බෙදා හැරීමෙන් යමක් ජනනය කළ යුතු නමුත් තවමත් ගුණාංග ඇත , එවිට වැනි දෙයක් නැවත ජනනය කළ හැකිය .
අනන්යතාව අහිමි වේ. ආදානය සහ ප්රතිදානය අතර වර්ණ සංයුතිය ආරක්ෂා කර ගැනීම සඳහා සිතියම්කරණය දිරිමත් කිරීම සඳහා මෙය භාවිතා කරන ලදී.
විසඳීමට , වෙනස්කම් කරන්නන් සහ ශ්රේණිය මත නගින්නේ සිටිය යුතුය,
බව සෘණ ලොග්-සම්භාවිතාව අහිමි මත බැස ඇත.
පුහුණුවස්ථාවර කිරීම සඳහා negative ණාත්මක ලොග්- සම්භාවිතාව පරමාර්ථය අවම වශයෙන් කොටු වූ අලාභයක් මගින් ප්රතිස්ථාපනය කරන ලදි - වෙනස්කම් කරන්නාගේ අවම කොටු දෝෂය, 1 සමඟ සැබෑ රූප ලේබල් කිරීම සහ 0 සමඟ රූප ජනනය කිරීම. ඒ නිසා අපි ඵලය අනුක්රමික මත බැස කිරීමට අවශ්ය,
අපිජනක යන්ත්ර සඳහා අඩු චතුරස්රයන් ද භාවිතා කරමු. ජනක යන්ත්ර ශ්රේණිය මතට බැස යා යුතුය,
අපි generator_xy
generator_yx
සඳහා සහ භාවිතා කරමු . අපි discriminator_x
discriminator_y
සඳහා සහ භාවිතා කරමු .
440 def run(self):ජනනයසාම්පල තබා ගැනීමට බෆර නැවත ධාවනය
542 gen_x_buffer = ReplayBuffer()
543 gen_y_buffer = ReplayBuffer()එපොච්හරහා ලූප්
546 for epoch in monit.loop(self.epochs):දත්තකට්ටලය හරහා ලූප
548 for i, batch in monit.enum('Train', self.dataloader):උපාංගයටරූප ගෙනයන්න
550 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)සැබෑලේබල වලට සමාන වේ
553 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
554 device=self.device, requires_grad=False)ව්යාජලේබල වලට සමාන වේ
556 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
557 device=self.device, requires_grad=False)ජනකයන්ත්ර පුහුණු කරන්න. මෙය ජනනය කරන ලද රූප නැවත ලබා දෙයි.
561 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)වෙනස්කම්කරන්නන් පුහුණු කරන්න
564 self.optimize_discriminator(data_x, data_y,
565 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
566 true_labels, false_labels)පුහුණුසංඛ්යාලේඛන සුරකින්න සහ ගෝලීය පියවර කවුන්ටරය වැඩි කරන්න
569 tracker.save()
570 tracker.add_global_step(max(len(data_x), len(data_y)))කාලපරතරයන් පින්තූර සුරකින්න
573 batches_done = epoch * len(self.dataloader) + i
574 if batches_done % self.sample_interval == 0:පින්තූරනියැදි කිරීමේදී ආකෘති සුරකින්න
576 experiment.save_checkpoint()නියැදිරූප
578 self.sample_images(batches_done)ඉගෙනුම්අනුපාත යාවත්කාලීන කරන්න
581 self.generator_lr_scheduler.step()
582 self.discriminator_lr_scheduler.step()නවපේළිය
584 tracker.new_line()586 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):පුහුණුප්රකාරයට වෙනස් කරන්න
592 self.generator_xy.train()
593 self.generator_yx.train()අනන්යතාවයනැතිවීම
598 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
599 self.identity_loss(self.generator_xy(data_y), data_y))රූප ජනනය කරන්න
602 gen_y = self.generator_xy(data_x)
603 gen_x = self.generator_yx(data_y)GANපාඩුව
608 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
609 self.gan_loss(self.discriminator_x(gen_x), true_labels))චක්රයඅහිමි
616 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
617 self.cycle_loss(self.generator_xy(gen_x), data_y))මුළුඅලාභය
620 loss_generator = (loss_gan +
621 self.cyclic_loss_coefficient * loss_cycle +
622 self.identity_loss_coefficient * loss_identity)ප්රශස්තකරණයේපියවරක් ගන්න
625 self.generator_optimizer.zero_grad()
626 loss_generator.backward()
627 self.generator_optimizer.step()ලොග්පාඩු
630 tracker.add({'loss.generator': loss_generator,
631 'loss.generator.cycle': loss_cycle,
632 'loss.generator.gan': loss_gan,
633 'loss.generator.identity': loss_identity})ජනනයකරන ලද රූප ආපසු
636 return gen_x, gen_y638 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
639 gen_x: torch.Tensor, gen_y: torch.Tensor,
640 true_labels: torch.Tensor, false_labels: torch.Tensor):653 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
654 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
655 self.gan_loss(self.discriminator_y(data_y), true_labels) +
656 self.gan_loss(self.discriminator_y(gen_y), false_labels))ප්රශස්තකරණයේපියවරක් ගන්න
659 self.discriminator_optimizer.zero_grad()
660 loss_discriminator.backward()
661 self.discriminator_optimizer.step()ලොග්පාඩු
664 tracker.add({'loss.discriminator': loss_discriminator})667def train():වින්යාසයන්සාදන්න
672 conf = Configs()අත්හදාබැලීමක් සාදන්න
674 experiment.create(name='cycle_gan')වින්යාසයන්ගණනය කරන්න. එය ගණනය කරනු conf.run
ඇති අතර එයට අවශ්ය අනෙකුත් සියලුම වින්යාස.
677 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
678 conf.initialize()සුරැකීමසහ පැටවීම සඳහා ආකෘති ලියාපදිංචි කරන්න. get_modules
ශබ්ද කෝෂයක් ලබා nn.Modules
දෙයි conf
. ඔබට ආකෘති වල අභිරුචි ශබ්ද කෝෂයක් ද නියම කළ හැකිය.
683 experiment.add_pytorch_models(get_modules(conf))අත්හදාබැලීම ආරම්භ කර නරඹන්න
685 with experiment.start():පුහුණුවක්රියාත්මක කරන්න
687 conf.run()690def plot_image(img: torch.Tensor):694 from matplotlib import pyplot as pltආතකයCPU වෙත ගෙන යන්න
697 img = img.cpu()සාමාන්යකරණයසඳහා රූපයේ විනාඩි සහ උපරිම අගයන් ලබා ගන්න
699 img_min, img_max = img.min(), img.max()අපිHWC වෙත මානයන් අනුපිළිවෙල වෙනස් කළ යුතුය.
703 img = img.permute(1, 2, 0)රූපයපෙන්වන්න
705 plt.imshow(img)අපටඅක්ෂ අවශ්ය නොවේ
707 plt.axis('off')සංදර්ශකය
709 plt.show()712def evaluate():පුහුණුධාවනයෙන් UUID ධාවනය කරන්න
717 trained_run_uuid = 'f73c1164184711eb9190b74249275441'වින්යාසවස්තුව සාදන්න
719 conf = Configs()අත්හදාබැලීම සාදන්න
721 experiment.create(name='cycle_gan_inference')පුහුණුවසඳහා සකසා ඇති අධි පරාමිතීන් පටවන්න
723 conf_dict = experiment.load_configs(trained_run_uuid)වින්යාසයන්ගණනය කරන්න. අපි ජනක යන්ත්ර නියම කර ඇති අතර 'generator_xy', 'generator_yx'
එමඟින් එය පටවන්නේ ඒවා සහ ඒවායේ පරායත්තතාවයන් පමණි. වින්යාස කිරීම device
හා ගණනය img_channels
කරනු ලැබේ, මේවා අවශ්ය වන බැවින් generator_xy
සහ generator_yx
.
ඔබටවෙනත් පරාමිතීන් අවශ්ය නම් dataset_name
ඔබ ඒවා මෙහි සඳහන් කළ යුතුය. ඔබ කිසිවක් සඳහන් නොකරන්නේ නම්, දත්ත පැටවුම් ඇතුළුව සියලු වින්යාසයන් ගණනය කරනු ලැබේ. වින්යාසයන් සහ ඒවායේ පරායත්තතා ගණනය කිරීම ඔබ අමතන විට සිදුවනු ඇත experiment.start
732 experiment.configs(conf, conf_dict)
733 conf.initialize()සුරැකීමසහ පැටවීම සඳහා ආකෘති ලියාපදිංචි කරන්න. get_modules
ශබ්ද කෝෂයක් ලබා nn.Modules
දෙයි conf
. ඔබට ආකෘති වල අභිරුචි ශබ්ද කෝෂයක් ද නියම කළ හැකිය.
738 experiment.add_pytorch_models(get_modules(conf))කුමනධාවනයෙන් පූරණය කළ යුතු දැයි සඳහන් කරන්න. ඔබ කතා කරන විට පැටවීම ඇත්ත වශයෙන්ම සිදුවනු ඇත experiment.start
741 experiment.load(trained_run_uuid)අත්හදාබැලීම ආරම්භ කරන්න
744 with experiment.start():රූපපරිවර්තනයන්
746 transforms_ = [
747 transforms.ToTensor(),
748 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
749 ]ඔබේමදත්ත පූරණය කරන්න. මෙන්න අපි පරීක්ෂණ කට්ටලය උත්සාහ කරමු. මම යෝස්මයිට් ඡායාරූප සමඟ උත්සාහ කරමින් සිටියෙමි, ඒවා නියමයි. ඔබට ඇමතුමේ ගණනය කිරීමට අවශ්ය වූ දෙයක් dataset_name
ලෙස ඔබ නියම කර ඇත්නම් conf.dataset_name
, ඔබට භාවිතා කළ
experiment.configs
755 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')දත්තකට්ටලයෙන් රූපයක් ලබා ගන්න
757 x_image = dataset[10]['x']රූපයපෙන්වන්න
759 plot_image(x_image)ඇගයීම්මාදිලිය
762 conf.generator_xy.eval()
763 conf.generator_yx.eval()අපටඅනුක්රමික අවශ්ය නොවේ
766 with torch.no_grad():කණ්ඩායම්මානය එක් කර අප භාවිතා කරන උපාංගයට යන්න
768 data = x_image.unsqueeze(0).to(conf.device)
769 generated_y = conf.generator_xy(data)ජනනයකරන ලද රූපය පෙන්වන්න.
772 plot_image(generated_y[0].cpu())
773
774
775if __name__ == '__main__':
776 train()ඇගයීම()