මෙය ලබා දෙන යූ-නෙට් ක්රියාත්මක කරයි
අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.
18import math
19from typing import List
20
21import numpy as np
22import torch
23import torch.nn as nn
24import torch.nn.functional as F
25
26from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer29class UNetModel(nn.Module):in_channels
ආදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේout_channels
ප්රතිදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේchannels
ආකෘතිය සඳහා මූලික නාලිකා ගණන වේn_res_blocks
එක් එක් මට්ටමේ අවශේෂ කුට්ටි ගණනattention_levels
අවධානය යොමු කළ යුතු මට්ටම් වේchannel_multipliers
එක් එක් මට්ටම් සඳහා නාලිකා ගණන සඳහා බහුකාර්ය සාධක වේn_heads
ට්රාන්ස්ෆෝමර්වල අවධානය යොමු කිරීමේ හිස් සංඛ්යාව34 def __init__(
35 self, *,
36 in_channels: int,
37 out_channels: int,
38 channels: int,
39 n_res_blocks: int,
40 attention_levels: List[int],
41 channel_multipliers: List[int],
42 n_heads: int,
43 tf_layers: int = 1,
44 d_cond: int = 768):54 super().__init__()
55 self.channels = channelsමට්ටම් ගණන
58 levels = len(channel_multipliers)ප්රමාණ කාල කාවැද්දීම්
60 d_time_emb = channels * 4
61 self.time_embed = nn.Sequential(
62 nn.Linear(channels, d_time_emb),
63 nn.SiLU(),
64 nn.Linear(d_time_emb, d_time_emb),
65 )U-Net ආදාන අඩක්
68 self.input_blocks = nn.ModuleList()ආදානය සිතියම් ගත කරන මූලික ව්යාවච්ඡාවchannels
. විවිධ මොඩියුලවල විවිධ ඉදිරි ක්රියාකාරී අත්සන් ඇති බැවින් කුට්ටිTimestepEmbedSequential
මොඩියුලයේ ඔතා ඇත; නිදසුනක් ලෙස, කැටි ගැසීමේදී විශේෂාංග සිතියම පමණක් පිළිගන්නා අතර අවශේෂ කොටස් විශේෂාංග සිතියම සහ වේලාව කාවැද්දීම පිළිගනී. TimestepEmbedSequential
ඒ අනුව ඔවුන් අමතයි.
75 self.input_blocks.append(TimestepEmbedSequential(
76 nn.Conv2d(in_channels, channels, 3, padding=1)))යූ-නෙට් හි ආදාන භාගයේ එක් එක් බ්ලොක් එකේ නාලිකා ගණන
78 input_block_channels = [channels]එක් එක් මට්ටමේ නාලිකා ගණන
80 channels_list = [channels * m for m in channel_multipliers]මට්ටම් සකස් කරන්න
82 for i in range(levels):අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න
84 for _ in range(n_res_blocks):පෙර නාලිකා සංඛ්යාවේ සිට වර්තමාන මට්ටමේ නාලිකා ගණන දක්වා අවශේෂ බ්ලොක් සිතියම්
87 layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
88 channels = channels_list[i]ට්රාන්ස්ෆෝමර් එකතු කරන්න
90 if i in attention_levels:
91 layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))යූ-නෙට් හි ආදාන භාගයට ඒවා එක් කර එහි ප්රතිදානයේ නාලිකා ගණන නිරීක්ෂණය කරන්න
94 self.input_blocks.append(TimestepEmbedSequential(*layers))
95 input_block_channels.append(channels)අවසාන වශයෙන් හැර අනෙක් සියලුම මට්ටම්වල පහළ නියැදිය
97 if i != levels - 1:
98 self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
99 input_block_channels.append(channels)යූ-නෙට් මැද
102 self.middle_block = TimestepEmbedSequential(
103 ResBlock(channels, d_time_emb),
104 SpatialTransformer(channels, n_heads, tf_layers, d_cond),
105 ResBlock(channels, d_time_emb),
106 )යූ-නෙට් හි දෙවන භාගය
109 self.output_blocks = nn.ModuleList([])ප්රතිලෝම අනුපිළිවෙලින් මට්ටම් සකස් කරන්න
111 for i in reversed(range(levels)):අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න
113 for j in range(n_res_blocks + 1):පෙර නාලිකා සංඛ්යාවෙන් අවශේෂ බ්ලොක් සිතියම් සහ යූ-නෙට් හි ආදාන භාගයේ සිට වත්මන් මට්ටමේ නාලිකා ගණන දක්වා මඟ හැරීමේ සම්බන්ධතා.
117 layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
118 channels = channels_list[i]ට්රාන්ස්ෆෝමර් එකතු කරන්න
120 if i in attention_levels:
121 layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))අන්තිම අවශේෂ කොටස හැර අවසාන අවශේෂ කොටසින් පසු සෑම මට්ටමකම ඉහළට නියැදිය. අපි ආපසු හැරවීමට පුනරාවර්තනය කරන බව සලකන්න; i.e. අවසානi == 0
වේ.
125 if i != 0 and j == n_res_blocks:
126 layers.append(UpSample(channels))යූ-නෙට් හි ප්රතිදාන භාගයට එක් කරන්න
128 self.output_blocks.append(TimestepEmbedSequential(*layers))අවසාන සාමාන්යකරණය සහ කැටි කිරීම
131 self.out = nn.Sequential(
132 normalization(channels),
133 nn.SiLU(),
134 nn.Conv2d(channels, out_channels, 3, padding=1),
135 )time_steps
හැඩයේ කාල පියවර වේ[batch_size]
max_period
කාවැද්දීම් වල අවම සංඛ්යාතය පාලනය කරයි.137 def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):; නාලිකා අඩක් පාපය වන අතර අනෙක් භාගය කෝස් වේ,
145 half = self.channels // 2147 frequencies = torch.exp(
148 -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
149 ).to(device=time_steps.device)151 args = time_steps[:, None].float() * frequencies[None]සහ
153 return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)x
හැඩයේ ආදාන විශේෂාංග සිතියමයි[batch_size, channels, width, height]
time_steps
හැඩයේ කාල පියවර වේ[batch_size]
cond
හැඩයේ කන්ඩිෂනේෂන්[batch_size, n_cond, d_cond]
155 def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):මඟ හැරීමේ සම්බන්ධතා සඳහා ආදාන අර්ධ ප්රතිදානයන් ගබඩා කිරීම
162 x_input_block = []කාලය පියවර කාවැද්දීම් ලබා ගන්න
165 t_emb = self.time_step_embedding(time_steps)
166 t_emb = self.time_embed(t_emb)U-Net ආදාන අඩක්
169 for module in self.input_blocks:
170 x = module(x, t_emb, cond)
171 x_input_block.append(x)යූ-නෙට් මැද
173 x = self.middle_block(x, t_emb, cond)U-Net ප්රතිදාන අඩක්
175 for module in self.output_blocks:
176 x = torch.cat([x, x_input_block.pop()], dim=1)
177 x = module(x, t_emb, cond)අවසාන සාමාන්යකරණය සහ කැටි කිරීම
180 return self.out(x)මෙම අනුක්රමික මොඩියුලයට විවිධ මොඩියුලයන් උරා බොනnn.Conv
SpatialTransformer
අතර ගැලපෙන අත්සන් සමඟ ඒවා අමතන්නResBlock
183class TimestepEmbedSequential(nn.Sequential):191 def forward(self, x, t_emb, cond=None):
192 for layer in self:
193 if isinstance(layer, ResBlock):
194 x = layer(x, t_emb)
195 elif isinstance(layer, SpatialTransformer):
196 x = layer(x, cond)
197 else:
198 x = layer(x)
199 return x202class UpSample(nn.Module):channels
යනු නාලිකා ගණන207 def __init__(self, channels: int):211 super().__init__()කැටි ගැසීමේ සිතියම්කරණය
213 self.conv = nn.Conv2d(channels, channels, 3, padding=1)x
හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
215 def forward(self, x: torch.Tensor):සාධකයක් අනුව ඉහළ නියැදිය
220 x = F.interpolate(x, scale_factor=2, mode="nearest")කැටි ගැසිම යොදන්න
222 return self.conv(x)225class DownSample(nn.Module):channels
යනු නාලිකා ගණන230 def __init__(self, channels: int):234 super().__init__()ක සාධකයක් විසින් පහළ-නියැදි කිරීමට stride දිග සමග convolution
236 self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)x
හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
238 def forward(self, x: torch.Tensor):කැටි ගැසිම යොදන්න
243 return self.op(x)246class ResBlock(nn.Module):channels
ආදාන නාලිකා ගණනd_t_emb
කාලරාමු කාවැද්දීම් වල ප්රමාණයout_channels
පිටතට ඇති නාලිකා ගණන වේ. `නාලිකාවලට පෙරනිමි.251 def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):257 super().__init__()out_channels
නිශ්චිතව දක්වා නැත
259 if out_channels is None:
260 out_channels = channelsපළමු සාමාන්යකරණය සහ කැටි ගැසිම
263 self.in_layers = nn.Sequential(
264 normalization(channels),
265 nn.SiLU(),
266 nn.Conv2d(channels, out_channels, 3, padding=1),
267 )කාල පියවර කාවැද්දීම්
270 self.emb_layers = nn.Sequential(
271 nn.SiLU(),
272 nn.Linear(d_t_emb, out_channels),
273 )අවසාන කැටි ගැසුණු ස්ථරය
275 self.out_layers = nn.Sequential(
276 normalization(out_channels),
277 nn.SiLU(),
278 nn.Dropout(0.),
279 nn.Conv2d(out_channels, out_channels, 3, padding=1)
280 )channels
අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels
සිතියම්ගත කිරීම
283 if out_channels == channels:
284 self.skip_connection = nn.Identity()
285 else:
286 self.skip_connection = nn.Conv2d(channels, out_channels, 1)x
හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
t_emb
හැඩයේ කාල පියවර කාවැද්දීම් වේ[batch_size, d_t_emb]
288 def forward(self, x: torch.Tensor, t_emb: torch.Tensor):මූලික කැටි ගැසිම
294 h = self.in_layers(x)කාල පියවර කාවැද්දීම්
296 t_emb = self.emb_layers(t_emb).type(h.dtype)කාල පියවර කාවැද්දීම් එකතු කරන්න
298 h = h + t_emb[:, :, None, None]අවසාන කැටි ගැසිම
300 h = self.out_layers(h)මඟ හැරීමේ සම්බන්ධතාවය එක් කරන්න
302 return self.skip_connection(x) + h305class GroupNorm32(nn.GroupNorm):310 def forward(self, x):
311 return super().forward(x.float()).type(x.dtype)314def normalization(channels):320 return GroupNorm32(32, channels)සයිනොසොයිඩල් කාල පියවර කාවැද්දීම් පරීක්ෂා කරන්න
323def _test_time_embeddings():327 import matplotlib.pyplot as plt
328
329 plt.figure(figsize=(15, 5))
330 m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
331 channel_multipliers=[],
332 n_heads=1, tf_layers=1, d_cond=1)
333 te = m.time_step_embedding(torch.arange(0, 1000))
334 plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
335 plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
336 plt.title("Time embeddings")
337 plt.show()341if __name__ == '__main__':
342 _test_time_embeddings()