9from typing import Optional
10
11import torch
12from torch import nn
13
14from labml_helpers.module import Module
15from labml_nn.transformers.fast_weights import DPFP
16from labml_nn.transformers.feed_forward import FeedForward
17from labml_nn.transformers.mha import PrepareForMultiHeadAttention
18from labml_nn.utils import clone_module_list
21class FastWeightsAttention(Module):
22    def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
23        super().__init__()

每头特征数

26        self.d_k = d_model // heads

28        self.heads = heads

这些改变了query 多头注意力。

31        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

这些改变了多value 头注意力的key 和。

33        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
34        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
35
36        self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
37                                  nn.Sigmoid())
38
39        self.phi = phi

输出层

42        self.output = nn.Linear(d_model, d_model)

辍学

44        self.dropout = nn.Dropout(dropout_prob)
46    def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
47        query = self.phi(self.query(x))
48        key = self.phi(self.key(x))
49        value = self.value(x)
50
51        if weights is None:
52            weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
53
54        value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
55
56        beta = self.gate(x)
57
58        weights = weights + torch.einsum('bhv,bhk->bhvk', beta * (value - value_existing), key)
59
60        x = torch.einsum('bhvk,bhk->bhv', weights, query)

连接多个头

63        x = x.reshape(x.shape[0], -1)

输出层

66        return self.output(x), weights
69class FastWeightsAttentionTransformerLayer(Module):
70    def __init__(self, *,
71                 d_model: int,
72                 attn: FastWeightsAttention,
73                 feed_forward: FeedForward,
74                 dropout_prob: float):
75        super().__init__()

变压器尺寸

77        self.size = d_model

79        self.attn = attn
80        self.feed_forward = feed_forward
81        self.dropout = nn.Dropout(dropout_prob)

归一化层

84        self.norm_self_attn = nn.LayerNorm([d_model])
85        self.norm_ff = nn.LayerNorm([d_model])
87    def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
88        attn, weights = self.attn(x, weights)

添加自我关注的结果

90        x = x + self.dropout(attn)

标准化以进行前馈

93        z = self.norm_ff(x)

通过前馈网络

95        ff = self.feed_forward(z)

将前馈结果添加回来

97        x = x + self.dropout(ff)

100        return x, weights
103class FastWeightsAttentionTransformer(Module):
104    def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
105        super().__init__()

制作变压器层的副本

107        self.layers = clone_module_list(layer, n_layers)

最终归一化层

109        self.norm = nn.LayerNorm([layer.size])
111    def forward(self, x_seq: torch.Tensor):

沿序列轴将输入拆分为一个列表

113        x_seq = torch.unbind(x_seq, dim=0)

存储输出的列表

115        res = []

对于每个输入步骤

117        weights = [None for _ in range(len(self.layers))]
118
119        for x in x_seq:

穿过每一层

121            for i, layer in enumerate(self.layers):

获取图层输出

123                x, weights[i] = layer(x, weights[i])
124
125            res.append(x)

堆叠输出张量

128        res = torch.stack(res)

规范化输出

130        return self.norm(res)