RoPer 由 Georges Harik (@gharik) 创作,这个实现基于他的原始代码。
旋转位置嵌入(ROPE)包括注意力分数计算中的相对位置。但是,嵌入本身不会获得任何位置信息,除了它可以隐含地从因果关注中获得的信息之外。
RoPer 将相对位置信息显式添加到值嵌入中。具体来说,它添加了它关注的代币的相对位置。我们使用相同的旋转位置嵌入来旋转注意的值,然后,在取加权总和之后,我们向相反的方向旋转决赛。这相当于相对于当前位置旋转每个值(注意之前)。
以下是使用RoPer在算术加法上训练变压器模型的训练代码,我们可以看到与RoPe相比有了显著改进。
对于任何头部,让注意力从一个位置到另一个位置,并成为位置上的价值嵌入。让我们将单个功能表示为。
通常情况下,我们会取值嵌入的权重和
这不会将有关位置的任何距离信息明确添加到最终结果中。
RoPer 将诸如 RoPe 之类的功能配对并进行变换对于一对,它会将它们转换为。让我们一起捐赠变换后的要素。然后,它将加权总和沿相反的方向旋转。注意。
请注意,
转换后的最终输出是,
请注意。
让我们展开第一个学期,
同样,我们可以显示第二个项等于,
这给了,
也就是说,相对于当前位置旋转的值的加权平均值。
118from typing import Optional
119
120import torch
121
122from labml_nn.transformers.rope import RotaryPositionalEmbeddings, RotaryPEMultiHeadAttention125class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):x
是位于键或带有形状的查询开头的 Tensor[seq_len, batch_size, n_heads, d]
132 def forward(self, x: torch.Tensor):缓存和值
137 self._build_cache(x)拆分特征,我们可以选择仅将旋转嵌入应用于部分特征集。
140 x_rope, x_pass = x[..., :self.d], x[..., self.d:]计算
144 neg_half_x = self._neg_half(x_rope)160 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])163 return torch.cat((x_rope, x_pass), dim=-1)166class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):173 def __init__(self, heads: int, d_model: int,
174 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
175 dropout_prob: float = 0.0):
176 super().__init__(heads, d_model, rope_percentage, dropout_prob)旋转位置嵌入层
179 d_rope_value = int(self.d_k * rope_value_percentage)
180
181 self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
182 self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)query
key
和value
是存储查询、键和值向量集合的张量。它们有形状[seq_len, batch_size, d_model]
。
mask
有形状[seq_len, seq_len, batch_size]
并mask[i, j, b]
指示是否为批量查询b
,位置处的查询i
有权访问位置处的键值j
。
184 def forward(self, *,
185 query: torch.Tensor,
186 key: torch.Tensor,
187 value: torch.Tensor,
188 mask: Optional[torch.Tensor] = None):query
,key
并且value
有形状[seq_len, batch_size, d_model]
200 seq_len, batch_size, _ = query.shape
201
202 if mask is not None:
203 mask = self.prepare_mask(mask, query.shape, key.shape)准备query
,key
并value
进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]
。
207 query = self.query(query)
208 key = self.key(key)
209 value = self.value(value)计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]
。
213 scores = self.get_scores(query, key)音阶分数
216 scores *= self.scale涂抹面膜
219 if mask is not None:
220 scores = scores.masked_fill(mask == 0, float('-inf'))关注按键序列维度
224 attn = self.softmax(scores)申请退学
227 attn = self.dropout(attn)在获取加权总和之前旋转值嵌入,使其包含位置信息
230 value = self.value_rotary_pe(value)乘以值
234 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)向相反方向旋转,使每个嵌入保持相对位置
237 x = self.value_reverse_rotary_pe(x)保存任何其他计算的注意力
240 self.attn = attn.detach()连接多个头
243 x = x.reshape(seq_len, batch_size, -1)输出层
246 return self.output(x)