මෙයකඩදාසි ට්රාන්ස්ෆෝමර්-එක්ස්එල් වෙතින් සාපේක්ෂ බහු-ශීර්ෂ අවධානය ක්රියාත්මක කිරීම: පයිටෝර්ච් හි ස්ථාවර දිග සන්දර්භයකින් ඔබ්බට අවධානය යොමු කරන භාෂා ආකෘති .
16import torch
17from torch import nn
18
19from labml.logger import inspect
20from labml_nn.transformers.mha import MultiHeadAttentionමෙමක්රමය තීරු මගින් අනුකෘතියක පේළිය මාරු කරයි.
ආදානයනම් [[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]
, මාරු කළ ප්රති result ලය වනු [[1, 2 ,3], [0, 4, 5], [9, 0, 7]]
ඇත. ඉතාමැනවින් අපි පහළ ත්රිකෝණය වසං කළ යුතු නමුත් එය අපගේ අරමුණ සඳහා හරි.
23def shift_right(x: torch.Tensor):ශුන්යතීරුවක් සංයුක්ත කරන්න
33 zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
34 x_padded = torch.cat([x, zero_pad], dim=1)අවසානයේසිට අතිරික්ත මූලද්රව්ය නැවත සකස් කර ඉවත් කරන්න
37 x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
38 x = x_padded[:-1].view_as(x)41 return xඅපි බහු-ප්රධාන අවධානය මොඩියුලය අභිබවා යන බැවින් අපට අවශ්ය වන්නේ get_scores
ක්රමය ලිවීමට පමණි.
44class RelativeMultiHeadAttention(MultiHeadAttention):52 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):ලකුණුගණනය කිරීමේදී අපි පැහැදිලිවම එය ඇතුළත් කර ඇති බැවින් රේඛීය පරිවර්තනයන්ට නැඹුරුවක් අවශ්ය නොවේ. කෙසේ වෙතත් පක්ෂග්රාහීව සිටීම අර්ථවත් value
විය හැකිය.
56 super().__init__(heads, d_model, dropout_prob, bias=False)සාපේක්ෂතනතුරු ගණන
59 self.P = 2 ** 12විමසුමටසාපේක්ෂව යතුර සඳහා සාපේක්ෂ ස්ථානීය කාවැද්දීම්. අපට කාවැද්දීම් අවශ්ය වන්නේ යතුරු විමසීමට පෙර හෝ පසුව විය හැකි බැවිනි.
63 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)විමසුමටසාපේක්ෂව යතුර සඳහා සාපේක්ෂ ස්ථානීය කාවැද්දීමේ නැඹුරුව.
65 self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)විමසුමසඳහා ස්ථානීය කාවැද්දීම් විමසුමේ පිහිටුමෙන් ස්වාධීන වේ
67 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)නිරපේක්ෂඅවධානයෙන්
මුල්කාවැද්දීම් වල රේඛීය පරිවර්තනයන් වන අතර නිරපේක්ෂ ස්ථානීය කේතීකරණයේ රේඛීය පරිවර්තනයන් වේ .
විමසුම්තත්ත්වය නොසලකා දී ඇති යතුරක් කෙරෙහි අවධානය යොමු කිරීම සමාන විය යුතු බව ඔවුහු පෙන්වා දෙති. එබැවින් නියතයක් සමඟ ප්රතිස්ථාපනය කරන්න .
දෙවනහා තෙවන කොන්දේසි සඳහා සාපේක්ෂ ස්ථානීය කේතන හඳුන්වා දෙනු ලැබේ. ඒ නිසා හා සමඟ ප්රතිස්ථාපනය වේ.
69 def get_scores(self, query: torch.Tensor, key: torch.Tensor):108 key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]]110 key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]]112 query_pos_bias = self.query_pos_bias[None, None, :, :]117 ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)119 b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)121 d = key_pos_bias[None, :, None, :]ලබාගැනීමට පේළි මාරු කරන්න
124 bd = shift_right(b + d)අමතරතනතුරු ඉවත් කරන්න
126 bd = bd[:, -key.shape[0]:]මුදලආපසු දෙන්න
134 return ac + bd137def _test_shift_right():
138 x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
139 inspect(x)
140 inspect(shift_right(x))
141
142 x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
143 inspect(x[:, :, 0, 0])
144 inspect(shift_right(x)[:, :, 0, 0])
145
146 x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
147 inspect(x[:, :, 0, 0])
148 inspect(shift_right(x)[:, :, 0, 0])
149
150
151if __name__ == '__main__':
152 _test_shift_right()