මෙය PyTorch හි රොටරි ස්ථානීය කාවැද්දීම් (කඹය) ක්රියාත්මක කිරීමයි.
රොටරිස්ථානීය කාවැද්දීම් (කඹය) ටෝකන වල ස්ථාන තොරතුරු භ්රමණ අනුකෘතියක් සමඟ කේතනය කරන අතර එය ස්වාභාවිකවම පැහැදිලි සාපේක්ෂ ස්ථාන යැපීමක් ඇතුළත් කරයි.
කුඩාෂේක්ස්පියර් දත්ත කට්ටලය මත කඹය සහිත ට්රාන්ස්ෆෝමර් ආකෘතියක් පුහුණු කිරීම සඳහා පුහුණු කේතය මෙන්න.
25import torch
26from torch import nn
27
28from labml.logger import inspect
29from labml_nn.transformers.mha import MultiHeadAttentionරොටරිකේතීකරණය 2D තලය තුළ භ්රමණය වීමෙන් විශේෂාංග යුගල පරිවර්තනය කරයි. එනම්, එය විශේෂාංග යුගල ලෙස සංවිධානය කරයි. සෑම යුගලයක්ම 2D තලයක ඛණ්ඩාංකයක් ලෙස සැලකිය හැකි අතර, කේතීකරණය ටෝකනයේ පිහිටීම අනුව කෝණයකින් එය භ්රමණය වේ.
ස්ථානයේ සිටින ඕනෑම ප්රධානියකුගේ යතුරේ හෝ විමසුමේ ලක්ෂණ දෙකක් වේවා . නැතහොත් සරල බව සඳහා උපකල්පනය කර ඇත්තේ ලක්ෂණ දෙකක් පමණි. එවිට පරිවර්තනය වන්නේ,
නියත කෝණයක් කොහෙද? අනෙක් අංග යුගල ඒ හා සමානව පරිවර්තනය වේ.
විශේෂාංගයුගලයක් සඳහා, ස්ථාන දෙකක් අතර තිත් නිෂ්පාදන අවධානය ලකුණු කරයි
මෙයින්පෙනී යන්නේ තිත් නිෂ්පාදන අවධානය සඳහා භ්රමණ කේතීකරණ සාපේක්ෂ අවධානයක් ලබා දෙන බවයි.
විශේෂාංගයුගල වශයෙන් කාණ්ඩගත කර ඉහත පරිදි හසුරුවනු ලැබේ. ඔවුන් එක් එක් යුගල සඳහා වෙනස් භාවිතා කරයි.
කඩදාසියෝජනා කරන්නේ විශේෂාංග යුගල සඳහා භාවිතා කිරීමයි.
අපිවිශේෂාංගය සමඟ විශේෂාංගය යුගල කරමු . එබැවින් පිහිටීම සඳහා අපි පරිවර්තනය කරමු
කිරීමට
32class RotaryPositionalEmbeddings(nn.Module):d
යනු විශේෂාංග ගණන base
ගණනය කිරීම සඳහා භාවිතා කරන නියතය 119 def __init__(self, d: int, base: int = 10_000):124 super().__init__()
125
126 self.base = base
127 self.d = d
128 self.cos_cached = None
129 self.sin_cached = Noneහැඹිලි සහ අගයන්
131 def _build_cache(self, x: torch.Tensor):හැඹිලියදැනටමත් ඉදිකර ඇත්නම් ආපසු යන්න
136 if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
137 returnඅනුක්රමිකදිග ලබා ගන්න
140 seq_len = x.shape[0]143 theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)ස්ථානදර්ශක සාදන්න [0, 1, ..., seq_len - 1]
146 seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)ස්ථානදර්ශකයේ නිෂ්පාදිතය ගණනය කරන්න
149 idx_theta = torch.einsum('n,d->nd', seq_idx, theta)පේළියසඳහා අපට ඇති පරිදි සංයුක්ත කරන්න
153 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)ඒවාහැඹිලිය
156 self.cos_cached = idx_theta2.cos()[:, None, None, :]
157 self.sin_cached = idx_theta2.sin()[:, None, None, :]159 def _neg_half(self, x: torch.Tensor):161 d_2 = self.d // 2ගණනයකරන්න
164 return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)x
යනු යතුරක හිසෙහි ටෙන්සර් හෝ හැඩය සහිත විමසුමකි [seq_len, batch_size, n_heads, d]
166 def forward(self, x: torch.Tensor):හැඹිලි සහ අගයන්
171 self._build_cache(x)විශේෂාංගබෙදන්න, අපට භ්රමණ කාවැද්දීම් යෙදීමට තෝරා ගත හැක්කේ අර්ධ විශේෂාංග සමූහයකට පමණි.
174 x_rope, x_pass = x[..., :self.d], x[..., self.d:]ගණනයකරන්න
178 neg_half_x = self._neg_half(x_rope)190 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])193 return torch.cat((x_rope, x_pass), dim=-1)196class RotaryPEMultiHeadAttention(MultiHeadAttention):203 def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
204 super().__init__(heads, d_model, dropout_prob)භ්රමණස්ථානීය කාවැද්දීමේ ස්ථර
207 d_rope = int(self.d_k * rope_percentage)
208 self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
209 self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)211 def get_scores(self, query: torch.Tensor, key: torch.Tensor):කඹයසමඟ තිත් නිෂ්පාදන ගණනය කරන්න
217 return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))සරලඋදාහරණයක් සමඟ කඹය පරීක්ෂා කිරීම
220def _test_rotary():224 x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
225 x = x[:, None, None, :]
226 inspect(x)
227
228 rotary_pe = RotaryPositionalEmbeddings(3)
229 inspect(rotary_pe(x))
230
231
232if __name__ == '__main__':
233 _test_rotary()