圧縮変圧器

これは PyTorch の長距離シーケンスモデリング用の圧縮トランスフォーマーの実装です

これはTransformer XLの拡張版で、過去の記憶を圧縮して注意範囲を広げています。つまり、最も遠いメモリがメモリに圧縮されます。ここで、は圧縮率です

圧縮操作

圧縮操作は次のように定義されます。この論文では複数の選択肢を紹介していますが、最良の結果が得られると思われる1次元の畳み込みのみを実装しています。各レイヤーには個別の圧縮操作があります。ここで、はレイヤー番号です。

トレーニング用圧縮操作

BPTTによるトレーニング圧縮では、非常に大きな計算グラフ(多くのタイムステップ)を維持する必要があるため、この論文では自動エンコーディング損失と注意再構成損失を提案しています自動エンコーディング損失は、圧縮されたメモリから元のメモリをデコードし、損失を計算します。アテンション再構成損失では、圧縮メモリと非圧縮メモリでマルチヘッドアテンションの結果を計算し、それらの間の平均二乗誤差を求めます。後者の方が良い結果が得られるため、ここでは後者を実装しました。

この実装ではレイヤー前の正規化を使用しますが、ペーパーではレイヤー後の正規化を使用します。前層ノルムはFFNやセルフアテンション前の層ノルムを行い、残差接続でのパススルーは正規化されません。これは標準的な変圧器の設定ではより安定しているはずです

Tiny Shakespeareデータセットで圧縮トランスフォーマーモデルをトレーニングするためのトレーニングコードとノートブックは次のとおりです

Open In Colab