1import math
2
3import torch
4from torch import nn
5
6from labml_helpers.module import Module
7from labml_nn.transformers import MultiHeadAttention

空间深度明智卷积

这其实比较慢

10class SpatialDepthWiseConvolution(Module):
  • d_k 是每个 head 中的通道数
17    def __init__(self, d_k: int, kernel_size: int = 3):
21        super().__init__()
22        self.kernel_size = kernel_size

我们使用 PyTorch 的Conv1d 模块。我们将组的数量设置为等于通道数,以便它对每个通道进行单独的卷积(使用不同的内核)。我们在两边添加填充,然后裁剪最右边的kernel_size - 1 结果

27        rng = 1 / math.sqrt(kernel_size)
28        self.kernels = nn.Parameter(torch.zeros((kernel_size, d_k)).uniform_(-rng, rng))

x 有形状[seq_len, batch_size, heads, d_k]

30    def forward(self, x: torch.Tensor):
35        res = x * self.kernels[0].view(1, 1, 1, -1)
36
37        for i in range(1, len(self.kernels)):
38            res[i:] += x[:-i] * self.kernels[i].view(1, 1, 1, -1)
39
40        return res

多 dconv-Head 注意力 (MDHA)

我们扩展了最初的 M ulti-Head Attention 实现,并将空间深度卷积添加到查询、键和值投影中。

43class MultiDConvHeadAttention(MultiHeadAttention):
51    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
52        super().__init__(heads, d_model, dropout_prob)

Multi-Head Attention 将创建查询、键和价值投影模块self.query self.key 、和self.value

我们将空间深度卷积层组合到每个层上,并替换self.query self.key 、和self.value

59        self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
60        self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
61        self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))