反馈变压器

这是 PyTorch 对《使用反馈存储器访问序列变压器中的更高层次表示》一文的 PyT orch 实现。

普通的变压器会并行处理代币。每个变压器层都注意前一层的输出。反馈变压器注意前面步骤中所有层的输出。因此,这会增加重复性,我们需要逐个代币进行处理。这会显著减慢训练速度(大约 5 到 10 倍,具体取决于序列长度)。但是,在预测反馈变换器时,速度更快,因为如果你缓存了内存向量,你可以预测下一个标记。

为了加快训练速度,本文讨论了从短序列长度开始并逐渐增加序列长度的问题。他们还讨论了使用预训练的并行变压器作为起点。

原始反馈变压器不保留所有层的输出。相反,它保留所有图层输出的加权总和。这减少了预测期间用于缓存的内存。这个文件的前半部分实现了这一点。

更新后的反馈变压器在各层之间共享用于计算密钥和值的权重。然后,我们只计算每个步骤的键和值一次,并将其缓存。这个文件的后半部分实现了这一点。我们实现了一个自定义 PyTorch 函数来提高性能。

这是训练代码和一本用于在 Tiny Shakespeare 数据集上训练反馈转换器的笔记本。

Colab 笔记本

Open In Colab