අවධානයරහිත ට්රාන්ස්ෆෝමරයක්

මෙය PyTorch කඩදාසි ක්රියාත්මක කිරීමකි අවධානය නිදහස් ට්රාන්ස්ෆෝමර් .

මෙමලිපිය ස්වයං අවධානය ස්තරය නව කාර්යක්ෂම මෙහෙයුමකින් ප්රතිස්ථාපනය කරයි, එය මතක සංකීර්ණතාවයක් ඇත , අනුක්රමයේ දිග කොතැනද ? කාවැද්දීම් වල මානය.

කඩදාසිAFT සහ AFT හඳුන්වා දෙයි AFT සහ AFT-conv. මෙන්න අපි ස්වයංක්රීය ප්රතිගාමී ආකෘතියක් තුළ සමීප ටෝකන කෙරෙහි අවධානය යොමු කරන AFT- දේශීය ක්රියාත්මක කර ඇත්තෙමු.

අවධානයනිදහස් ට්රාන්ස්ෆෝමර්

AFT( MHAසමාන) පළමු විමසුම බවට කාවැද්දීම් පරිවර්තනය , සමග ප්රධාන හා අගය tensors ඉගෙන ගත් බර. එක් එක් ස්ථානය සඳහා ප්රතිදානය පහත සඳහන් මෙහෙයුම සමඟ ගණනය කරනු ලැබේ.

,මූලද්රව්ය-wise ානවන්ත නිෂ්පාදනයක් කොහෙද , -ෙර්ඛීය නොවන වේ (සිග්මෝයිඩ්) හා යුගල-නැණවත් තත්ත්වය අගතීන් උගත් න්යාසය වේ.

මෙයින්අදහස් කරන්නේ අපි අගයන් බරිත සාමාන්යය ගෙන විමසුම මගින් ඒවා ගුණ කරන බවයි. මෙමඟින් MHA අවශ්ය අවධානය යොමු කිරීමේ අනුකෘතිය ගණනය කිරීමේ අවශ්යතාවය ඉවත් කරන අතර එම නිසා මතක අවශ්යතාවය අඩු කරයි.

AFTදේශීය

AFTදේශීය වශයෙන් උගත් යුගල-නැණවත් ස්ථාන අගතීන් පමණක් දේශීයව අදාළ වේ:

,දේශීය කවුළු ප්රමාණය කොහේද?

දේශීයකවුළුවෙන් පිටත වුවද AFT මෙහෙයුම තවමත් වෙනත් ප්රදේශවලින් යතුරු වටිනාකම් යුගල භාවිතා කරයි. දේශීය කවුළුවෙන් පිටත කාවැද්දීම් සම්පූර්ණයෙන්ම නොපෙනෙන දේශීය ට්රාන්ස්ෆෝමර් වලට වඩා මෙය වෙනස් වේ.

AFTදේශීය ආකෘතියක් සඳහා පුහුණු කේතය මෙන්න.

View Run

61from typing import Optional
62
63import torch
64from torch import nn
65
66from labml_helpers.module import Module

AFTදේශීය මෙහෙයුම

කොහෙද,

69class AFTLocal(Module):
  • d_model යනු query , key සහ value දෛශිකවල ඇති ලක්ෂණ ගණන වේ.
  • seq_len වේ
  • local_window_size දේශීය කවුළු ප්රමාණයයි
  • bias සඳහා පරිවර්තනයන් සඳහා නැඹුරුව පරාමිතිය තිබිය යුතුද යන්න , සහ .
88    def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
96        super().__init__()

දේශීයකවුළු ප්රමාණය

99        self.local_window_size = local_window_size

මේවාපරිණාමනය කරයි query , key සහ value දෛශික.

101        self.query = nn.Linear(d_model, d_model, bias=bias)
102        self.key = nn.Linear(d_model, d_model, bias=bias)
103        self.value = nn.Linear(d_model, d_model, bias=bias)

යුගල-නැණවත්ස්ථානීය අගතීන්

105        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)

සඳහාමාස්ක්

107        self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)

සක්‍රීයකිරීම

109        self.activation = nn.Sigmoid()

ප්රතිදානස්ථරය

111        self.output = nn.Linear(d_model, d_model)

දේශීයවෙස් මුහුණ සාදන්න

මෙයවෙස් මුහුණක් නිර්මාණය කරයි

113    @staticmethod
114    def create_local_mask(seq_len, local_window_size):

ඒවාටමුල පුරන්න

130        local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)

ශුන්ය කරන්න

132        local_mask = torch.tril(local_mask, local_window_size - 1)

ශුන්ය කරන්න

134        local_mask = torch.triu(local_mask, -(local_window_size - 1))

137        return local_mask

query , key සහ value විමසුම, යතුරසහ වටිනාකමසඳහා ටෝකන් කාවැද්දීම් එකතු කිරීම ගබඩා කරන ආතතීන් වේ. ඒවායේ හැඩය ඇත [seq_len, batch_size, d_model] .

mask හැඩය ඇති [seq_len, seq_len, batch_size] අතර කණ්ඩායම සඳහා b , ස්ථානයේ විමසුමට ප්රවේශය i තිබේද යන්න mask[i, j, b] දක්වයි ස්ථානයේ ප්රධාන-අගය j .

139    def forward(self, *,
140                query: torch.Tensor,
141                key: torch.Tensor,
142                value: torch.Tensor,
143                mask: Optional[torch.Tensor] = None):

query , key value සහ හැඩය [seq_len, batch_size, d_model]

155        seq_len, _, _ = query.shape
156
157        if mask is not None:

mask හැඩය ඇත [seq_len_q, seq_len_k, batch_size] , එහිදී පළමු මානය විමසුම් මානයක් වේ. විමසුම මානයක් සමාන වේ නම් එය විකාශනය කරනු ඇත.

161            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
162            assert mask.shape[1] == key.shape[0]
163            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

විමසුම, යතුර සහ අගය කාවැද්දීම් පරිවර්තනය කරන්න

166        query = self.query(query)
167        key = self.key(key)
168        value = self.value(value)

ලබාගන්න

වෙස්මුහුණභාවිතා කිරීම

181        pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
182        pos_bias = pos_bias.unsqueeze(-1)
183        pos_bias.masked_fill_(~mask, float('-inf'))

අපිගණනය කර , වෙන වෙනම සහ අනුකෘති ගුණ කිරීමක් කරන්නෙමු. අපි පැහැදිලි කිරීම සඳහා einsum භාවිතා කරමු.

සොෆ්ට්මැක්ස්ගණනය ස්ථාවර කිරීම සඳහා ඝාතකයන් ගණනය කිරීමට පෙර අපි අඩු කරන්නෙමු.

විශාල නම් විශාල වන අතර ගණනය කිරීම අස්ථායී වේ. numerator සහ නිකායකය සිට ඝාතීය ගණනය කිරීමට පෙර නියතයක් අඩු කිරීම අවලංගු වනු ඇත. හා ගණනය ස්ථාවර උදව් විය හැක. එබැවින් අපි ගණනය කිරීම ස්ථාවර කිරීමට අඩු කරමු.

205        max_key = key.max(dim=0, keepdims=True)[0]
206        max_pos_bias = pos_bias.max(dim=1,  keepdims=True)[0]

209        exp_key = torch.exp(key - max_key)

211        exp_pos_bias = torch.exp(pos_bias - max_pos_bias)

සංඛ්යාකොටස

214        num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)

මෙමහරය කොටසක්

216        den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)

ප්රතිදාන

221        y = self.activation(query) * num / den

ප්රතිදානස්ථරය

224        return self.output(y)

දේශීයවෙස් මුහුණ පරීක්ෂා කරන්න

227def _test_local_mask():
231    from labml.logger import inspect
232    inspect(AFTLocal.create_local_mask(10, 4))

236if __name__ == '__main__':
237    _test_local_mask()