这是纸质《开关变形金刚:以简单高效的稀疏度扩展到万亿个参数模型》的微型 PyTorch 实现。我们的实现只有几百万个参数,不对并行分布式训练进行建模。它进行单个 GPU 训练,但我们实现了论文中描述的切换概念。
Switch Transformer 通过根据令牌在参数之间切换,为每个令牌使用不同的参数。因此,只为每个代币选择了一小部分参数。因此,您可以拥有更多参数,但计算成本更低。
切换发生在每个变压器模块的位置前馈网络 (FFN) 上。位置前馈网络由两个按顺序完全连接的层组成。在交换机变压器中,我们有多个 FFN(多位专家),我们根据路由器选择使用哪一个。输出是一组用于选择 FFN 的概率,我们选择概率最高的概率,然后仅对其进行评估。因此,从本质上讲,计算成本与拥有单个 FFN 相同。在我们的实现中,当你有许多或大型 FFN 时,这种并行化效果不佳,因为这一切都发生在单个 GPU 上。在分布式设置中,你会将每个 FFN(每个都很大)放在不同的设备上。
本文引入了另一个损失术语来平衡专家(FFN)之间的负载,并讨论了路由不平衡时丢弃代币的问题。
这是训练代码和一本用于在 Tiny Shakespeare 数据集上训练开关变压器的笔记本。