මෙය PyTorch කඩදාසි ක්රියාත්මක කිරීමකි අවධානය නිදහස් ට්රාන්ස්ෆෝමර් .
මෙමලිපිය ස්වයං අවධානය ස්තරය නව කාර්යක්ෂම මෙහෙයුමකින් ප්රතිස්ථාපනය කරයි, එය මතක සංකීර්ණතාවයක් ඇත , අනුක්රමයේ දිග කොතැනද ? කාවැද්දීම් වල මානය.
කඩදාසිAFT සහ AFT හඳුන්වා දෙයි AFT සහ AFT-conv. මෙන්න අපි ස්වයංක්රීය ප්රතිගාමී ආකෘතියක් තුළ සමීප ටෝකන කෙරෙහි අවධානය යොමු කරන AFT- දේශීය ක්රියාත්මක කර ඇත්තෙමු.
AFT( MHAසමාන) පළමු විමසුම බවට කාවැද්දීම් පරිවර්තනය , සමග ප්රධාන හා අගය tensors ඉගෙන ගත් බර. එක් එක් ස්ථානය සඳහා ප්රතිදානය පහත සඳහන් මෙහෙයුම සමඟ ගණනය කරනු ලැබේ.
,මූලද්රව්ය-wise ානවන්ත නිෂ්පාදනයක් කොහෙද , -ෙර්ඛීය නොවන වේ (සිග්මෝයිඩ්) හා යුගල-නැණවත් තත්ත්වය අගතීන් උගත් න්යාසය වේ.
මෙයින්අදහස් කරන්නේ අපි අගයන් බරිත සාමාන්යය ගෙන විමසුම මගින් ඒවා ගුණ කරන බවයි. මෙමඟින් MHA අවශ්ය අවධානය යොමු කිරීමේ අනුකෘතිය ගණනය කිරීමේ අවශ්යතාවය ඉවත් කරන අතර එම නිසා මතක අවශ්යතාවය අඩු කරයි.
AFTදේශීය වශයෙන් උගත් යුගල-නැණවත් ස්ථාන අගතීන් පමණක් දේශීයව අදාළ වේ:
,දේශීය කවුළු ප්රමාණය කොහේද?
දේශීයකවුළුවෙන් පිටත වුවද AFT මෙහෙයුම තවමත් වෙනත් ප්රදේශවලින් යතුරු වටිනාකම් යුගල භාවිතා කරයි. දේශීය කවුළුවෙන් පිටත කාවැද්දීම් සම්පූර්ණයෙන්ම නොපෙනෙන දේශීය ට්රාන්ස්ෆෝමර් වලට වඩා මෙය වෙනස් වේ.
AFTදේශීය ආකෘතියක් සඳහා පුහුණු කේතය මෙන්න.
61from typing import Optional
62
63import torch
64from torch import nn
65
66from labml_helpers.module import Module69class AFTLocal(Module):d_model
යනු query
, key
සහ value
දෛශිකවල ඇති ලක්ෂණ ගණන වේ. seq_len
වේ local_window_size
දේශීය කවුළු ප්රමාණයයි bias
සඳහා පරිවර්තනයන් සඳහා නැඹුරුව පරාමිතිය තිබිය යුතුද යන්න , සහ . 88 def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):96 super().__init__()දේශීයකවුළු ප්රමාණය
99 self.local_window_size = local_window_sizeමේවාපරිණාමනය කරයි query
, key
සහ value
දෛශික.
101 self.query = nn.Linear(d_model, d_model, bias=bias)
102 self.key = nn.Linear(d_model, d_model, bias=bias)
103 self.value = nn.Linear(d_model, d_model, bias=bias)යුගල-නැණවත්ස්ථානීය අගතීන්
105 self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)සඳහාමාස්ක්
107 self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)සක්රීයකිරීම
109 self.activation = nn.Sigmoid()ප්රතිදානස්ථරය
111 self.output = nn.Linear(d_model, d_model)113 @staticmethod
114 def create_local_mask(seq_len, local_window_size):ඒවාටමුල පුරන්න
130 local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)ශුන්ය කරන්න
132 local_mask = torch.tril(local_mask, local_window_size - 1)ශුන්ය කරන්න
134 local_mask = torch.triu(local_mask, -(local_window_size - 1))137 return local_mask query
, key
සහ value
විමසුම, යතුරසහ වටිනාකමසඳහා ටෝකන් කාවැද්දීම් එකතු කිරීම ගබඩා කරන ආතතීන් වේ. ඒවායේ හැඩය ඇත [seq_len, batch_size, d_model]
.
mask
හැඩය ඇති [seq_len, seq_len, batch_size]
අතර කණ්ඩායම සඳහා b
, ස්ථානයේ විමසුමට ප්රවේශය i
තිබේද යන්න mask[i, j, b]
දක්වයි ස්ථානයේ ප්රධාන-අගය j
.
139 def forward(self, *,
140 query: torch.Tensor,
141 key: torch.Tensor,
142 value: torch.Tensor,
143 mask: Optional[torch.Tensor] = None):query
, key
value
සහ හැඩය [seq_len, batch_size, d_model]
155 seq_len, _, _ = query.shape
156
157 if mask is not None:mask
හැඩය ඇත [seq_len_q, seq_len_k, batch_size]
, එහිදී පළමු මානය විමසුම් මානයක් වේ. විමසුම මානයක් සමාන වේ නම් එය විකාශනය කරනු ඇත.
161 assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
162 assert mask.shape[1] == key.shape[0]
163 assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]විමසුම, යතුර සහ අගය කාවැද්දීම් පරිවර්තනය කරන්න
166 query = self.query(query)
167 key = self.key(key)
168 value = self.value(value)181 pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
182 pos_bias = pos_bias.unsqueeze(-1)
183 pos_bias.masked_fill_(~mask, float('-inf'))අපිගණනය කර , වෙන වෙනම සහ අනුකෘති ගුණ කිරීමක් කරන්නෙමු. අපි පැහැදිලි කිරීම සඳහා einsum භාවිතා කරමු.
සොෆ්ට්මැක්ස්ගණනය ස්ථාවර කිරීම සඳහා ඝාතකයන් ගණනය කිරීමට පෙර අපි අඩු කරන්නෙමු.
විශාල නම් විශාල වන අතර ගණනය කිරීම අස්ථායී වේ. numerator සහ නිකායකය සිට ඝාතීය ගණනය කිරීමට පෙර නියතයක් අඩු කිරීම අවලංගු වනු ඇත. හා ගණනය ස්ථාවර උදව් විය හැක. එබැවින් අපි ගණනය කිරීම ස්ථාවර කිරීමට අඩු කරමු.
205 max_key = key.max(dim=0, keepdims=True)[0]
206 max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]209 exp_key = torch.exp(key - max_key)211 exp_pos_bias = torch.exp(pos_bias - max_pos_bias)සංඛ්යාකොටස
214 num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)මෙමහරය කොටසක්
216 den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)ප්රතිදාන
221 y = self.activation(query) * num / denප්රතිදානස්ථරය
224 return self.output(y)දේශීයවෙස් මුහුණ පරීක්ෂා කරන්න
227def _test_local_mask():231 from labml.logger import inspect
232 inspect(AFTLocal.create_local_mask(10, 4))236if __name__ == '__main__':
237 _test_local_mask()