ස්ථාවර විසරණය සඳහා යූ-නෙට්

මෙය ලබා දෙන යූ-නෙට් ක්රියාත්මක කරයි

අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.

18import math
19from typing import List
20
21import numpy as np
22import torch
23import torch.nn as nn
24import torch.nn.functional as F
25
26from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer

යූ-නෙට් ආකෘතිය

29class UNetModel(nn.Module):
  • in_channels ආදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේ
  • out_channels ප්රතිදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේ
  • channels ආකෘතිය සඳහා මූලික නාලිකා ගණන වේ
  • n_res_blocks එක් එක් මට්ටමේ අවශේෂ කුට්ටි ගණන
  • attention_levels අවධානය යොමු කළ යුතු මට්ටම් වේ
  • channel_multipliers එක් එක් මට්ටම් සඳහා නාලිකා ගණන සඳහා බහුකාර්ය සාධක වේ
  • n_heads ට්රාන්ස්ෆෝමර්වල අවධානය යොමු කිරීමේ හිස් සංඛ්යාව
34    def __init__(
35            self, *,
36            in_channels: int,
37            out_channels: int,
38            channels: int,
39            n_res_blocks: int,
40            attention_levels: List[int],
41            channel_multipliers: List[int],
42            n_heads: int,
43            tf_layers: int = 1,
44            d_cond: int = 768):
54        super().__init__()
55        self.channels = channels

මට්ටම් ගණන

58        levels = len(channel_multipliers)

ප්රමාණ කාල කාවැද්දීම්

60        d_time_emb = channels * 4
61        self.time_embed = nn.Sequential(
62            nn.Linear(channels, d_time_emb),
63            nn.SiLU(),
64            nn.Linear(d_time_emb, d_time_emb),
65        )

U-Net ආදාන අඩක්

68        self.input_blocks = nn.ModuleList()

ආදානය සිතියම් ගත කරන මූලික ව්යාවච්ඡාවchannels . විවිධ මොඩියුලවල විවිධ ඉදිරි ක්රියාකාරී අත්සන් ඇති බැවින් කුට්ටිTimestepEmbedSequential මොඩියුලයේ ඔතා ඇත; නිදසුනක් ලෙස, කැටි ගැසීමේදී විශේෂාංග සිතියම පමණක් පිළිගන්නා අතර අවශේෂ කොටස් විශේෂාංග සිතියම සහ වේලාව කාවැද්දීම පිළිගනී. TimestepEmbedSequential ඒ අනුව ඔවුන් අමතයි.

75        self.input_blocks.append(TimestepEmbedSequential(
76            nn.Conv2d(in_channels, channels, 3, padding=1)))

යූ-නෙට් හි ආදාන භාගයේ එක් එක් බ්ලොක් එකේ නාලිකා ගණන

78        input_block_channels = [channels]

එක් එක් මට්ටමේ නාලිකා ගණන

80        channels_list = [channels * m for m in channel_multipliers]

මට්ටම් සකස් කරන්න

82        for i in range(levels):

අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න

84            for _ in range(n_res_blocks):

පෙර නාලිකා සංඛ්යාවේ සිට වර්තමාන මට්ටමේ නාලිකා ගණන දක්වා අවශේෂ බ්ලොක් සිතියම්

87                layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
88                channels = channels_list[i]

ට්රාන්ස්ෆෝමර් එකතු කරන්න

90                if i in attention_levels:
91                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

යූ-නෙට් හි ආදාන භාගයට ඒවා එක් කර එහි ප්රතිදානයේ නාලිකා ගණන නිරීක්ෂණය කරන්න

94                self.input_blocks.append(TimestepEmbedSequential(*layers))
95                input_block_channels.append(channels)

අවසාන වශයෙන් හැර අනෙක් සියලුම මට්ටම්වල පහළ නියැදිය

97            if i != levels - 1:
98                self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
99                input_block_channels.append(channels)

යූ-නෙට් මැද

102        self.middle_block = TimestepEmbedSequential(
103            ResBlock(channels, d_time_emb),
104            SpatialTransformer(channels, n_heads, tf_layers, d_cond),
105            ResBlock(channels, d_time_emb),
106        )

යූ-නෙට් හි දෙවන භාගය

109        self.output_blocks = nn.ModuleList([])

ප්රතිලෝම අනුපිළිවෙලින් මට්ටම් සකස් කරන්න

111        for i in reversed(range(levels)):

අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න

113            for j in range(n_res_blocks + 1):

පෙර නාලිකා සංඛ්යාවෙන් අවශේෂ බ්ලොක් සිතියම් සහ යූ-නෙට් හි ආදාන භාගයේ සිට වත්මන් මට්ටමේ නාලිකා ගණන දක්වා මඟ හැරීමේ සම්බන්ධතා.

117                layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
118                channels = channels_list[i]

ට්රාන්ස්ෆෝමර් එකතු කරන්න

120                if i in attention_levels:
121                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

අන්තිම අවශේෂ කොටස හැර අවසාන අවශේෂ කොටසින් පසු සෑම මට්ටමකම ඉහළට නියැදිය. අපි ආපසු හැරවීමට පුනරාවර්තනය කරන බව සලකන්න; i.e. අවසානi == 0 වේ.

125                if i != 0 and j == n_res_blocks:
126                    layers.append(UpSample(channels))

යූ-නෙට් හි ප්රතිදාන භාගයට එක් කරන්න

128                self.output_blocks.append(TimestepEmbedSequential(*layers))

අවසාන සාමාන්යකරණය සහ කැටි කිරීම

131        self.out = nn.Sequential(
132            normalization(channels),
133            nn.SiLU(),
134            nn.Conv2d(channels, out_channels, 3, padding=1),
135        )

සයිනොසොයිඩල් කාල පියවර කාවැද්දීම් සාදන්න

  • time_steps හැඩයේ කාල පියවර වේ[batch_size]
  • max_period කාවැද්දීම් වල අවම සංඛ්යාතය පාලනය කරයි.
137    def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):

; නාලිකා අඩක් පාපය වන අතර අනෙක් භාගය කෝස් වේ,

145        half = self.channels // 2

147        frequencies = torch.exp(
148            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
149        ).to(device=time_steps.device)

151        args = time_steps[:, None].float() * frequencies[None]

සහ

153        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  • x හැඩයේ ආදාන විශේෂාංග සිතියමයි[batch_size, channels, width, height]
  • time_steps හැඩයේ කාල පියවර වේ[batch_size]
  • cond හැඩයේ කන්ඩිෂනේෂන්[batch_size, n_cond, d_cond]
155    def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):

මඟ හැරීමේ සම්බන්ධතා සඳහා ආදාන අර්ධ ප්රතිදානයන් ගබඩා කිරීම

162        x_input_block = []

කාලය පියවර කාවැද්දීම් ලබා ගන්න

165        t_emb = self.time_step_embedding(time_steps)
166        t_emb = self.time_embed(t_emb)

U-Net ආදාන අඩක්

169        for module in self.input_blocks:
170            x = module(x, t_emb, cond)
171            x_input_block.append(x)

යූ-නෙට් මැද

173        x = self.middle_block(x, t_emb, cond)

U-Net ප්රතිදාන අඩක්

175        for module in self.output_blocks:
176            x = torch.cat([x, x_input_block.pop()], dim=1)
177            x = module(x, t_emb, cond)

අවසාන සාමාන්යකරණය සහ කැටි කිරීම

180        return self.out(x)

විවිධ යෙදවුම් සහිත මොඩියුල සඳහා අනුක්රමික කොටස

මෙම අනුක්රමික මොඩියුලයට විවිධ මොඩියුලයන් උරා බොනnn.Conv SpatialTransformer අතර ගැලපෙන අත්සන් සමඟ ඒවා අමතන්නResBlock

183class TimestepEmbedSequential(nn.Sequential):
191    def forward(self, x, t_emb, cond=None):
192        for layer in self:
193            if isinstance(layer, ResBlock):
194                x = layer(x, t_emb)
195            elif isinstance(layer, SpatialTransformer):
196                x = layer(x, cond)
197            else:
198                x = layer(x)
199        return x

දක්වා-නියැදීම් ස්ථරය

202class UpSample(nn.Module):
  • channels යනු නාලිකා ගණන
207    def __init__(self, channels: int):
211        super().__init__()

කැටි ගැසීමේ සිතියම්කරණය

213        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
215    def forward(self, x: torch.Tensor):

සාධකයක් අනුව ඉහළ නියැදිය

220        x = F.interpolate(x, scale_factor=2, mode="nearest")

කැටි ගැසිම යොදන්න

222        return self.conv(x)

පහළ-නියැදි ස්ථරය

225class DownSample(nn.Module):
  • channels යනු නාලිකා ගණන
230    def __init__(self, channels: int):
234        super().__init__()

ක සාධකයක් විසින් පහළ-නියැදි කිරීමට stride දිග සමග convolution

236        self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
238    def forward(self, x: torch.Tensor):

කැටි ගැසිම යොදන්න

243        return self.op(x)

රෙස්නෙට් බ්ලොක්

246class ResBlock(nn.Module):
  • channels ආදාන නාලිකා ගණන
  • d_t_emb කාලරාමු කාවැද්දීම් වල ප්රමාණය
  • out_channels පිටතට ඇති නාලිකා ගණන වේ. `නාලිකාවලට පෙරනිමි.
251    def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
257        super().__init__()

out_channels නිශ්චිතව දක්වා නැත

259        if out_channels is None:
260            out_channels = channels

පළමු සාමාන්යකරණය සහ කැටි ගැසිම

263        self.in_layers = nn.Sequential(
264            normalization(channels),
265            nn.SiLU(),
266            nn.Conv2d(channels, out_channels, 3, padding=1),
267        )

කාල පියවර කාවැද්දීම්

270        self.emb_layers = nn.Sequential(
271            nn.SiLU(),
272            nn.Linear(d_t_emb, out_channels),
273        )

අවසාන කැටි ගැසුණු ස්ථරය

275        self.out_layers = nn.Sequential(
276            normalization(out_channels),
277            nn.SiLU(),
278            nn.Dropout(0.),
279            nn.Conv2d(out_channels, out_channels, 3, padding=1)
280        )

channels අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels සිතියම්ගත කිරීම

283        if out_channels == channels:
284            self.skip_connection = nn.Identity()
285        else:
286            self.skip_connection = nn.Conv2d(channels, out_channels, 1)
  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
  • t_emb හැඩයේ කාල පියවර කාවැද්දීම් වේ[batch_size, d_t_emb]
  • 288    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):

    මූලික කැටි ගැසිම

    294        h = self.in_layers(x)

    කාල පියවර කාවැද්දීම්

    296        t_emb = self.emb_layers(t_emb).type(h.dtype)

    කාල පියවර කාවැද්දීම් එකතු කරන්න

    298        h = h + t_emb[:, :, None, None]

    අවසාන කැටි ගැසිම

    300        h = self.out_layers(h)

    මඟ හැරීමේ සම්බන්ධතාවය එක් කරන්න

    302        return self.skip_connection(x) + h

    float32 වාත්තු සමග කණ්ඩායම් සාමාන්යකරණය

    305class GroupNorm32(nn.GroupNorm):
    310    def forward(self, x):
    311        return super().forward(x.float()).type(x.dtype)

    කණ්ඩායම් සාමාන්යකරණය

    මෙය උපකාරක ශ්රිතයක් වන අතර ස්ථාවර කණ්ඩායම් සංඛ්යාවක් ඇත..

    314def normalization(channels):
    320    return GroupNorm32(32, channels)

    සයිනොසොයිඩල් කාල පියවර කාවැද්දීම් පරීක්ෂා කරන්න

    323def _test_time_embeddings():
    327    import matplotlib.pyplot as plt
    328
    329    plt.figure(figsize=(15, 5))
    330    m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
    331                  channel_multipliers=[],
    332                  n_heads=1, tf_layers=1, d_cond=1)
    333    te = m.time_step_embedding(torch.arange(0, 1000))
    334    plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
    335    plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
    336    plt.title("Time embeddings")
    337    plt.show()

    341if __name__ == '__main__':
    342    _test_time_embeddings()