9from typing import Optional
10
11import torch
12from torch import nn
13
14from labml_helpers.module import Module
15from labml_nn.transformers.fast_weights import DPFP
16from labml_nn.transformers.feed_forward import FeedForward
17from labml_nn.transformers.mha import PrepareForMultiHeadAttention
18from labml_nn.utils import clone_module_list21class FastWeightsAttention(Module):22 def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
23 super().__init__()ヘッドあたりの機能数
26 self.d_k = d_model // heads28 self.heads = headsquery
これらは多面的な注意力を変えます。
31 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)これらは頭の中を一変させkey
、value
多面的な注目を集めます。
33 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
34 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
35
36 self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
37 nn.Sigmoid())
38
39 self.phi = phi出力レイヤー
42 self.output = nn.Linear(d_model, d_model)ドロップアウト
44 self.dropout = nn.Dropout(dropout_prob)46 def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
47 query = self.phi(self.query(x))
48 key = self.phi(self.key(x))
49 value = self.value(x)
50
51 if weights is None:
52 weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
53
54 value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
55
56 beta = self.gate(x)
57
58 weights = weights + torch.einsum('bhv,bhk->bhvk', beta * (value - value_existing), key)
59
60 x = torch.einsum('bhvk,bhk->bhv', weights, query)複数のヘッドを連結
63 x = x.reshape(x.shape[0], -1)出力レイヤー
66 return self.output(x), weights69class FastWeightsAttentionTransformerLayer(Module):70 def __init__(self, *,
71 d_model: int,
72 attn: FastWeightsAttention,
73 feed_forward: FeedForward,
74 dropout_prob: float):
75 super().__init__()変圧器サイズ
77 self.size = d_model79 self.attn = attn
80 self.feed_forward = feed_forward
81 self.dropout = nn.Dropout(dropout_prob)正規化レイヤー
84 self.norm_self_attn = nn.LayerNorm([d_model])
85 self.norm_ff = nn.LayerNorm([d_model])87 def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
88 attn, weights = self.attn(x, weights)セルフアテンションの結果を追加
90 x = x + self.dropout(attn)フィードフォワード用に正規化
93 z = self.norm_ff(x)フィードフォワードネットワークを通過
95 ff = self.feed_forward(z)フィードフォワードの結果を追加し直す
97 x = x + self.dropout(ff)100 return x, weights103class FastWeightsAttentionTransformer(Module):104 def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
105 super().__init__()トランスレイヤーのコピーを作成
107 self.layers = clone_module_list(layer, n_layers)最終正規化レイヤー
109 self.norm = nn.LayerNorm([layer.size])111 def forward(self, x_seq: torch.Tensor):入力をシーケンス軸に沿ってリストに分割します
113 x_seq = torch.unbind(x_seq, dim=0)出力を保存するリスト
115 res = []各インプットステップについて
117 weights = [None for _ in range(len(self.layers))]
118
119 for x in x_seq:各レイヤーを貫通する
121 for i, layer in enumerate(self.layers):レイヤー出力を取得
123 x, weights[i] = layer(x, weights[i])
124
125 res.append(x)出力テンソルを積み重ねる
128 res = torch.stack(res)出力を正規化
130 return self.norm(res)