1import math
2
3import torch
4from torch import nn
5
6from labml_helpers.module import Module
7from labml_nn.transformers import MultiHeadAttention10class SpatialDepthWiseConvolution(Module):d_k
は各ヘッドのチャンネル数17 def __init__(self, d_k: int, kernel_size: int = 3):21 super().__init__()
22 self.kernel_size = kernel_sizeConv1d
PyTorchのモジュールを使用しています。グループの数をチャネル数と同じになるように設定し、チャネルごとに (異なるカーネルで) 個別の畳み込みを行います。両側にパディングを追加し、kernel_size - 1
後で一番適切な結果になるようにトリミングします
27 rng = 1 / math.sqrt(kernel_size)
28 self.kernels = nn.Parameter(torch.zeros((kernel_size, d_k)).uniform_(-rng, rng))x
形がある [seq_len, batch_size, heads, d_k]
30 def forward(self, x: torch.Tensor):35 res = x * self.kernels[0].view(1, 1, 1, -1)
36
37 for i in range(1, len(self.kernels)):
38 res[i:] += x[:-i] * self.kernels[i].view(1, 1, 1, -1)
39
40 return resMulti-Head Attentionの当初の実装を拡張し、クエリ、キー、バリュープロジェクションに空間深度方向のコンボリューションを追加します。
43class MultiDConvHeadAttention(MultiHeadAttention):51 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
52 super().__init__(heads, d_model, dropout_prob)Multi-Head Attention は、クエリ、キー、バリュープロジェクションモジュールself.query
self.key
、およびを作成します。self.value
それぞれに空間深度方向の畳み込み層を組み合わせて、、、を置き換えますself.query
。self.key
self.value
59 self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
60 self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
61 self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))