幅がヘッドに分割される
単一のヘッドは1つのパターンを見る
アクティビティ67では、スケールドドットプロダクトアテンションを扱いました:クエリベクトルQ、キーべクトルK、値ベクトルV;Q·Kᵀ/√d_kを計算し、マスク、softmax、Vに重み付け。一つのヘッドは一つの関係パターンを学習します:主語-動詞の一致、句読点のペアリング、または役に立たないものかもしれません。
マルチヘッドアテンションは、同じ操作をトークンの表現の異なるスライスで並列に複数回実行します。12個の並列ヘッド。12個の可能な関係パターン。ヘッドは訓練圧力だけで専門化します;建築家がヘッド4に動詞の時制を見るよう指示することはありません。
Split 関係
ANDREA-120M は d_model = 768 & n_head = 12 を設定します。マルチヘッドアテンションは 768 を 12 個の 64 のチャンクに分割します:
head_dim = d_model / n_head
64 = 768 / 12
すべてのヘッドは64次元のベクトルで動作します。splitは明確に適用されます:d_modelはn_headで割り切れなければなりません(余りが0)。この条件に違反する設定は、実行時ではなく設定検証で失敗します。
3つのモデル、3つの分割
| Variant | d_model | n_head | head_dim |
|---|---|---|---|
| ANDREA-12M | 384 | 12 | 32 |
| ANDREA-120M | 768 | 12 | 64 |
| ANDREA-480M | 1536 | 24 | 64 |
注意: ANDREA-12M & ANDREA-120M は n_head=12 を一定に保ち、d_model およびそのため head_dim のみがスケールします。ANDREA-480M はヘッド数を24に倍増し、head_dim=64 を ANDREA-120M に合わせています。
head_dim の計算
ヘッドごとの3つの行列、または1つの大きな行列
ヘッドごとの視点
各ヘッドは独自のクエリ投影、キー投影、値投影を必要とします。ヘッド h の場合:
Q_h = X · W_Q^h ここで W_Q^h の形状は [d_model, head_dim]
K_h = X · W_K^h ただし、W_K^h の形状は [d_model, head_dim]
V_h = X · W_V^h ただし、W_V^h の形状は [d_model, head_dim]
X は入力形状 [batch, seq_len, d_model] を保持します。射影後、Q_h、K_h、V_h はそれぞれ形状 [batch, seq_len, head_dim] を保持します。
統合ビュー
各ヘッドの行列はメモリ上で隣り合って配置されます。形状 [d_model, d_model] の単一の統合行列 W_Q が一度にすべてのヘッドを生成します:
Q_fused = X · W_Q # [batch, seq_len, d_model]
Q_per_head = reshape(Q_fused) # [batch, n_head, seq_len, head_dim]
融合行列積は12回のBLAS呼び出しではなく1回の呼び出しで済みます。CUDAテンソルコアはこのサイズの行列積でピークスループットに達します。ヘッドごとの行列積ではハードウェアを十分に活用できません。
パラメータ数
3つの融合行列 W_Q、W_K、W_V、それぞれ d_model × d_model。出力投影 W_O を加えて、これも d_model × d_model。ANDREA-120M の場合:
1層あたりの注意機構のパラメータ = 4 × 768² = 2,359,296 ≈ 2.36M
12層にわたるパラメータ = 12 × 2.36M ≈ 28.3M
ANDREA-120Mの総パラメータの約4分の1が注意機構のプロジェクションに存在します。残りの4分の3はMLPサブレイヤーと埋め込みに存在します。
プロジェクションの命名
12のベクトルが1つになる
各ヘッドが計算した後
各ヘッドは形状 [batch, seq_len, head_dim] の出力テンソルを生成します。12個のヘッドは12個のそのようなテンソルを生成します。特徴次元に沿った連結により、それらを再びまとめます:
concat_output = concat(head_1, head_2, ..., head_12)
shape = [batch, seq_len, n_head × head_dim]
= [batch, seq_len, 768] # for ANDREA-120M
Concat は split を逆転させます。総特徴次元は d_model に戻ります。次元での情報損失はありません;違いは各チャンクが含む内容にあります:ヘッド 1 のチャンクはヘッド 1 の学習された注意パターンを反映します。
出力投影 W_O
単なる Concatenation だけではヘッドは孤立したままです:ヘッド 4 の出力がヘッド 7 の出力の隣にあり、互いに認識していません。形状 [d_model, d_model] の出力投影 W_O がそれらを混合します:
attention_output = concat_output · W_O
shape = [batch, seq_len, d_model]
W_O の後で、出力次元のそれぞれが12個のすべてのヘッドの学習された線形結合を運ぶようになります。情報は、この単一の行列乗算を通じてヘッド間で自由に流れます。
なぜヘッドが専門化するのか
アーキテクチャ上、何もhead 4が動詞の時制を学習したり、head 9が対応する句読点を学習したりすることを強制しません。専門化は勾配の圧力から生じます:訓練中、重複して寄与するheadは、独自に寄与するheadよりも小さな勾配を受け取ります。何千ステップにわたり、各headは総損失を最も効果的に減少させるニッチに落ち着きます。
経験的に、訓練されたトランスフォーマーは、位置パターン(headが前のトークンを見る)、構文パターン(headが対応する閉じ括弧を見る)、意味パターン(headが最も最近の固有实体を見る)を扱うheadを示します。ラベルなしでこの専門化が訓練されます。W_Oを通じて伝播される訓練シグナルだけで、headが整理されます。
なぜ12個のheadか、より広い1つのheadではないのか
CUDA がヘッドをどのように保存するか
単一のテンソル、リシェイプ済み
ANDREA のトレーニングエンジン microgpt_cuda.cu は、12 個のヘッドそれぞれに別々の 12 個のバッファを割り当てません。1 つの融合テンソルを割り当て、ヘッド次元をストライドパターンとして扱います:
// Q = X · W_Q (1 回の matmul、ヘッド間で融合) の後
// Q の形状は [batch, seq_len, d_model]
// [batch, seq_len, n_head, head_dim] にリシェイプ
// [batch, n_head, seq_len, head_dim] に転置
// 各ヘッドの内側の2次元が現在メモリ上で連続
転置により n_head が seq_len の前に移動します。なぜか? 次の操作 (Q_h · K_h^T) で各ヘッドの seq_len × head_dim スライスがメモリ上で連続である必要があるためです。CUDA の行列乗算は連続テンソルで高速に動作します。
1つのカーネル、多数のヘッド
単一のattention CUDAカーネルは、すべてのヘッドで並列に実行されます。各スレッドブロックが1つの(batch, head)ペアを処理します。ブロック内のスレッドは、seq_len × head_dimタイルで協調します。カーネルは複数のヘッドを処理していることを知りません。起動グリッドが並列性を処理します。
構成はハードウェアを反映します
ANDREA-120Mのn_head=12, head_dim=64の選択は、RTX 4090テンソルコアに適合します。これらはmatmulタイルを16の倍数で好みます。head_dim=64 = 4 × 16はタイル形状に完全に一致します。head_dim=32 (ANDREA-12M)も一致しますが、タイルを十分に活用しません。head_dim=72は一致せず、フォールバックカーネルを強制します。
最終的な全体像
| ステップ | 操作 | 出力形状 |
|---|---|---|
| 1. プロジェクト | Q = X · W_Q (K, V も同様) | [batch, seq, d_model] |
| 2. リシェイプ&転置 | d_model を (n_head, head_dim) に分割 | [batch, n_head, seq, head_dim] |
| 3. ヘッドごとの注意機構 | 各ヘッドでスケールド・ドットプロダクト | [batch, n_head, seq, head_dim] |
| 4. 転置&リシェイプ | (n_head, head_dim) を d_model に結合 | [batch, seq, d_model] |
| 5. 出力投影 | output = concat · W_O | [batch, seq, d_model] |
5つのステップ。3つの行列乗算が入力を触る(Q, K, V 投影)。1つの行列乗算が結合されたヘッドを触る(W_O)。1つの注意カーネルがすべてのヘッドを並列処理。ANDREA-120M は各レイヤーで5つのステップすべてを1回実行、12層深く、毎回のフォワードパスで。