ස්ථාවර විසරණය සඳහා ස්වයංක්රීය ආකේතකය

රූප අවකාශය සහ ගුප්ත අවකාශය අතර සිතියම් ගත කිරීම සඳහා භාවිතා කරන ස්වයංක්රීය එන්කෝඩර් ආකෘතිය මෙය ක්රියාත්මක කරයි.

අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.

18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn

ඔටෝඑන්කෝඩරය

මෙය එන්කෝඩරය සහ විකේතක මොඩියුල වලින් සමන්විත වේ.

25class Autoencoder(nn.Module):
  • encoder ආකේතකය වේ
  • decoder විකේතකය වේ
  • emb_channels යනු ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මානයන් ගණන
  • z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ
  • 32    def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
    39        super().__init__()
    40        self.encoder = encoder
    41        self.decoder = decoder

    අවකාශය කාවැද්දීමේ සිට ප්රමාණාත්මක කාවැද්දීමේ අවකාශ අවස්ථා දක්වා සිතියම් ගත කිරීම (මධ්යන්ය සහ ලොග් විචලතාව)

    44        self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)

    ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ සිට නැවත කාවැද්දීම අවකාශය දක්වා සිතියමට සංකෝචනය

    47        self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)

    ගුප්ත නිරූපණයට රූප කේතනය කරන්න

    • img හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]
    49    def encode(self, img: torch.Tensor) -> 'GaussianDistribution':

    හැඩය සහිත කාවැද්දීම් ලබා ගන්න[batch_size, z_channels * 2, z_height, z_height]

    56        z = self.encoder(img)

    ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මොහොත ලබා ගන්න

    58        moments = self.quant_conv(z)

    බෙදා හැරීම ආපසු ලබා දෙන්න

    60        return GaussianDistribution(moments)

    ගුප්ත නිරූපණයෙන් රූප විකේතනය කරන්න

    • z හැඩය සහිත ගුප්ත නිරූපණයයි[batch_size, emb_channels, z_height, z_height]
    62    def decode(self, z: torch.Tensor):

    ප්රමාණාත්මක නිරූපණයෙන් අවකාශය කාවැද්දීම සඳහා සිතියම

    69        z = self.post_quant_conv(z)

    හැඩයේ රූපය විකේතනය කරන්න[batch_size, channels, height, width]

    71        return self.decoder(z)

    එන්කෝඩර් මොඩියුලය

    74class Encoder(nn.Module):
    • channels පළමු සංවහන ස්ථරයේ නාලිකා ගණන වේ
    • channel_multipliers පසුකාලීන බ්ලොක් වල නාලිකා සංඛ්යාව සඳහා බහුකාර්ය සාධක වේ
    • n_resnet_blocks එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේ
    • in_channels යනු රූපයේ ඇති නාලිකා ගණන
  • z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ
  • 79    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
    80                 in_channels: int, z_channels: int):
    89        super().__init__()

    විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ

    93        n_resolutions = len(channel_multipliers)

    රූපය සිතියම් ගත කරන මූලික කැටි ගැසුණු ස්ථරයchannels

    96        self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)

    එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන

    99        channels_list = [m * channels for m in [1] + channel_multipliers]

    ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව

    102        self.down = nn.ModuleList()

    ඉහළ මට්ටමේ කුට්ටි සාදන්න

    104        for i in range(n_resolutions):

    සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ පහළ-නියැදීම් වලින් සමන්විත වේ

    106            resnet_blocks = nn.ModuleList()

    රෙස්නෙට් බ්ලොක් එකතු කරන්න

    108            for _ in range(n_resnet_blocks):
    109                resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
    110                channels = channels_list[i + 1]

    ඉහළ මට්ටමේ බ්ලොක්

    112            down = nn.Module()
    113            down.block = resnet_blocks

    අන්තිම හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ පහළ-නියැදීම

    115            if i != n_resolutions - 1:
    116                down.downsample = DownSample(channels)
    117            else:
    118                down.downsample = nn.Identity()

    120            self.down.append(down)

    අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්

    123        self.mid = nn.Module()
    124        self.mid.block_1 = ResnetBlock(channels, channels)
    125        self.mid.attn_1 = AttnBlock(channels)
    126        self.mid.block_2 = ResnetBlock(channels, channels)

    කැටි ගැසීමකින් අවකාශය කාවැද්දීම සඳහා සිතියම

    129        self.norm_out = normalization(channels)
    130        self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
    • img හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]
    132    def forward(self, img: torch.Tensor):

    ආරම්භක කැටි ගැස්මchannels සමඟ සිතියම

    138        x = self.conv_in(img)

    ඉහළ මට්ටමේ කුට්ටි

    141        for down in self.down:

    රෙස්නෙට් බ්ලොක්

    143            for block in down.block:
    144                x = block(x)

    පහළ-නියැදීම්

    146            x = down.downsample(x)

    අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්

    149        x = self.mid.block_1(x)
    150        x = self.mid.attn_1(x)
    151        x = self.mid.block_2(x)

    අවකාශය කාවැද්දීම සඳහා සාමාන්යකරණය කර සිතියම් ගත කරන්න

    154        x = self.norm_out(x)
    155        x = swish(x)
    156        x = self.conv_out(x)

    159        return x

    විකේතක මොඩියුලය

    162class Decoder(nn.Module):
    • channels අවසාන සංවහන ස්ථරයේ නාලිකා ගණන වේ
    • channel_multipliers පෙර බ්ලොක් වල නාලිකා ගණන සඳහා බහුකාර්ය සාධක, ප්රතිලෝම අනුපිළිවෙල
    • n_resnet_blocks එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේ
    • out_channels යනු රූපයේ ඇති නාලිකා ගණන
  • z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ
  • 167    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
    168                 out_channels: int, z_channels: int):
    177        super().__init__()

    විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ

    181        num_resolutions = len(channel_multipliers)

    ප්රතිලෝම අනුපිළිවෙලෙහි එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන

    184        channels_list = [m * channels for m in channel_multipliers]

    ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන

    187        channels = channels_list[-1]

    කාවැද්දීමේ අවකාශය සිතියම් ගත කරන මූලික කැටි ගැස්වීමේ ස්ථරයchannels

    190        self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)

    අවධානය සහිත රෙස්නෙට් බ්ලොක්

    193        self.mid = nn.Module()
    194        self.mid.block_1 = ResnetBlock(channels, channels)
    195        self.mid.attn_1 = AttnBlock(channels)
    196        self.mid.block_2 = ResnetBlock(channels, channels)

    ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව

    199        self.up = nn.ModuleList()

    ඉහළ මට්ටමේ කුට්ටි සාදන්න

    201        for i in reversed(range(num_resolutions)):

    සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ ඉහළ නියැදීම් වලින් සමන්විත වේ

    203            resnet_blocks = nn.ModuleList()

    රෙස්නෙට් බ්ලොක් එකතු කරන්න

    205            for _ in range(n_resnet_blocks + 1):
    206                resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
    207                channels = channels_list[i]

    ඉහළ මට්ටමේ බ්ලොක්

    209            up = nn.Module()
    210            up.block = resnet_blocks

    පළමුවැන්න හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ ඉහළට නියැදීම

    212            if i != 0:
    213                up.upsample = UpSample(channels)
    214            else:
    215                up.upsample = nn.Identity()

    මුරපොලට අනුකූල වීමට සූදානම් වන්න

    217            self.up.insert(0, up)

    සංකෝචනය සමඟ රූප අවකාශයට සිතියම

    220        self.norm_out = normalization(channels)
    221        self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
    • z හැඩය සහිත කාවැද්දීම tensor වේ[batch_size, z_channels, z_height, z_height]
    223    def forward(self, z: torch.Tensor):

    ආරම්භක කැටි ගැස්මchannels සමඟ සිතියම

    229        h = self.conv_in(z)

    අවධානය සහිත රෙස්නෙට් බ්ලොක්

    232        h = self.mid.block_1(h)
    233        h = self.mid.attn_1(h)
    234        h = self.mid.block_2(h)

    ඉහළ මට්ටමේ කුට්ටි

    237        for up in reversed(self.up):

    රෙස්නෙට් බ්ලොක්

    239            for block in up.block:
    240                h = block(h)

    ඉහළට නියැදීම

    242            h = up.upsample(h)

    රූප අවකාශයට සාමාන්යකරණය කර සිතියම් ගත කරන්න

    245        h = self.norm_out(h)
    246        h = swish(h)
    247        img = self.conv_out(h)

    250        return img

    ගවුසියානු බෙදාහැරීම්

    253class GaussianDistribution:
    • parameters හැඩයේ කාවැද්දීම පිළිබඳ විචලනයන්ගේ මාධ්යයන් සහ ලොග් වේ[batch_size, z_channels * 2, z_height, z_height]
    258    def __init__(self, parameters: torch.Tensor):

    භේදය මධ්යන්යය සහ විචලතාව ලඝු-සටහන

    264        self.mean, log_var = torch.chunk(parameters, 2, dim=1)

    විචල්යයන්ගේ ලොග් දැමීම

    266        self.log_var = torch.clamp(log_var, -30.0, 20.0)

    සම්මත අපගමනය ගණනය කරන්න

    268        self.std = torch.exp(0.5 * self.log_var)
    270    def sample(self):

    බෙදාහැරීමෙන් නියැදිය

    272        return self.mean + self.std * torch.randn_like(self.std)

    අවධානය වාරණ

    275class AttnBlock(nn.Module):
    • channels යනු නාලිකා ගණන
    280    def __init__(self, channels: int):
    284        super().__init__()

    කණ්ඩායම් සාමාන්යකරණය

    286        self.norm = normalization(channels)

    විමසුම්, යතුර සහ අගය සිතියම්

    288        self.q = nn.Conv2d(channels, channels, 1)
    289        self.k = nn.Conv2d(channels, channels, 1)
    290        self.v = nn.Conv2d(channels, channels, 1)

    අවසාන කැටි ගැසුණු ස්ථරය

    292        self.proj_out = nn.Conv2d(channels, channels, 1)

    අවධානය පරිමාණ සාධකය

    294        self.scale = channels ** -0.5
    • x හැඩයේ ආතතිකය වේ[batch_size, channels, height, width]
    296    def forward(self, x: torch.Tensor):

    සාමාන්‍ය කරන්නx

    301        x_norm = self.norm(x)

    විමසුම, යතුර සහ දෛශික කාවැද්දීම් ලබා ගන්න

    303        q = self.q(x_norm)
    304        k = self.k(x_norm)
    305        v = self.v(x_norm)

    විමසුම, ප්රධාන සහ දෛශික කාවැද්දීම්[batch_size, channels, height, width] වෙත නැවත සකස් කරන්න[batch_size, channels, height * width]

    310        b, c, h, w = q.shape
    311        q = q.view(b, c, h * w)
    312        k = k.view(b, c, h * w)
    313        v = v.view(b, c, h * w)

    ගණනය කරන්න

    316        attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
    317        attn = F.softmax(attn, dim=2)

    ගණනය කරන්න

    320        out = torch.einsum('bij,bcj->bci', attn, v)

    නැවත සකස් කරන්න[batch_size, channels, height, width]

    323        out = out.view(b, c, h, w)

    අවසාන කැටි ගැසුණු ස්ථරය

    325        out = self.proj_out(out)

    අවශේෂ සම්බන්ධතාවය එක් කරන්න

    328        return x + out

    දක්වා-නියැදීම් ස්ථරය

    331class UpSample(nn.Module):
    • channels යනු නාලිකා ගණන
    335    def __init__(self, channels: int):
    339        super().__init__()

    කැටි ගැසීමේ සිතියම්කරණය

    341        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
    • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
    343    def forward(self, x: torch.Tensor):

    සාධකයක් අනුව ඉහළ නියැදිය

    348        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

    කැටි ගැසිම යොදන්න

    350        return self.conv(x)

    පහළ-නියැදි ස්ථරය

    353class DownSample(nn.Module):
    • channels යනු නාලිකා ගණන
    357    def __init__(self, channels: int):
    361        super().__init__()

    ක සාධකයක් විසින් පහළ-නියැදි කිරීමට stride දිග සමග convolution

    363        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
    • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
    365    def forward(self, x: torch.Tensor):

    පෑඩින් එකතු කරන්න

    370        x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)

    කැටි ගැසිම යොදන්න

    372        return self.conv(x)

    රෙස්නෙට් බ්ලොක්

    375class ResnetBlock(nn.Module):
    • in_channels යනු ආදානයේ නාලිකා ගණන
    • out_channels නිමැවුමේ නාලිකා ගණන වේ
    379    def __init__(self, in_channels: int, out_channels: int):
    384        super().__init__()

    පළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

    386        self.norm1 = normalization(in_channels)
    387        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)

    දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

    389        self.norm2 = normalization(out_channels)
    390        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)

    in_channels අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels සිතියම්ගත කිරීම

    392        if in_channels != out_channels:
    393            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
    394        else:
    395            self.nin_shortcut = nn.Identity()
    • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
    397    def forward(self, x: torch.Tensor):
    402        h = x

    පළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

    405        h = self.norm1(h)
    406        h = swish(h)
    407        h = self.conv1(h)

    දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

    410        h = self.norm2(h)
    411        h = swish(h)
    412        h = self.conv2(h)

    සිතියම සහ අවශේෂ එකතු කරන්න

    415        return self.nin_shortcut(x) + h

    ස්විෂ් සක්රිය කිරීම

    418def swish(x: torch.Tensor):
    424    return x * torch.sigmoid(x)

    කණ්ඩායම් සාමාන්යකරණය

    මෙය උපකාරක ශ්රිතයක් වන අතර ස්ථාවර කණ්ඩායම් ගණන සහeps .

    427def normalization(channels: int):
    433    return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)