රූප අවකාශය සහ ගුප්ත අවකාශය අතර සිතියම් ගත කිරීම සඳහා භාවිතා කරන ස්වයංක්රීය එන්කෝඩර් ආකෘතිය මෙය ක්රියාත්මක කරයි.
අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.
18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn25class Autoencoder(nn.Module):encoder
ආකේතකය වේdecoder
විකේතකය වේemb_channels
යනු ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මානයන් ගණනz_channels
කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ32 def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):39 super().__init__()
40 self.encoder = encoder
41 self.decoder = decoderඅවකාශය කාවැද්දීමේ සිට ප්රමාණාත්මක කාවැද්දීමේ අවකාශ අවස්ථා දක්වා සිතියම් ගත කිරීම (මධ්යන්ය සහ ලොග් විචලතාව)
44 self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ සිට නැවත කාවැද්දීම අවකාශය දක්වා සිතියමට සංකෝචනය
47 self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)img
හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]
49 def encode(self, img: torch.Tensor) -> 'GaussianDistribution':හැඩය සහිත කාවැද්දීම් ලබා ගන්න[batch_size, z_channels * 2, z_height, z_height]
56 z = self.encoder(img)ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මොහොත ලබා ගන්න
58 moments = self.quant_conv(z)බෙදා හැරීම ආපසු ලබා දෙන්න
60 return GaussianDistribution(moments)z
හැඩය සහිත ගුප්ත නිරූපණයයි[batch_size, emb_channels, z_height, z_height]
62 def decode(self, z: torch.Tensor):ප්රමාණාත්මක නිරූපණයෙන් අවකාශය කාවැද්දීම සඳහා සිතියම
69 z = self.post_quant_conv(z)හැඩයේ රූපය විකේතනය කරන්න[batch_size, channels, height, width]
71 return self.decoder(z)74class Encoder(nn.Module):channels
පළමු සංවහන ස්ථරයේ නාලිකා ගණන වේchannel_multipliers
පසුකාලීන බ්ලොක් වල නාලිකා සංඛ්යාව සඳහා බහුකාර්ය සාධක වේn_resnet_blocks
එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේin_channels
යනු රූපයේ ඇති නාලිකා ගණනz_channels
කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ79 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
80 in_channels: int, z_channels: int):89 super().__init__()විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ
93 n_resolutions = len(channel_multipliers)රූපය සිතියම් ගත කරන මූලික කැටි ගැසුණු ස්ථරයchannels
96 self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන
99 channels_list = [m * channels for m in [1] + channel_multipliers]ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව
102 self.down = nn.ModuleList()ඉහළ මට්ටමේ කුට්ටි සාදන්න
104 for i in range(n_resolutions):සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ පහළ-නියැදීම් වලින් සමන්විත වේ
106 resnet_blocks = nn.ModuleList()රෙස්නෙට් බ්ලොක් එකතු කරන්න
108 for _ in range(n_resnet_blocks):
109 resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
110 channels = channels_list[i + 1]ඉහළ මට්ටමේ බ්ලොක්
112 down = nn.Module()
113 down.block = resnet_blocksඅන්තිම හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ පහළ-නියැදීම
115 if i != n_resolutions - 1:
116 down.downsample = DownSample(channels)
117 else:
118 down.downsample = nn.Identity()120 self.down.append(down)අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්
123 self.mid = nn.Module()
124 self.mid.block_1 = ResnetBlock(channels, channels)
125 self.mid.attn_1 = AttnBlock(channels)
126 self.mid.block_2 = ResnetBlock(channels, channels)කැටි ගැසීමකින් අවකාශය කාවැද්දීම සඳහා සිතියම
129 self.norm_out = normalization(channels)
130 self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)img
හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]
132 def forward(self, img: torch.Tensor):ආරම්භක කැටි ගැස්මchannels
සමඟ සිතියම
138 x = self.conv_in(img)ඉහළ මට්ටමේ කුට්ටි
141 for down in self.down:රෙස්නෙට් බ්ලොක්
143 for block in down.block:
144 x = block(x)පහළ-නියැදීම්
146 x = down.downsample(x)අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්
149 x = self.mid.block_1(x)
150 x = self.mid.attn_1(x)
151 x = self.mid.block_2(x)අවකාශය කාවැද්දීම සඳහා සාමාන්යකරණය කර සිතියම් ගත කරන්න
154 x = self.norm_out(x)
155 x = swish(x)
156 x = self.conv_out(x)159 return x162class Decoder(nn.Module):channels
අවසාන සංවහන ස්ථරයේ නාලිකා ගණන වේchannel_multipliers
පෙර බ්ලොක් වල නාලිකා ගණන සඳහා බහුකාර්ය සාධක, ප්රතිලෝම අනුපිළිවෙලn_resnet_blocks
එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේout_channels
යනු රූපයේ ඇති නාලිකා ගණනz_channels
කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ167 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
168 out_channels: int, z_channels: int):177 super().__init__()විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ
181 num_resolutions = len(channel_multipliers)ප්රතිලෝම අනුපිළිවෙලෙහි එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන
184 channels_list = [m * channels for m in channel_multipliers]ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන
187 channels = channels_list[-1]කාවැද්දීමේ අවකාශය සිතියම් ගත කරන මූලික කැටි ගැස්වීමේ ස්ථරයchannels
190 self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)අවධානය සහිත රෙස්නෙට් බ්ලොක්
193 self.mid = nn.Module()
194 self.mid.block_1 = ResnetBlock(channels, channels)
195 self.mid.attn_1 = AttnBlock(channels)
196 self.mid.block_2 = ResnetBlock(channels, channels)ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව
199 self.up = nn.ModuleList()ඉහළ මට්ටමේ කුට්ටි සාදන්න
201 for i in reversed(range(num_resolutions)):සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ ඉහළ නියැදීම් වලින් සමන්විත වේ
203 resnet_blocks = nn.ModuleList()රෙස්නෙට් බ්ලොක් එකතු කරන්න
205 for _ in range(n_resnet_blocks + 1):
206 resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
207 channels = channels_list[i]ඉහළ මට්ටමේ බ්ලොක්
209 up = nn.Module()
210 up.block = resnet_blocksපළමුවැන්න හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ ඉහළට නියැදීම
212 if i != 0:
213 up.upsample = UpSample(channels)
214 else:
215 up.upsample = nn.Identity()මුරපොලට අනුකූල වීමට සූදානම් වන්න
217 self.up.insert(0, up)සංකෝචනය සමඟ රූප අවකාශයට සිතියම
220 self.norm_out = normalization(channels)
221 self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)z
හැඩය සහිත කාවැද්දීම tensor වේ[batch_size, z_channels, z_height, z_height]
223 def forward(self, z: torch.Tensor):ආරම්භක කැටි ගැස්මchannels
සමඟ සිතියම
229 h = self.conv_in(z)අවධානය සහිත රෙස්නෙට් බ්ලොක්
232 h = self.mid.block_1(h)
233 h = self.mid.attn_1(h)
234 h = self.mid.block_2(h)ඉහළ මට්ටමේ කුට්ටි
237 for up in reversed(self.up):රෙස්නෙට් බ්ලොක්
239 for block in up.block:
240 h = block(h)ඉහළට නියැදීම
242 h = up.upsample(h)රූප අවකාශයට සාමාන්යකරණය කර සිතියම් ගත කරන්න
245 h = self.norm_out(h)
246 h = swish(h)
247 img = self.conv_out(h)250 return img253class GaussianDistribution:parameters
හැඩයේ කාවැද්දීම පිළිබඳ විචලනයන්ගේ මාධ්යයන් සහ ලොග් වේ[batch_size, z_channels * 2, z_height, z_height]
258 def __init__(self, parameters: torch.Tensor):භේදය මධ්යන්යය සහ විචලතාව ලඝු-සටහන
264 self.mean, log_var = torch.chunk(parameters, 2, dim=1)විචල්යයන්ගේ ලොග් දැමීම
266 self.log_var = torch.clamp(log_var, -30.0, 20.0)සම්මත අපගමනය ගණනය කරන්න
268 self.std = torch.exp(0.5 * self.log_var)270 def sample(self):බෙදාහැරීමෙන් නියැදිය
272 return self.mean + self.std * torch.randn_like(self.std)275class AttnBlock(nn.Module):channels
යනු නාලිකා ගණන280 def __init__(self, channels: int):284 super().__init__()කණ්ඩායම් සාමාන්යකරණය
286 self.norm = normalization(channels)විමසුම්, යතුර සහ අගය සිතියම්
288 self.q = nn.Conv2d(channels, channels, 1)
289 self.k = nn.Conv2d(channels, channels, 1)
290 self.v = nn.Conv2d(channels, channels, 1)අවසාන කැටි ගැසුණු ස්ථරය
292 self.proj_out = nn.Conv2d(channels, channels, 1)අවධානය පරිමාණ සාධකය
294 self.scale = channels ** -0.5x
හැඩයේ ආතතිකය වේ[batch_size, channels, height, width]
296 def forward(self, x: torch.Tensor):සාමාන්ය කරන්නx
301 x_norm = self.norm(x)විමසුම, යතුර සහ දෛශික කාවැද්දීම් ලබා ගන්න
303 q = self.q(x_norm)
304 k = self.k(x_norm)
305 v = self.v(x_norm)විමසුම, ප්රධාන සහ දෛශික කාවැද්දීම්[batch_size, channels, height, width]
වෙත නැවත සකස් කරන්න[batch_size, channels, height * width]
310 b, c, h, w = q.shape
311 q = q.view(b, c, h * w)
312 k = k.view(b, c, h * w)
313 v = v.view(b, c, h * w)ගණනය කරන්න
316 attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
317 attn = F.softmax(attn, dim=2)ගණනය කරන්න
320 out = torch.einsum('bij,bcj->bci', attn, v)නැවත සකස් කරන්න[batch_size, channels, height, width]
323 out = out.view(b, c, h, w)අවසාන කැටි ගැසුණු ස්ථරය
325 out = self.proj_out(out)අවශේෂ සම්බන්ධතාවය එක් කරන්න
328 return x + out331class UpSample(nn.Module):channels
යනු නාලිකා ගණන335 def __init__(self, channels: int):339 super().__init__()කැටි ගැසීමේ සිතියම්කරණය
341 self.conv = nn.Conv2d(channels, channels, 3, padding=1)x
හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
343 def forward(self, x: torch.Tensor):සාධකයක් අනුව ඉහළ නියැදිය
348 x = F.interpolate(x, scale_factor=2.0, mode="nearest")කැටි ගැසිම යොදන්න
350 return self.conv(x)353class DownSample(nn.Module):channels
යනු නාලිකා ගණන357 def __init__(self, channels: int):361 super().__init__()ක සාධකයක් විසින් පහළ-නියැදි කිරීමට stride දිග සමග convolution
363 self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)x
හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
365 def forward(self, x: torch.Tensor):පෑඩින් එකතු කරන්න
370 x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)කැටි ගැසිම යොදන්න
372 return self.conv(x)375class ResnetBlock(nn.Module):in_channels
යනු ආදානයේ නාලිකා ගණනout_channels
නිමැවුමේ නාලිකා ගණන වේ379 def __init__(self, in_channels: int, out_channels: int):384 super().__init__()පළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
386 self.norm1 = normalization(in_channels)
387 self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
389 self.norm2 = normalization(out_channels)
390 self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)in_channels
අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels
සිතියම්ගත කිරීම
392 if in_channels != out_channels:
393 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
394 else:
395 self.nin_shortcut = nn.Identity()x
හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
397 def forward(self, x: torch.Tensor):402 h = xපළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
405 h = self.norm1(h)
406 h = swish(h)
407 h = self.conv1(h)දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
410 h = self.norm2(h)
411 h = swish(h)
412 h = self.conv2(h)සිතියම සහ අවශේෂ එකතු කරන්න
415 return self.nin_shortcut(x) + h418def swish(x: torch.Tensor):424 return x * torch.sigmoid(x)427def normalization(channels: int):433 return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)