මෙය PyTorch ක්රියාත්මක කිරීම කඩදාසි ධූරාවලි ට්රාන්ස්ෆෝමර් වඩාත් කාර්යක්ෂම භාෂා ආකෘති වේ.
දිගුඅනුපිළිවෙලවල් කාර්යක්ෂමව හැසිරවීමට මෙම ලිපිය ධූරාවලි ට්රාන්ස්ෆෝමර් ගෘහ නිර්මාණ ශිල්පයක් හඳුන්වා දෙයි. ට්රාන්ස්ෆෝමර් ස්ථර වල පළමු භාගය පහළට නියැදි ටෝකන සහ දෙවන භාගය එකම විභේදනයේ ස්ථර අතර සෘජු මඟ හැරීමේ සම්බන්ධතා සහිත සාම්පල. දර්ශන කාර්යයන් සඳහා මෙය යූ-නෙට් වලට ටිකක් සමාන ය.
ඔවුන්විවිධ ඉහළ නියැදීම් සහ පහළ-නියැදි ශිල්පීය ක්රම උත්සාහ කරන අතර ඔවුන් පැය වීදුරු ආකෘතිය ලෙස හඳුන්වන හොඳම ක්රියාකාරී සහ පහළ නියැදි ශිල්පීය ක්රම සහිත ආකෘතියක් ගොඩනඟයි.
මෙන්නඅපි සරල බව සඳහා සරලම ඉහළ නියැදීම් සහ පහළ නියැදි ක්රම ක්රියාත්මක කර ඇත. පසුව වඩාත් සංකීර්ණ (හා වඩා හොඳ ක්රියාකාරී) ක්රියාත්මක කිරීම් එකතු කිරීම අපි සලකා බලමු.
පැයවීදුරු ආකෘතිය සඳහා පුහුණු කේතය මෙන්න.
30from typing import List
31
32import torch
33from torch import nn
34
35from labml_helpers.module import Module
36from labml_nn.transformers import MultiHeadAttention, TransformerLayer
37from labml_nn.transformers.feed_forward import FeedForward
38from labml_nn.transformers.utils import subsequent_maskමෙමආකෘතිය නැවත නැවතත් ස්ථර මැදට එකතු කරන අතර පහළට නියැදීමෙන් අනුක්රමය කෙටි කරයි. තවත් පැයක ආකෘතියකින් සැකසූ කෙටි අනුක්රමය සාමාන්ය ට්රාන්ස්ෆෝමර් ස්ථර දෙකක් අතර සැන්ඩ්විච් කර ඇත. (ට්රාන්ස්ෆෝමර් ස්තරය ස්වයං අවධානය යොමු කරන තට්ටුවක් සහ ස්ථාන-නැණවත් පෝෂක-ඉදිරි ස්ථරයක් ඇත).
41class HourGlass(Module):n_heads
යනු බහු-හිස අවධානය යොමු කරන ස්ථරවල හිස් ගණන d_model
ටෝකන් කාවැද්දීම් වල ප්රමාණයයි dropout
අතහැර දැමීමේ සම්භාවිතාව d_ff
ස්ථාන-නැණවත් පෝෂක-ඉදිරි ස්ථර වල සැඟවුණු ස්ථරයේ මානයන් වේ shortening_factors
කෙටි කිරීමේ සාධක ලැයිස්තුවයි51 def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):59 super().__init__()පහළ-නියැදීම්පෙර ට්රාන්ස්ෆෝමර් ස්ථරය
62 self.pre = TransformerLayer(d_model=d_model,64 self_attn=MultiHeadAttention(n_heads, d_model, dropout),66 feed_forward=FeedForward(d_model, d_ff, dropout),68 dropout_prob=dropout)ස්වයංක්රීය-ප්රතිගාමීවෙස්
70 self.mask = AutoregressiveMask()කෙටිකිරීමේ සාධකය (හෝ පහළ-නියැදි අනුපාතය)
73 k = shortening_factors[0]අනාගතටෝකනවල සිට අතීත ටෝකන වෙත තොරතුරු කාන්දු නොවන බවට වග බලා ගැනීම සඳහා අපි ටෝකන දකුණට මාරු කරමු.
78 self.shift_right = ShiftRight(k - 1)කෙටිකිරීම හෝ පහළ-නියැදි ස්තරය. අපි සරලම ආකෘතිය භාවිතා කරමු - සාමාන්ය තටාක. කඩදාසි පෙන්නුම් කරන්නේ අවධානය පදනම් කරගත් නියැදීම් අප තවමත් ක්රියාත්මක කර නොමැති හොඳම ක්රියා කරන බවයි.
81 self.shortening = AvgPoolShortening(k)තවත්කෙටි කිරීමක් නොමැති නම් (පැය වීදුරුව මැද)
84 if len(shortening_factors) == 1:මැදස්ථරය තවත් ට්රාන්ස්ෆෝමර් ස්ථරයකි
86 self.shortened = TransformerLayer(d_model=d_model,
87 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
88 feed_forward=FeedForward(d_model, d_ff, dropout),
89 dropout_prob=dropout)ස්වයංක්රීයආවරණ
91 self.mask_short = AutoregressiveMask()
92 self.hour_glass = None
93 else:තවත්පැය වීදුරු ආකෘතියක් නැවත නැවත ඇතුල් කරන්න
95 self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])Up-නියැදිස්ථරය. අපි සරල බව සඳහා බොළඳ ඉහළ නියැදීම් භාවිතා කරන අතර කඩදාසි පෙන්නුම් කරන්නේ නියැදීම් පදනම් කරගත් අවධානය වඩා හොඳින් ක්රියාත්මක වන බවයි.
99 self.up_sampling = NaiveUpSampling(k)ඉහළ-නියැදීම්පසු අවසන් ට්රාන්ස්ෆෝමර් ස්ථරය
102 self.post = TransformerLayer(d_model=d_model,
103 self_attn=MultiHeadAttention(n_heads, d_model, dropout),
104 feed_forward=FeedForward(d_model, d_ff, dropout),
105 dropout_prob=dropout)107 def forward(self, x: torch.Tensor):ආරම්භකට්රාන්ස්ෆෝමර් ස්ථරය
110 x = self.pre(x=x, mask=self.mask(x))මාරුකිරීම සහ කෙටි කිරීම
113 x_short = self.shortening(self.shift_right(x))අපිපැය වීදුරුවේ කේන්ද්රයේ සිටින්නේ නම්,
117 if self.hour_glass is None:මධ්යස්ථානයස්ථරය ට්රාන්ස්ෆෝමර්
120 x_short = self.shortened(x=x_short, mask=self.mask_short(x_short))122 else:124 x_short = self.hour_glass(x_short)කෙටිකළ අනුපිළිවෙල ඉහළට සාම්පල කර මඟ හැරීමේ සම්බන්ධතාවයක් එක් කරන්න
128 x = x + self.up_sampling(x, x_short)අවසන්ට්රාන්ස්ෆෝමර් ස්ථරය
131 x = self.post(x=x, mask=self.mask(x))134 return x137class ShiftRight(Module):shift
විසින් මාරු කිරීමට පියවර ගණන වේ144 def __init__(self, shift: int):148 super().__init__()ඍණාත්මකවිය නොහැක
150 assert shift >= 0152 self.shift = shiftx
හැඩයේ ආතතිකාරයකි [seq_len, ...]
154 def forward(self, x: torch.Tensor):මාරුවමුල් පිටපත ආපසු ලබා දෙන්නේ නම්
159 if self.shift == 0:
160 return xZerosවමට ඇප්පෙන්ඩ් කළ යුතුය
162 prefix = x.new_zeros([self.shift, *x.shape[1:]])ශුන්යයසංයුක්ත කර අයිතිය ටන්ක කරන්න
164 return torch.cat([prefix, x[:-self.shift]])167class AvgPoolShortening(Module):k
කෙටි කිරීමේ සාධකයයි174 def __init__(self, k: int):178 super().__init__()සාමාන්යතටාක ස්ථරය
180 self.pool = nn.AvgPool1d(k, ceil_mode=True)x
හැඩයෙන් යුක්ත වේ [seq_len, batch_size, d_model]
182 def forward(self, x: torch.Tensor):තටාකස්ථරය හැඩය පිළිගන්නා [batch_size, d_model, seq_len]
බැවින් අපි අක්ෂ permute.
188 return self.pool(x.permute(1, 2, 0)).permute(2, 0, 1)191class NaiveUpSampling(Module):k
කෙටි කිරීමේ සාධකයයි198 def __init__(self, k: int):202 super().__init__()
203 self.k = kx
පහළ-නියැදීම් පෙර කාවැද්දීම් සමග tensor වේ x_short
යනු ඉහළ dens නත්වයේ ආතතියයි (ඉහළට සාම්පල කළ යුතු) නිරූපණයකි205 def forward(self, x: torch.Tensor, x_short: torch.Tensor):අනුක්රමිකමානය හරහා නැවත කරන්න
211 expanded = torch.repeat_interleave(x_short, self.k, dim=0)අවසානයේඅමතර කාවැද්දීම් ඉවත් කරන්න
213 expanded = expanded[:x.shape[0]]216 return expanded219class AutoregressiveMask(Module):224 def __init__(self):
225 super().__init__()
226 self.mask = None228 def forward(self, x: torch.Tensor):අපවිසින් නිර්මාණය කර නොමැති නම් හෝ ප්රමාණ වෙනස් වී ඇත්නම් වෙස්මුහුණක් සාදන්න
230 if self.mask is None or self.mask.size(0) != len(x):පසුකාලීන වෙස්මුහුණ, අනාගත ටෝකන දැකීමෙන් ටෝකන වසං කරනු ඇත
232 self.mask = subsequent_mask(len(x)).to(x.device)235 return self.maskමෙයඒකාබද්ධ කළ යුතු අනුයාත ටෝකන කාවැද්දීම් වලට අනුකූල වන අතර එය තනි ටෝකන කාවැද්දීමේ ප්රමාණයට සිතියම් ගත කිරීම සඳහා රේඛීය පරිවර්තනයක් සිදු කරයි.
238class LinearPoolingShortening(Module):246 def __init__(self):
247 super().__init__()
248 raise NotImplementedError251class AttentionBasedShortening(Module):263 def __init__(self):
264 super().__init__()
265 raise NotImplementedError268class LinearUpSampling(Module):275 def __init__(self):
276 super().__init__()
277 raise NotImplementedError280class AttentionBasedUpSampling(Module):292 def __init__(self):
293 super().__init__()
294 raise NotImplementedError