රොටරිස්ථානීය කාවැද්දීම් (කඹය)

මෙය PyTorch හි රොටරි ස්ථානීය කාවැද්දීම් (කඹය) ක්රියාත්මක කිරීමයි.

රොටරිස්ථානීය කාවැද්දීම් (කඹය) ටෝකන වල ස්ථාන තොරතුරු භ්රමණ අනුකෘතියක් සමඟ කේතනය කරන අතර එය ස්වාභාවිකවම පැහැදිලි සාපේක්ෂ ස්ථාන යැපීමක් ඇතුළත් කරයි.

කුඩාෂේක්ස්පියර් දත්ත කට්ටලය මත කඹය සහිත ට්රාන්ස්ෆෝමර් ආකෘතියක් පුහුණු කිරීම සඳහා පුහුණු කේතය මෙන්න.

View Run

25import torch
26from torch import nn
27
28from labml.logger import inspect
29from labml_nn.transformers.mha import MultiHeadAttention

කඹයමොඩියුලය

රොටරිකේතීකරණය 2D තලය තුළ භ්රමණය වීමෙන් විශේෂාංග යුගල පරිවර්තනය කරයි. එනම්, එය විශේෂාංග යුගල ලෙස සංවිධානය කරයි. සෑම යුගලයක්ම 2D තලයක ඛණ්ඩාංකයක් ලෙස සැලකිය හැකි අතර, කේතීකරණය ටෝකනයේ පිහිටීම අනුව කෝණයකින් එය භ්රමණය වේ.

විශේෂාංගයුගලයක් සඳහා

ස්ථානයේ සිටින ඕනෑම ප්රධානියකුගේ යතුරේ හෝ විමසුමේ ලක්ෂණ දෙකක් වේවා . නැතහොත් සරල බව සඳහා උපකල්පනය කර ඇත්තේ ලක්ෂණ දෙකක් පමණි. එවිට පරිවර්තනය වන්නේ,

නියත කෝණයක් කොහෙද? අනෙක් අංග යුගල ඒ හා සමානව පරිවර්තනය වේ.

අවධානයසාපේක්ෂ වේ

විශේෂාංගයුගලයක් සඳහා, ස්ථාන දෙකක් අතර තිත් නිෂ්පාදන අවධානය ලකුණු කරයි

මෙයින්පෙනී යන්නේ තිත් නිෂ්පාදන අවධානය සඳහා භ්රමණ කේතීකරණ සාපේක්ෂ අවධානයක් ලබා දෙන බවයි.

සියලුමවිශේෂාංග සඳහා

විශේෂාංගයුගල වශයෙන් කාණ්ඩගත කර ඉහත පරිදි හසුරුවනු ලැබේ. ඔවුන් එක් එක් යුගල සඳහා වෙනස් භාවිතා කරයි.

කඩදාසියෝජනා කරන්නේ විශේෂාංග යුගල සඳහා භාවිතා කිරීමයි.

අපිවිශේෂාංගය සමඟ විශේෂාංගය යුගල කරමු . එබැවින් පිහිටීම සඳහා අපි පරිවර්තනය කරමු

කිරීමට

32class RotaryPositionalEmbeddings(nn.Module):
  • d යනු විශේෂාංග ගණන
  • base ගණනය කිරීම සඳහා භාවිතා කරන නියතය
119    def __init__(self, d: int, base: int = 10_000):
124        super().__init__()
125
126        self.base = base
127        self.d = d
128        self.cos_cached = None
129        self.sin_cached = None

හැඹිලි සහ අගයන්

131    def _build_cache(self, x: torch.Tensor):

හැඹිලියදැනටමත් ඉදිකර ඇත්නම් ආපසු යන්න

136        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
137            return

අනුක්රමිකදිග ලබා ගන්න

140        seq_len = x.shape[0]

143        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)

ස්ථානදර්ශක සාදන්න [0, 1, ..., seq_len - 1]

146        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)

ස්ථානදර්ශකයේ නිෂ්පාදිතය ගණනය කරන්න

149        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)

පේළියසඳහා අපට ඇති පරිදි සංයුක්ත කරන්න

153        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

ඒවාහැඹිලිය

156        self.cos_cached = idx_theta2.cos()[:, None, None, :]
157        self.sin_cached = idx_theta2.sin()[:, None, None, :]
159    def _neg_half(self, x: torch.Tensor):

161        d_2 = self.d // 2

ගණනයකරන්න

164        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
  • x යනු යතුරක හිසෙහි ටෙන්සර් හෝ හැඩය සහිත විමසුමකි [seq_len, batch_size, n_heads, d]
166    def forward(self, x: torch.Tensor):

හැඹිලි සහ අගයන්

171        self._build_cache(x)

විශේෂාංගබෙදන්න, අපට භ්රමණ කාවැද්දීම් යෙදීමට තෝරා ගත හැක්කේ අර්ධ විශේෂාංග සමූහයකට පමණි.

174        x_rope, x_pass = x[..., :self.d], x[..., self.d:]

ගණනයකරන්න

178        neg_half_x = self._neg_half(x_rope)

ගණනයකරන්න

සඳහා

190        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])

193        return torch.cat((x_rope, x_pass), dim=-1)

භ්රමණස්ථානීය කාවැද්දීම් සහිත බහු-හිස අවධානය

මුල් ට්රාන්ස්ෆෝමරයෙන් අපි බහු-හිස අවධානයඅභිබවා යමු.

196class RotaryPEMultiHeadAttention(MultiHeadAttention):
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
204        super().__init__(heads, d_model, dropout_prob)

භ්රමණස්ථානීය කාවැද්දීමේ ස්ථර

207        d_rope = int(self.d_k * rope_percentage)
208        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
209        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)

විමසුම්සහ යතුරු අතර ලකුණු ගණනය කරන්න

211    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

කඹයසමඟ තිත් නිෂ්පාදන ගණනය කරන්න

217        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))

සරලඋදාහරණයක් සමඟ කඹය පරීක්ෂා කිරීම

220def _test_rotary():
224    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
225    x = x[:, None, None, :]
226    inspect(x)
227
228    rotary_pe = RotaryPositionalEmbeddings(3)
229    inspect(rotary_pe(x))
230
231
232if __name__ == '__main__':
233    _test_rotary()