スイッチトランス

これは、論文の「スイッチトランスフォーマー:シンプルで効率的なスパース性を備えた1兆パラメータモデルへのスケーリングのミニチュアPyTorch実装です。私たちの実装には数百万のパラメーターしかなく、モデルの並列分散トレーニングは行いません。シングルGPUトレーニングを行いますが、論文に記載されているようにスイッチングという概念を実装しています

Switch Transformer は、トークンに基づいてパラメーターを切り替えることにより、トークンごとに異なるパラメーターを使用します。したがって、各トークンで選択されるパラメータはごくわずかです。そのため、より多くのパラメーターを使用できますが、計算コストは少なくなります

切り替えは、各トランスブロックの位置ごとのフィードフォワードネットワーク (FFN) で行われます。位置単位のフィードフォワードネットワークは、連続して完全に接続された2つの層で構成されています。スイッチトランスには複数のFFN(複数のエキスパート)がいて、ルーターに基づいてどれを使用するかを選択しました。出力はFFNを選択する確率のセットで、最も確率の高いものを選んで評価します。つまり、基本的に、計算コストは単一の FFN を使用する場合と同じです。私たちの実装では、FFNが多い場合や大きい場合は、すべて1つのGPUで実行されるため、うまく並列化できません。分散型セットアップでは、それぞれの FFN(それぞれが非常に大きい)を異なるデバイスに配置することになります

この論文では、エキスパート(FFN)間で負荷を分散するための別の損失用語を紹介し、ルーティングのバランスが取れていない場合のトークンのドロップについて論じています。

これは、Tiny Shakespeareデータセットのスイッチトランスをトレーニングするためのトレーニングコードとノートブックです