RoPer 由 Georges Harik (@gharik) 创作,这个实现基于他的原始代码。

具有相对距离的旋转位置嵌入 (RoPer)

旋转位置嵌入(ROPE)包括注意力分数计算中的相对位置。但是,嵌入本身不会获得任何位置信息,除了它可以隐含地从因果关注中获得的信息之外。

RoPer 将相对位置信息显式添加到值嵌入中。具体来说,它添加了它关注的代币的相对位置。我们使用相同的旋转位置嵌入来旋转注意的值,然后,在取加权总和之后,我们向相反的方向旋转决赛。这相当于相对于当前位置旋转每个值(注意之前)。

以下是使用RoPer在算术加法上训练变压器模型的训练代码,我们可以看到与RoPe相比有了显著改进。

嵌入中的相对距离

对于任何头部,让注意力从一个位置到另一个位置,并成为位置上的价值嵌入。让我们将单个功能表示为

通常情况下,我们会取值嵌入的权重和

这不会将有关位置的任何距离信息明确添加到最终结果中

RoPer 将诸如 RoPe 之类的功能配对并进行变换对于一对它会将它们转换为。让我们一起捐赠变换后的要素。然后,它将加权总和沿相反的方向旋转注意

请注意,

转换后的最终输出是,

请注意

让我们展开第一个学期

同样,我们可以显示第二个项等于,

这给了,

也就是说,相对于当前位置旋转的值的加权平均值。

这是一个使用 RoPer 执行算术加法任务的实验

118from typing import Optional
119
120import torch
121
122from labml_nn.transformers.rope import RotaryPositionalEmbeddings, RotaryPEMultiHeadAttention

向相反方向旋转的绳索模块

这继承了 Ro Pe 旋转实现并改变了方向。

125class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
  • x 是位于键或带有形状的查询开头的 Tensor[seq_len, batch_size, n_heads, d]
132    def forward(self, x: torch.Tensor):

缓存

137        self._build_cache(x)

拆分特征,我们可以选择仅将旋转嵌入应用于部分特征集。

140        x_rope, x_pass = x[..., :self.d], x[..., self.d:]

计算

144        neg_half_x = self._neg_half(x_rope)

计算

对于

160        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])

163        return torch.cat((x_rope, x_pass), dim=-1)

通过旋转定位嵌入实现多头关注

我们超越了原装变压器的多头注意力

166class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
173    def __init__(self, heads: int, d_model: int,
174                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
175                 dropout_prob: float = 0.0):
176        super().__init__(heads, d_model, rope_percentage, dropout_prob)

旋转位置嵌入层

179        d_rope_value = int(self.d_k * rope_value_percentage)
180
181        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
182        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)

query keyvalue 是存储查询向量集合的张量。它们有形状[seq_len, batch_size, d_model]

mask 有形状[seq_len, seq_len, batch_size]mask[i, j, b] 指示是否为批量查询b ,位置处的查询i 有权访问位置处的键值j

184    def forward(self, *,
185                query: torch.Tensor,
186                key: torch.Tensor,
187                value: torch.Tensor,
188                mask: Optional[torch.Tensor] = None):

querykey 并且value 有形状[seq_len, batch_size, d_model]

200        seq_len, batch_size, _ = query.shape
201
202        if mask is not None:
203            mask = self.prepare_mask(mask, query.shape, key.shape)

准备querykeyvalue 进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]

207        query = self.query(query)
208        key = self.key(key)
209        value = self.value(value)

计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]

213        scores = self.get_scores(query, key)

音阶分数

216        scores *= self.scale

涂抹面膜

219        if mask is not None:
220            scores = scores.masked_fill(mask == 0, float('-inf'))

关注按键序列维度

224        attn = self.softmax(scores)

申请退学

227        attn = self.dropout(attn)

在获取加权总和之前旋转值嵌入,使其包含位置信息

230        value = self.value_rotary_pe(value)

乘以值

234        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

向相反方向旋转,使每个嵌入保持相对位置

237        x = self.value_reverse_rotary_pe(x)

保存任何其他计算的注意力

240        self.attn = attn.detach()

连接多个头

243        x = x.reshape(seq_len, batch_size, -1)

输出层

246        return self.output(x)