රෙට්රෝආකෘතිය

RETROසඳහා ආදර්ශ අර්ථ දැක්වීම මෙයයි.

View Run

16import math
17from typing import Set
18
19import torch
20from torch import nn
21
22from labml.logger import inspect

කඹය කාවැද්දීම්

අපිස්වයං අවධානය ස්ථර භමණ තත්ත්වය කාවැද්දීම් භාවිතා කරන්න. ස්ථානීය තොරතුරු කාවැද්දීම් තුළට කාවැදී ඇති අතර එම නිසා ඒවා පොදු අවධානයට ලක් නොකරමු. හේතු නොවන ස්වයං අවධානයට පැහැදිලි ස්ථානීය තොරතුරු අවශ්ය වන්නේ එයට අනුමාන කළ නොහැකි බැවිනි.

25class RotaryPositionalEmbeddings(nn.Module):
  • d යනු විශේෂාංග ගණන
  • base ගණනය කිරීම සඳහා භාවිතා කරන නියතය
36    def __init__(self, d: int, base: int = 10_000):
41        super().__init__()

43        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
  • x යනු යතුරක හිසෙහි ටෙන්සර් හෝ හැඩය සහිත විමසුමකි [ batch_size, seq_len, n_heads, d]
45    def forward(self, x: torch.Tensor):

හැඩයඋපුටා ගන්න

50        batch_size, seq_len, n_heads, d = x.shape

53        d_2 = d // 2

ස්ථානදර්ශක සාදන්න [0, 1, ..., seq_len - 1]

56        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

ස්ථානදර්ශකයේ නිෂ්පාදිතය ගණනය කරන්න

59        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)

පේළියසඳහා අපට ඇති පරිදි සංයුක්ත කරන්න

63        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

ගණනයකරන්න

67        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

ගණනයකරන්න

සඳහා

79        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])

82        return rx

ස්වයංඅවධානය ස්ථරය

මෙයහේතු සහ හේතු නොවන බහු-හිස සහිත ස්වයං අවධානයඅදාළ වේ.

85class SelfAttention(nn.Module):
  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_k යනු හිසකට ඇති ලක්ෂණ ගණන
  • is_causal මෙය හේතුකාරක අවධානය (මැස්සෙඩ්) යන්න පෙන්නුම් කරයි
92    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
99        super().__init__()
100
101        self.is_causal = is_causal
102        self.n_heads = n_heads
103        self.d_k = d_k

සොෆ්ට්මැක්ස්වලට පෙර අවධානය පරිමාණය කිරීම

106        self.scale = 1 / math.sqrt(self.d_k)

විමසුම, යතුරු සහ අගය හිස් සඳහා රේඛීය ස්ථර.

109        self.query = nn.Linear(d_model, n_heads * d_k)
110        self.key = nn.Linear(d_model, n_heads * d_k)
111        self.value = nn.Linear(d_model, n_heads * d_k)

පූර්වසම්මත ස්තරය. කඩදාසි වෙනුවට RMSNorm භාවිතා කරයි.

114        self.norm = nn.LayerNorm(d_model)

අවධානයසම්භාවිතාව සඳහා සොෆ්ට්මැක්ස්

117        self.softmax = nn.Softmax(dim=-1)

රොටරිස්ථානීය කාවැද්දීම්

120        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)

අවසානරේඛීය ස්ථරය

123        self.output = nn.Linear(n_heads * d_k, d_model)

හේතුකාරකඅවධානය සඳහා අවධානය යොමු කිරීමේ ස්තරය ආවරණය කරන්න

  • attn හැඩයේ අවධානය යොමු කිරීමේ අනුකෘතියකි [batch_size, n_heads, seq_len, seq_len]
125    def mask_attention(self, attn: torch.Tensor):

හේතුනොවන අවධානය සඳහා ආවරණ නොමැත

133        if not self.is_causal:
134            return attn

ත්රිකෝණාකාරවෙස් මුහුණක් සාදන්න

137        mask = torch.tril(attn.new_ones(attn.shape[-2:]))

වෙස්මුහුණමගින් පෙරහන් කරන්න

139        return attn.masked_fill(mask == 0, float('-inf'))
  • h යනු හැඩයේ ට්රාන්ස්ෆෝමර් කාවැද්දීම් වේ [batch_size, seq_len, d_model]
141    def forward(self, h: torch.Tensor):

අවශේෂසම්බන්ධතාවය

147        h_res = h

පූර්වසාමාන්යකරණය

150        h = self.norm(h)

විමසුම, යතුර සහ අගයන් ලබා ගෙන ඒවා හිස් වලට බෙදන්න. මේවාට හැඩයන් ඇත [batch_size, seq_len, n_heads, d_k]

154        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
155        q = self.query(h).view(mh_shape)
156        k = self.key(h).view(mh_shape)
157        v = self.value(h).view(mh_shape)

භ්රමණස්ථානීය කාවැද්දීම් යොදන්න

160        q = self.rotary_pe(q)
161        k = self.rotary_pe(k)

අවධානයගණනය කරන්න

164        attn = torch.einsum('bihd,bjhd->bhij', q, k)

විසින්එය පරිමාණය

166        attn = attn * self.scale

එයහේතු අවධානයක් නම් වෙස් මුහුණු යොදන්න

169        attn = self.mask_attention(attn)

අවධානයසම්භාවිතාව ගණනය කරන්න

172        attn = self.softmax(attn)

වටිනාකම්ලබා ගන්න

175        h = torch.einsum("bhij,bjhd->bihd", attn, v)

හැඩයෙන්වෙනස් [batch_size, seq_len, n_heads, d_k] කරන්න [batch_size, seq_len, n_heads * d_k]

179        h = h.reshape(*h.shape[:-2], -1)

අවසානරේඛීය ස්ථරය යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, seq_len, d_model]

183        h = self.output(h)

අවශේෂසම්බන්ධතාවය එක් කරන්න

186        return h + h_res

හරස්අවධානය ස්ථරය

මෙයඉහත අර්ථ දක්වා ඇති ස්වයං අවධානය ස්ථරයට සමාන වේ, එය විමසුම් වලට වඩා වෙනස් කාවැද්දීම් කට්ටලයකින් යතුරු සහ අගයන් ලබා ගනී.

ආදානකුට්ටි මත පදනම්ව නැවත ලබා ගත් කුට්ටි කේතනය කිරීම සඳහා මෙය එන්කෝඩරයේ භාවිතා වේ.

අපිමෙහි කිසිදු පැහැදිලි ස්ථානීය කාවැද්දීමක් භාවිතා නොකරමු. ආකෘතියට කාවැද්දීම් වල ස්ථානීය තොරතුරු ව්යංගයෙන් නියෝජනය කළ හැකි යැයි අපි උපකල්පනය කරමු.

189class CrossAttention(nn.Module):
  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_k යනු හිසකට ඇති ලක්ෂණ ගණන
203    def __init__(self, d_model: int, n_heads: int, d_k: int):
209        super().__init__()
210
211        self.n_heads = n_heads
212        self.d_k = d_k

සොෆ්ට්මැක්ස්වලට පෙර අවධානය පරිමාණය කිරීම

215        self.scale = 1 / math.sqrt(self.d_k)

විමසුම, යතුරු සහ අගය හිස් සඳහා රේඛීය ස්ථර.

218        self.query = nn.Linear(d_model, n_heads * d_k)
219        self.key = nn.Linear(d_model, n_heads * d_k)
220        self.value = nn.Linear(d_model, n_heads * d_k)

විමසුම්කාවැද්දීම් සඳහා පූර්ව සම්මත ස්තරය. කඩදාසි වෙනුවට RMSNorm භාවිතා කරයි.

223        self.norm = nn.LayerNorm(d_model)

අවධානයසම්භාවිතාව සඳහා සොෆ්ට්මැක්ස්

226        self.softmax = nn.Softmax(dim=-1)

අවසානරේඛීය ස්ථරය

229        self.output = nn.Linear(n_heads * d_k, d_model)
  • e හැඩයෙන් යුත් ළඟම අසල්වැසියාගේ කුට්ටිය කාවැද්දීම් ලබා ගත හැකිය [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h ආසන්නතම අසල්වැසියන් හැඩයෙන් ලබා ගන්නා ලද ආදාන කුට්ටි [batch_size, chunks, chunk_len, d_model] වේ. මෙය දැනටමත් සාමාන්යකරණය වී ඇත.
231    def forward(self, e: torch.Tensor, h: torch.Tensor):

අවශේෂසම්බන්ධතාවය

240        e_res = e

ලබාගත් කුට්ටි සාමාන්යකරණය කරන්න

243        e = self.norm(e)

ලබාගත් කුට්ටි වලින් විමසුම ලබා ගන්න

246        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)

ආදානකුට්ටි වලින් යතුරු සහ අගයන් ලබා ගන්න

248        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
249        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)

සියලුමකුට්ටි සඳහා අවධානය ලකුණු ගණනය කරන්න. ලබා ගත් සෑම අසල්වැසියෙකුම එය නැවත ලබා ගත් මුල් කුට්ටිය කෙරෙහි අවධානය යොමු කරනු ඇත. මෙම හැඩය ඇත [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]

254        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)

පරිමාණඅවධානය ලකුණු

256        attn = attn * self.scale

අවසානමානය හරහා සොෆ්ට්මැක්ස් ගණනය කරන්න

259        attn = self.softmax(attn)

වටිනාකම්එකතු කරන්න

262        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)

හැඩයෙන්වෙනස් [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] කරන්න [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]

266        e = e.reshape(*e.shape[:-2], -1)

අවසානරේඛීය ස්ථරය යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, chunks, neighbors, neighbor_len, d_model]

270        e = self.output(e)

අවශේෂසම්බන්ධතාවය එක් කරන්න

273        return e + e_res

තැළුණුහරස් අවධානය ස්ථරය

මෙයඉහත අර්ථ දක්වා ඇති හරස් අවධානය ස්ථරයට සමාන වේ.

මෙයනැවත ලබා ගත් අසල්වැසියා කුට්ටි වෙත අවධානය යොමු කිරීම සඳහා විකේතකය තුළ භාවිතා වේ.

අපිමෙහි කිසිදු පැහැදිලි ස්ථානීය කාවැද්දීමක් භාවිතා නොකරමු. ආකෘතියට කාවැද්දීම් වල ස්ථානීය තොරතුරු ව්යංගයෙන් නියෝජනය කළ හැකි යැයි අපි උපකල්පනය කරමු.

276class ChunkedCrossAttention(nn.Module):
  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_k යනු හිසකට ඇති ලක්ෂණ ගණන
  • chunk_len යනු කුට්ටියක දිග
288    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
296        super().__init__()
297
298        self.chunk_len = chunk_len
299        self.n_heads = n_heads
300        self.d_k = d_k

සොෆ්ට්මැක්ස්වලට පෙර අවධානය පරිමාණය කිරීම

303        self.scale = 1 / math.sqrt(self.d_k)

විමසුම, යතුරු සහ අගය හිස් සඳහා රේඛීය ස්ථර.

306        self.query = nn.Linear(d_model, n_heads * d_k)
307        self.key = nn.Linear(d_model, n_heads * d_k)
308        self.value = nn.Linear(d_model, n_heads * d_k)

විමසුම්කාවැද්දීම් සඳහා පූර්ව සම්මත ස්තරය. කඩදාසි වෙනුවට RMSNorm භාවිතා කරයි.

311        self.norm = nn.LayerNorm(d_model)

අවධානයසම්භාවිතාව සඳහා සොෆ්ට්මැක්ස්

314        self.softmax = nn.Softmax(dim=-1)

අවසානරේඛීය ස්ථරය

317        self.output = nn.Linear(n_heads * d_k, d_model)

h හැඩයේ ආදාන [batch_size, seq_len, d_model] e කාවැද්දීම් යනු හැඩයේ ආසන්නතම අසල්වැසියන් ලබා ගත හැකිය [batch_size, chunks, neighbors, neighbor_len, d_model]

319    def forward(self, h: torch.Tensor, e: torch.Tensor):

හැඩයලබා ගන්න

326        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

කුට්ටිනොමැති නම් අවධානය යොමු නොකරයි (නියැදීමේදී කෙටි යෙදවුම් සඳහා)

329        if chunks == 0:
330            return h

අවශේෂසම්බන්ධතාවය

333        h_res = h

පළමු chunk_len - 1 කාවැද්දීම් ඉවත් කරන්න. අතීත ටෝකන භාවිතයෙන් පමණක් ලබා ගත් සහ කේතනය කර ඇති අසල්වැසියන් කෙරෙහි ආදානය අවධානය යොමු කරයි; එවිට තොරතුරු කාන්දු වීමක් සිදු නොවේ. පළමු කුට්ටියේ සිට ලබාගත් අසල්වැසියන් පළමු කුට්ටියෙන් තොරතුරු ලැබෙනු ඇත. එබැවින් අනුක්රමය වමට මාරු කිරීමෙන් තොරතුරු පමණක් දකුණට ගලා යන බවට chunk_len - 1 අපි වග බලා ගනිමු.

341        h = h[:, self.chunk_len - 1:]

පූර්වසම්මතය

343        h = self.norm(h)

ආදානයකුට්ටි වලට බෙදීමට හැකිවන පරිදි හිස් කාවැද්දීම් අවසානය දක්වා එක් කරන්න

345        if h.shape[1] < chunks * self.chunk_len:
346            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)

ආදානයකුට්ටි බවට නැවත සකස් කරන්න.

348        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)

ආදානයෙන්විමසුම ලබා ගන්න

351        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)

ලබාගත් අසල්වැසියන්ගෙන් යතුරු සහ වටිනාකම් ලබා ගන්න

353        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
354        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)

ආදානකුට්ටි සඳහා අවධානය ලකුණු ගණනය කරන්න. සෑම කුට්ටියක්ම කලින් කුට්ටිය විසින් ලබා ගන්නා ලද අසල්වැසියන් කෙරෙහි අවධානය යොමු කරනු ඇත. මෙම හැඩය ඇත [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]

359        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)

පරිමාණඅවධානය ලකුණු

361        attn = attn * self.scale

අවසානමානයන් දෙකට වඩා සොෆ්ට්මැක්ස් යොදන්න neighbors, neighbor_len

364        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)

වටිනාකම්එකතු කරන්න

367        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)

හැඩයෙන්වෙනස් [batch_size, chunks, chunk_len, n_heads, d_k] කරන්න [batch_size, chunks * chunk_len, n_heads * d_k]

371        h = h.reshape(batch_size, chunks * self.chunk_len, -1)

අවසානරේඛීය ස්ථරය යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, chunks * chunk_len, d_model]

375        h = self.output(h)

වමට chunk_len - 1 ශුන්ය කාවැද්දීම එක් කරන්න; එනම් දකුණු එය ආපසු මාරු කරන්න

378        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)

අවශේෂසම්බන්ධතාවය ඉවත් කර එකතු කරන්න

381        return h[:, :h_res.shape[1]] + h_res

ස්ථාන-නැණවත්පෝෂණය ඉදිරි ස්ථරය

මෙයරේඛීය ස්ථර දෙකක් සහ මැද සක්රිය කිරීමකින් සමන්විත වේ.

384class FeedForward(nn.Module):
  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ
  • d_ff සැඟවුණු ස්ථරයේ අංක ලක්ෂණ වේ
  • 391    def __init__(self, d_model: int, d_ff: int):
    397        super().__init__()

    රේඛීයස්ථර දෙක

    400        self.lin1 = nn.Linear(d_model, d_ff)
    401        self.lin2 = nn.Linear(d_ff, d_model)

    Reluසක්රිය කිරීම

    404        self.act = nn.ReLU()

    පෙර-සම්මතස්තරය

    407        self.norm = nn.LayerNorm(d_model)

    h හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, d_model]

    409    def forward(self, h: torch.Tensor):

    අවශේෂ

    415        h_res = h

    පූර්වසම්මතය

    417        h = self.norm(h)

    පළමුරේඛීය ස්ථරය

    419        h = self.lin1(h)

    සක්‍රීයකිරීම

    421        h = self.act(h)

    දෙවනරේඛීය ස්ථරය

    423        h = self.lin2(h)

    අවශේෂසම්බන්ධතාවය එක් කරන්න

    426        return h + h_res

    ළඟමඅසල්වැසි ආකේතකය

    මෙමමොඩියුලය ලබා ගත් ආසන්නතම අසල්වැසියන් සංකේතවත් කරයි

    429class NearestNeighborEncoder(nn.Module):
    • chunk_len යනු කුට්ටියක දිග
    • n_layer එන්කෝඩරයේ ස්ථර ගණන
    • ca_layers හරස් අවධානය ඇති ස්ථර වේ
    • d_model යනු කාවැද්දීම් වල විශේෂාංග ගණන
    • n_heads අවධානය යොමු ස්ථර වල හිස් ගණන
    • d_k අවධානය යොමු ප්රධානීන් ප්රමාණය
    • d_ff යනු පෝෂක ඉදිරි ජාලයේ සැඟවුණු ස්ථර වල ප්රමාණයයි
    436    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
    437                 d_model: int, n_heads: int, d_k: int, d_ff: int):
    448        super().__init__()
    449        self.ca_layers = ca_layers
    450        self.chunk_len = chunk_len

    හරස්අවධානය ස්ථර

    452        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])

    ද්වි-දිශානුගතස්වයං අවධානය ස්ථර

    454        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])

    ඉදිරිස්ථර පෝෂණය කරන්න

    456        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

    සඳහාපූර්ව සාමාන්යකරණ ස්තරය

    459        self.norm_h = nn.LayerNorm(d_model)
    • e ලබා ගත් ළඟම අසල්වැසියන්ගේ ටෝකන් කාවැද්දීම්, හැඩයෙන් [batch_size, chunks, neighbors, neighbor_len, d_model]
    • h යනු ආදාන ටෝකන කාවැද්දීම්, හැඩයෙන් [batch_size, seq_len, d_model]

    කුට්ටි සහ අසල්වාසීන් සමාන්තරව සකස් කරනු ලැබේ.

    461    def forward(self, e: torch.Tensor, h: torch.Tensor):

    හැඩයලබා ගන්න

    474        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

    477        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)

    පූර්වසම්මතය

    480        h_split = self.norm_h(h_split)

    හරස්අවධානය ස්ථරයේ දර්ශකය තබා ගන්න

    483        p_ca = 0

    සියලුමස්ථර සඳහා

    485        for p in range(len(self.attn)):

    ද්වි-දිශානුගතස්වයං අවධානය

    488            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)

    හරස්අවධානය යොමු කරන්නේ නම්

    491            if p in self.ca_layers:

    493                e = self.ca[p_ca](e, h_split)

    හරස්අවධානය දර්ශකය වැඩි කරන්න

    495                p_ca += 1

    ඉදිරිස්ථරය පෝෂණය කරන්න

    498            e = self.ffw[p](e)

    ආපසු

    501        return e

    රෙට්රෝආකෘතිය

    මෙයරෙට්රෝ විකේතකය

    504class RetroModel(nn.Module):
    • v_vocab යනු වචන මාලාවේ ටෝකන ගණන
    • d_model යනු කාවැද්දීම් වල විශේෂාංග ගණන
    • n_layers යනු විකේතකයේ ස්ථර ගණන
    • ca_layers හරස් අවධානය ඇති ස්ථර වේ
    • chunk_len යනු කුට්ටියක දිග
    • n_heads අවධානය යොමු ස්ථර වල හිස් ගණන
    • d_k අවධානය යොමු ප්රධානීන් ප්රමාණය
    • d_ff යනු පෝෂක ඉදිරි ජාලයේ සැඟවුණු ස්ථර වල ප්රමාණයයි
    • encoder ළඟම අසල්වැසියා එන්කෝඩරයයි
    511    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
    512                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
    524        super().__init__()
    525
    526        self.ca_layers = ca_layers
    527        self.encoder = encoder

    ටෝකන්කාවැද්දීම ස්ථරය

    530        self.emb = nn.Embedding(n_vocab, d_model)

    කපනලද හරස් අවධානය ස්ථර

    532        self.cca = nn.ModuleList(
    533            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])

    අවධානයස්ථර

    535        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])

    ඉදිරිස්ථර පෝෂණය කරන්න

    537        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

    කියවීමේස්ථරය

    539        self.read = nn.Linear(d_model, n_vocab)

    ආසන්නතමඅසල්වැසියාගේ කාවැද්දීම් සඳහා පූර්ව සාමාන්යකරණ ස්තරය

    543        self.norm_e = nn.LayerNorm(d_model)
    • x ආදාන අනුක්රමය, හැඩයෙන් [batch_size, seq_len]
    • ret හැඩයෙන් ලබා ගත් අසල්වැසියන් වේ [batch_size, chunks, neighbors, neighbor_len]
    545    def forward(self, x: torch.Tensor, ret: torch.Tensor):

    ආදානකාවැද්දීම් ලබා ගන්න

    554        h = self.emb(x)

    ලබාගත් අසල්වාසීන්ගේ කාවැද්දීම් .

    ආදානසහ අසල්වැසියන් සඳහා අපි එකම කාවැද්දීම් භාවිතා කරමු

    560        ret_emb = self.emb(ret)

    කපනලද හරස් අවධානය ස්ථරයේ දර්ශකය තබා ගන්න

    563        p_ca = 0

    සියලුමස්ථර සඳහා

    565        for p in range(len(self.attn)):

    හේතුකාරකස්වයං අවධානය

    567            h = self.attn[p](h)

    පළමු ස්ථරයට පෙර එන්කෝඩර් කාවැද්දීම් ලබා ගන්න

    571            if self.ca_layers and p == min(self.ca_layers):

    අපි එන්කෝඩරයට කාවැද්දීම් සම්මත කළෙමු.

    575                e = self.encoder(ret_emb, h)

    එන්කෝඩර්කාවැද්දීම් සාමාන්යකරණය කරන්න

    577                e = self.norm_e(e)

    කුරුස-හරස්අවධානය නම්

    580            if p in self.ca_layers:

    582                h = self.cca[p_ca](h, e)

    වර්ධකකපන ලද හරස් අවධානය දර්ශකය

    584                p_ca += 1

    587            h = self.ffw[p](h)

    590        return self.read(h)

    ව්යාජදත්ත සමඟ ආකෘතිය පරීක්ෂා කරන්න

    593def _test():
    597    chunk_len = 4
    598    d_model = 8
    599    d_ff = 32
    600    n_heads = 2
    601    d_k = 4
    602
    603    device = torch.device('cuda:0')
    604
    605    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
    606                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
    607
    608    m.to(device)
    609    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
    610    ret = [
    611        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
    612        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
    613    ]
    614    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
    615
    616    inspect(res)

    620if __name__ == '__main__':
    621    _test()