මෙය ලබා දෙන යූ-නෙට් හි භාවිතා කරන ට්රාන්ස්ෆෝමර් මොඩියුලය ක්රියාත්මක කරයි
අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.
19from typing import Optional
20
21import torch
22import torch.nn.functional as F
23from torch import nn26class SpatialTransformer(nn.Module):channels
විශේෂාංග සිතියමේ නාලිකා ගණනn_heads
අවධානය යොමු ප්රධානීන් සංඛ්යාව වේn_layers
ට්රාන්ස්ෆෝමර් ස්ථර ගණනd_cond
යනු කොන්දේසි සහිත කාවැද්දිවල ප්රමාණයයි31 def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):38 super().__init__()ආරම්භක කණ්ඩායම් සාමාන්යකරණය
40 self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)මූලික කැටි ගැසිම
42 self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)ට්රාන්ස්ෆෝමර් ස්ථර
45 self.transformer_blocks = nn.ModuleList(
46 [BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
47 )අවසාන කැටි ගැසිම
50 self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)x
හැඩයේ විශේෂාංග සිතියමයි[batch_size, channels, height, width]
cond
හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
52 def forward(self, x: torch.Tensor, cond: torch.Tensor):හැඩය ලබා ගන්න[batch_size, channels, height, width]
58 b, c, h, w = x.shapeඅවශේෂ සම්බන්ධතාවය සඳහා
60 x_in = xසාමාන්ය කරන්න
62 x = self.norm(x)මූලික කැටි ගැසිම
64 x = self.proj_in(x)සිට සම්ප්රේෂණය කර නැවත හැඩගස්වා[batch_size, channels, height, width]
ගන්න[batch_size, height * width, channels]
67 x = x.permute(0, 2, 3, 1).view(b, h * w, c)ට්රාන්ස්ෆෝමර් ස්ථර යොදන්න
69 for block in self.transformer_blocks:
70 x = block(x, cond)නැවත හැඩගස්වා සිට සම්ප්රේෂණය[batch_size, height * width, channels]
කරන්න[batch_size, channels, height, width]
73 x = x.view(b, h, w, c).permute(0, 3, 1, 2)අවසාන කැටි ගැසිම
75 x = self.proj_out(x)අවශේෂ එකතු කරන්න
77 return x + x_in80class BasicTransformerBlock(nn.Module):d_model
ආදාන කාවැද්දීමේ ප්රමාණයයිn_heads
අවධානය යොමු ප්රධානීන් සංඛ්යාව වේd_head
අවධානය යොමු හිසෙහි ප්රමාණයයිd_cond
යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයි85 def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):92 super().__init__()ස්වයං අවධානය ස්ථරය හා පෙර-සම්මත ස්ථරය
94 self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
95 self.norm1 = nn.LayerNorm(d_model)හරස් අවධානය ස්ථරය සහ පෙර-සම්මත ස්ථරය
97 self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
98 self.norm2 = nn.LayerNorm(d_model)Feed-ඉදිරි ජාලය සහ පෙර-සම්මත ස්ථරය
100 self.ff = FeedForward(d_model)
101 self.norm3 = nn.LayerNorm(d_model)x
හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]
cond
හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
103 def forward(self, x: torch.Tensor, cond: torch.Tensor):ස්වයං අවධානය
109 x = self.attn1(self.norm1(x)) + xකන්ඩිෂනේෂන් සමඟ හරස් අවධානය
111 x = self.attn2(self.norm2(x), cond=cond) + xFeed-ඉදිරි ජාලය
113 x = self.ff(self.norm3(x)) + x115 return x118class CrossAttention(nn.Module):125 use_flash_attention: bool = Falsed_model
ආදාන කාවැද්දීමේ ප්රමාණයයිn_heads
අවධානය යොමු ප්රධානීන් සංඛ්යාව වේd_head
අවධානය යොමු හිසෙහි ප්රමාණයයිd_cond
යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයිis_inplace
මතකය ඉතිරි කර ගැනීම සඳහා අවධානය softmax ගණනය inplace ඉටු කිරීමට යන්න නියම127 def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):136 super().__init__()
137
138 self.is_inplace = is_inplace
139 self.n_heads = n_heads
140 self.d_head = d_headඅවධානය පරිමාණ සාධකය
143 self.scale = d_head ** -0.5විමසුම්, යතුර සහ අගය සිතියම්
146 d_attn = d_head * n_heads
147 self.to_q = nn.Linear(d_model, d_attn, bias=False)
148 self.to_k = nn.Linear(d_cond, d_attn, bias=False)
149 self.to_v = nn.Linear(d_cond, d_attn, bias=False)අවසාන රේඛීය ස්ථරය
152 self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))සැකසුම ෆ්ලෑෂ් අවධානය. ෆ්ලෑෂ් අවධානය භාවිතා කරනු ලබන්නේ එය ස්ථාපනය කර ඇත්නම් සහCrossAttention.use_flash_attention
එය සකසා ඇත්නම් පමණිTrue
.
157 try:ක්ලෝනකරණය කිරීමෙන් ඔබට ෆ්ලෑෂ් අවධානය ස්ථාපනය කළ හැකිය Github repo, https://github.com/HazyResearch/flash-attention ඉන්පසු ධාවනයpython setup.py install
161 from flash_attn.flash_attention import FlashAttention
162 self.flash = FlashAttention()පරිමාණ තිත් නිෂ්පාදන අවධානය සඳහා පරිමාණය සකසන්න.
164 self.flash.softmax_scale = self.scaleඑය ස්ථාපනය කර නොමැතිNone
නම් සකසන්න
166 except ImportError:
167 self.flash = Nonex
හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]
cond
හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
169 def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):None
අපි ස්වයං අවධානය යොමුcond
කරන්නේ නම්
176 has_cond = cond is not None
177 if not has_cond:
178 cond = xවිමසුම, යතුර සහ අගය දෛශික ලබා ගන්න
181 q = self.to_q(x)
182 k = self.to_k(cond)
183 v = self.to_v(cond)ෆ්ලෑෂ් අවධානය ලබා ගත හැකි නම් සහ හිස ප්රමාණය අඩු හෝ සමාන නම් භාවිතා කරන්න128
186 if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
187 return self.flash_attention(q, k, v)එසේ නොමැති නම්, සාමාන්ය අවධානයට වැටීම
189 else:
190 return self.normal_attention(q, k, v)q
හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
k
හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
v
හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
192 def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):අනුක්රමික අක්ෂය ඔස්සේ කණ්ඩායම් ප්රමාණය සහ මූලද්රව්ය ගණන ලබා ගන්න (width * height
)
202 batch_size, seq_len, _ = q.shapeෆ්ලෑෂ් අවධානය සඳහාv
දෛශිකq
k
, හැඩයේ තනි ටෙන්සරයක් ලබා ගැනීමට[batch_size, seq_len, 3, n_heads * d_head]
206 qkv = torch.stack((q, k, v), dim=2)හිස් බෙදන්න
208 qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)ෆ්ලෑෂ් අවධානය හිස් ප්රමාණ සඳහා ක්රියා කරයි32
128
,64
සහ, එබැවින් මෙම ප්රමාණයට සරිලන පරිදි හිස් පෑඩ් කළ යුතුය.
212 if self.d_head <= 32:
213 pad = 32 - self.d_head
214 elif self.d_head <= 64:
215 pad = 64 - self.d_head
216 elif self.d_head <= 128:
217 pad = 128 - self.d_head
218 else:
219 raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')හිස් පෑඩ් කරන්න
222 if pad:
223 qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)අවධානය ගණනය කරන්න මෙය හැඩයේ ආතතිකයක් ලබා දෙයි[batch_size, seq_len, n_heads, d_padded]
228 out, _ = self.flash(qkv)අමතර හිස ප්රමාණය කපා
230 out = out[:, :, :, :self.d_head]නැවත හැඩගස්වන්න[batch_size, seq_len, n_heads * d_head]
232 out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)රේඛීය ස්ථරයක්[batch_size, height * width, d_model]
සමඟ සිතියම
235 return self.to_out(out)q
හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
k
හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
v
හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
237 def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):හැඩයේ හිස් වලට බෙදන්න[batch_size, seq_len, n_heads, d_head]
247 q = q.view(*q.shape[:2], self.n_heads, -1)
248 k = k.view(*k.shape[:2], self.n_heads, -1)
249 v = v.view(*v.shape[:2], self.n_heads, -1)අවධානය ගණනය කරන්න
252 attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scaleසොෆ්ට්මැක්ස් ගණනය කරන්න
256 if self.is_inplace:
257 half = attn.shape[0] // 2
258 attn[half:] = attn[half:].softmax(dim=-1)
259 attn[:half] = attn[:half].softmax(dim=-1)
260 else:
261 attn = attn.softmax(dim=-1)අවධානය ප්රතිදානය ගණනය
265 out = torch.einsum('bhij,bjhd->bihd', attn, v)නැවත හැඩගස්වන්න[batch_size, height * width, n_heads * d_head]
267 out = out.reshape(*out.shape[:2], -1)රේඛීය ස්ථරයක්[batch_size, height * width, d_model]
සමඟ සිතියම
269 return self.to_out(out)272class FeedForward(nn.Module):d_model
ආදාන කාවැද්දීමේ ප්රමාණයයිd_mult
සැඟවුණු ස්ථර ප්රමාණය සඳහා බහුකාර්ය සාධකයකි277 def __init__(self, d_model: int, d_mult: int = 4):282 super().__init__()
283 self.net = nn.Sequential(
284 GeGLU(d_model, d_model * d_mult),
285 nn.Dropout(0.),
286 nn.Linear(d_model * d_mult, d_model)
287 )289 def forward(self, x: torch.Tensor):
290 return self.net(x)293class GeGLU(nn.Module):300 def __init__(self, d_in: int, d_out: int):
301 super().__init__()ඒකාබද්ධ රේඛීය ප්රක්ෂේපණ සහ
303 self.proj = nn.Linear(d_in, d_out * 2)305 def forward(self, x: torch.Tensor):ලබා ගන්න
307 x, gate = self.proj(x).chunk(2, dim=-1)309 return x * F.gelu(gate)