勾配スパイクの発生源
穏やかなミニバッチと衝撃的なもの
ほとんどのミニバッチは、合理的な大きさの勾配を生成します。データに大まかに適合したモデルに対するクロスエントロピー損失は狭い範囲に留まり、バックプロパゲーションはその信号を同程度の大きさの勾配として後ろに運びます。
一部のミニバッチはそうではありません。勾配スパイクの3つの発生源:
1. 外れ値の例。 極めて稀なトークン組み合わせを持つ単一のシーケンスが、平均から遠い損失と平均から遠い勾配を生み出します。
2. 数値の端数ケース。 ほぼゼロのsoftmax分母、NaNを生むlayernorm、FP16のオーバーフロー。それぞれが典型的なものより桁違いに大きな勾配を生み出します。
3. 分布のシフト。 単一の訓練実行中にデータソースを切り替えると、モデルに新しい分布でショックを与えます。ANDREAのbanditは7〜42ステップごとにソースの重みを再シャッフルします。各切り替えは小さな分布シフトです。
ANDREA-120M v1: スパイクのカスケード
v1には勾配クリッピングがありませんでした。banditからの7〜42ステップごとのソース遷移が、モデルにrepo-docs(リスト構造)の短いバーストを、gutenberg(長文散文)を、hermes3-general(Q&A)を供給しました。各遷移が勾配スパイクを生み、各スパイクが120Mスケールで重みを退化吸引子に押し込みました。
重要な実証的事実。 ANDREA-12M はクリッピングなしで同じバンディットに耐えました。小さな重み行列は勾配ショックに対して頑健です。1つの悪いバッチが12Mのパラメータを暴走アトラクタに押し込むことはできませんが、120Mのパラメータでは可能です。モデルがスケールするにつれてクリッピングの重要性が増します。
グローバル L2 ノルムクリッピング
2つの選択肢:Per-Tensor または Global
勾配の大きさを制限する2つの方法:
Per-tensor クリッピング。 各勾配テンソールを独立してクリップします。埋め込み勾配は自身のノルムでクリップされ、注意勾配は自身のノルムでクリップされます。シンプルですが、相対的なスケールを歪めます:1つのテンソルの小さなスパイク(今やゼロ勾配)が、もう1つの巨大な勾配(未処理)とペアになります。
グローバル L2 ノルムクリッピング。 すべての勾配を1つの大きなベクトルとして扱います。すべてのパラメータにわたる総 L2 ノルムを計算します。ノルムが max_norm を超えた場合、すべての勾配を同じ係数でスケーリングします。テンソル間の相対的な大きさを保持します。
ANDREA はグローバルを使用します。Pascanu ら (2013) は、トランスフォーマーの訓練においてグローバルクリッピングがテンソルごとのクリッピングを上回ることを実証的に示しました。
数学
グローバル L2 ノルムを計算します:
norm = sqrt(sum over all params of g_i^2)
norm <= max_norm の場合、勾配はそのまま通過します。norm > max_norm の場合、すべての勾配を max_norm / norm でスケーリングします:
g_i_clipped = g_i * (max_norm / norm)
スケーリング後、新しいノルムは正確に max_norm になります。ANDREA は max_norm = 1.0 を使用します。
スケールファクターの計算
勾配ノルム計算に3つのカーネルが必要な理由
ナイーブなアルゴリズムはGPUで実行できない
グローバルL2ノルム計算の擬似コード:
total = 0
各パラメータ p に対して:
p.grad の各要素 g に対して:
total += g * g
norm = sqrt(total)
GPU 上では、この単純なループは2つの理由で失敗します:
1. 逐次蓄積。 単一の total 蓄積器はすべてのスレッドが他のすべてのスレッドを待機させるため、GPUの並列性を損ないます。
2. 異種テンソル。 ANDREA-120Mには形状が大きく異なるテンソルがあります:埋め込み (8449 x 768)、注意機構 QKV (768 x 768)、layernorm (768)。1つのカーネルではすべての形状を効率的に処理できません。
ANDREAの3カーネル・パイプライン
作業を microgpt_cuda.cu の3つのCUDAカーネルに分割します:
カーネル1: k_grad_norm_partial。 各パラメータテンソルに対して、平方和の部分和を計算します。各スレッドブロックがテンソルのチャンクを縮約し、結果を小さなスクラッチバッファに書き込みます。並列性:チャンクごとに1ブロック、全テンソルにわたって数百のブロック。
カーネル 2: k_grad_norm_final. スクラッチバッファを単一のスカラーに縮約します。その平方根を取ります。小さなカーネルで、マイクロ秒単位で実行されます。
カーネル 3: k_grad_scale. norm > max_norm の場合、scale = max_norm / norm を計算し、すべての勾配要素に scale を掛けます。すべての勾配テンソルに1回パス、恥ずかしいほど並列化可能です。
順序が重要: Pre-Adam
クリッピングパイプラインは、AdamW が m、v、または任意のパラメータを更新する前に実行されます。なぜですか?
クリップされた勾配がAdamWの指数移動平均に供給されます。スパイクがm & vに流れるのを許すと、それらの走行平均を破損させ、スパイク後の多くのステップで回復を遅くします。Adam前のクリッピングにより、スパイクの影響を単一の不良ステップに限定します。
なぜ3つのカーネルなのか、1つではないのか?
ノークリッピングがv1を殺した方法
バンディットソースが7〜42ステップごとに遷移
ANDREAのバンディットはフェーズで動作します。各フェーズは7、14、21、28、または42ステップ(ランダムに選択)続きます。各フェーズ境界で、ソースウェイトがシフトします:例えばrepo-docsが0.1から0.6に跳ね上がり、gutenbergが0.4から0.1に低下、hermes3-generalが0.5から0.7に上昇するなど。
各遷移はモデルへの分布ショックです。ロスが一時的にスパイクします。グラディエントもそれに伴いスパイク:gutenberg風の散文に対してロスを最小化していたモデルが、今やrepo-docs風のリスト構造を見ることになり、グラディエントは典型的な大きさの10倍や100倍の矯正信号を運びます。
v1の失敗モード
クリッピングなしで、それらの10-100倍のグラディエンスパイクがAdamWのm & v平均に流れ込みました。AdamWのスムージングにより、スパイクの影響は実際の悪いバッチの後でも多くのステップにわたって持続しました。v1のバニラAdam(weight decayなし)と組み合わせると、スパイク駆動の重み更新がフェーズごとに累積し、重みが劣化アトラクタに漂いました:1つのトークンのロジットがsoftmaxを支配し、サンプル出力がそのトークンになり、トレーニングコンテキストにそのトークンが含まれ、グラディエントがそのトークンを強化しました。再現ロックイン。
v2の安定性
v2ではmax_norm = 1.0のクリッピングが追加され、AdamW & LR warmupが導入されました。m & vへのスパイクの影響は制限され、重みのドリフト速度はピーク時でもlr max_norm = 0.0003 1.0 = 0.0003を超えません。フェーズ遷移は依然としてスパイクを生みますが、それらのスパイクはオプティマイザに到達する前にキャップされます。
結果:v2(データフィルタv2.5 & v3の洗練後)は、事実の想起、多段落のコヒーレンス、および生物学 & 信号処理サンプルの外部評価で9.5/10を達成しました。
容量-脆性カップリング
同じ bandit。同じデータ。同じハイパーパラメータを除くクリッピング以外。なぜ12Mはクリッピングなしで生き残ったのに、120Mは崩壊したのか?
2つの相乗要因:
1. より大きな重み行列はより多くのアトラクターを保存する。 768x768の注意プロジェクションは590Kパラメータを持つ;パラメータあたりのわずかなドリフトでも注意の動作に意味のある変化を生む。384x384の注意プロジェクションは147Kパラメータを持ち、より制約された部分空間に留まる。
2. 層が多いほど乗法的相互作用が増える。 v3は12のトランスフォーマー層を持つ(12Mの6に対して)。スパイクは12層の相乗的な非線形性を通って伝播する;各層は前の層のドリフトを増幅できる。
脆性は容量とともに増大する。クリッピングはあるスケール閾値以上で必須になる;ANDREAはその閾値を12Mと120Mパラメータの間に置く。
v1 カスケードの診断
クリッピングが他に適用される場所は?
関連活動
クリッピングに関連する3つの兄弟活動:
- Activity 10: AdamW. クリッピングはAdamWの m & v をスパイク汚染から保護します。クリッピングなしでは、1つの悪いバッチが50ステップ以上のオプティマイザ状態を破損します。
- Activity 11: LR warmup. Warmupは lr を減衰させ、クリッピングは g を減衰させます。組み合わせると:ステップ1で、最悪ケースのパラメータ更新は lr_after_warmup max_norm = 1.5e-7 1.0 = 1.5e-7 となり、どちらのガードもなしの 0.0003 * 50 = 0.015 に比べて100,000倍の削減です。
- アクティビティ 14: 多腕バンディット。 バンディット段階の長さ(7〜42ステップ)は、特定のソースが支配しないように特に短く設定されています。クリッピングが頻繁な遷移を安全にします。
クリッピングはトランスフォーマー訓練における最も安価な安定性向上策です:3つの小さなCUDAカーネル、ステップあたりマイクロ秒単位、120M+モデルが収束するか崩壊するかを決定的に左右します。