English· Español· Deutsch· Nederlands· Français· 日本語· ქართული· 繁體中文· 简体中文· Português· Русский· العربية· हिन्दी· Italiano· 한국어· Polski· Svenska· Türkçe· Українська· Tiếng Việt· Bahasa Indonesia

un

guest
1 / ?
back to lessons

查詢、鍵、值

從相同輸入的三個線性映射

在嵌入(活動 4)之後,每個位置攜帶一個 768 維的向量 x_t。注意力開始時產生 x 的三個不同投影:


Q (查詢): 這個位置想要知道什麼?


K (key): 這個位置提供給其他位置什麼?


V (value): 如果被關注,這個位置會傳遞什麼內容?


每個投影都來自一個學習到的權重矩陣:


Q = x · W_Q     # W_Q 形狀:(d_model, d_k)
K = x · W_K     # W_K 形狀:(d_model, d_k)
V = x · W_V     # W_V 形狀:(d_model, d_k)

三個矩陣,皆透過反向傳播訓練。模型學習:在這個位置,哪個查詢能最佳檢索有用的過去上下文?哪個鍵能良好宣傳此位置的內容?若被選中,哪個值能提供內容?


縮放點積注意力


圖書館類比

想像一個圖書館卡片目錄。你走進去時腦中有一個主題(你的查詢)。每張卡片列出關鍵字(一個)。當你的主題匹配卡片的關鍵字時,你取出書的內容(一個)。注意力機制會對每個詞元並行執行此操作:每個位置查詢每個其他位置,評排名對齊度,並檢索值向量的加權組合。


ANDREA-120M 維度


數量備註
d_model768每個位置的向量大小
n_head12平行注意力頭
d_k64每個頭的維度 (= d_model / n_head)
T1024上下文長度

d_k = d_model / n_head = 768 / 12 = 64。每個頭看到完整的 768 維空間中的 64 維切片。Activity 6 (grow_a_language_model_multi_head) 詳細介紹了每個頭的分裂。

計算 d_k

為兩個 ANDREA 變體計算 d_k。(a) ANDREA-12M: d_model = 384, n_head = 12。(b) ANDREA-480M: d_model = 1536, n_head = 24。為每個顯示公式 d_k = d_model / n_head。

為什麼要除以 sqrt(d_k)

分數矩陣

一旦 Q 和 K 存在(每個形狀 (T, d_k)),注意力計算一個分數矩陣:


scores = Q · K^T     # shape: (T, T)

scores[i, j] = 位置 i 的查詢與位置 j 的鍵對齊的強度。每個 (i, j) 對會得到一個分數:1024 × 1024 = 1,048,576 個分數,每個注意力頭每個前向傳遞。


為什麼要除以

兩個隨機 d 維單位向量的點積,其大小量級為 sqrt(d)。若不進行縮放,分數會隨著 d_k 增長:


- d_k = 64:典型的點積數量級為 8。

- d_k = 256:典型的點積數量級為 16。

- d_k = 4096:典型的點積數量級為 64。


大的分數會產生尖峰的 softmax(一個位置主導,其他地方梯度消失)。訓練停滯。縮放修正了幅度:


scaled_scores = (Q · K^T) / sqrt(d_k)

對於 ANDREA-120M,sqrt(d_k) = sqrt(64) = 8。每個分數都會除以 8。無論 d_k 為何,大小都會保持大致單位尺度。Softmax 保持良好行為。梯度流動順暢。


Vaswani 的原始理由

來自 Attention Is All You Need (2017):「對於 d_k 的較大值,向量點積的大小會變大,將 softmax 函數推入梯度極小的區域。」sqrt(d_k) 除數可抵消這種增長。


程式碼視圖

microgpt_cuda.cu 內部,這種縮放以字面除法形式出現:


scores[i][j] = dot(Q[i], K[j]) * (1.0f / sqrtf(d_k));

每個分數一個浮點數乘法。便宜。關鍵。

在 d_model = 4096 時的縮放

假設一個研究團隊建置了 ANDREA-2B,d_model = 4096 & n_head = 32。(a) 計算 d_k。(b) 計算 sqrt(d_k)。(c) 用一句話解釋如果團隊在這個規模忘記除以 sqrt(d_k) 會發生什麼。

為什麼位置 i 不能看到位置 j > i

從生成而生的約束

ANDREA 一次生成一個 token。在推論時,位置 0 產生第一個 token,然後位置 1 看到位置 0 的輸出並產生第二個 token,依此類推。模型在生成過程中永遠無法存取未來的 token。


訓練必須反映這一點。如果在訓練期間位置 5 可以關注位置 6,模型會學到一個捷徑:「透過讀取 token 6 來預測 token 6」。在推論時,這個捷徑消失(token 6 尚未存在)。模型的訓練與推論行為會出現災難性的分歧。


一個遮罩

因果遮罩會阻擋從任何位置 i 到任何位置 j > i 的關注。實作方式:在 j > i 的地方將 scaled_scores[i][j] 設為 -infinity。經過 softmax 後,那些條目變成 exp(-inf) = 0。遮罩乾淨地將對未來位置的關注歸零。


for i in range(T):
for j in range(T):
if j > i:
scaled_scores[i][j] = -1e9   # 實際上等同於 -inf

經過 softmax(按行)後,每行總和為 1,但只有 [0, i] 的項目攜帶機率質量。位置 i 只混合來自過去位置的資訊。


可視化遮罩

應用遮罩後的得分矩陣形狀 (T, T) 看起來像下三角結構:


遮罩後的 scaled_scores,逐行 softmax:

row 0:  [1.0, 0,   0,   0,   ...]   # 只看到自己
row 1:  [0.4, 0.6, 0,   0,   ...]   # 看到位置 0, 1
第 2 行: [0.2, 0.3, 0.5, 0,   ...]   # 能看到 0, 1, 2
第 3 行: [0.1, 0.2, 0.3, 0.4, ...]   # 能看到 0, 1, 2, 3
...

每行嚴格的下三角概率分佈。未來保持不可見。


為什麼僅解碼器 Transformer 需要這個

僅解碼器模型如 ANDREA、GPT 和 LLaMA 都共享一個目標:從過去預測下一個 token。因果遮罩使該目標能夠並行訓練:每個位置同時計算自己的下一個 token 預測,且沒有位置透過偷看前方而作弊。

遮罩與風味

活動 2 (簡介) 涵蓋了三種 transformer 風味:僅編碼器、編碼器-解碼器、僅解碼器。(a) 哪種風味使用因果遮罩?(b) 用一句話說明為什麼另一種風味(僅編碼器,如 BERT)不會使用因果遮罩。(c) 未遮罩的編碼器訓練的目標是什麼?

從分數到輸出

Softmax:分數轉換為機率

遮罩後的縮放分數仍然橫跨實數範圍。Softmax 將每一行轉換為機率分佈:


A[i][j] = exp(scaled_scores[i][j]) / sum_k exp(scaled_scores[i][k])

產生三個特性:


- A[i][j] >= 0 對所有 (i, j) 皆成立。

- sum_j A[i][j] = 1 對每一行 i 皆成立。

- 較大的原始分數產生較大的機率(單調遞增)。


第 i 行的機率向量告訴模型:計算其輸出時,第 i 個位置應多少注意每個先前位置?


加權 V 總和

位置 i 的最終注意力輸出:


output[i] = sum_j A[i][j] · V[j]

每個值向量 V[j] 會被注意力機率 A[i][j] 加權,然後求和。位置 i 的輸出會結合來自每個先前位置的值向量,並依相關性加權。


以矩陣形式,一次處理所有位置:


Attention(Q, K, V) = softmax(mask(Q · K^T / sqrt(d_k))) · V

一行程式碼。一整個注意力機制。Vaswani 等人在 2017 年寫下了這一行;變形器自此以來並未根本改變。


每個頭的輸出形狀

一個注意力頭的輸出:形狀 (T, d_k)。對於 ANDREA-120M:(1024, 64)。所有 12 個頭並行計算;它們的輸出串聯成 (1024, 768) 並饋入最終線性投影 (W_O),然後進入 transformer 區塊的 MLP。


活動 6 (grow_a_language_model_multi_head) 涵蓋多頭分割。活動 7 (grow_a_language_model_transformer_block) 涵蓋注意力周圍的一切:殘差連接、層歸一化、MLP。

合成一個管線

用你自己的話合成一個完整的注意力管線。逐步說明單個位置 i(例如序列中的位置 5)從輸入向量 x_5 到注意力輸出[5] 的過程。依序命名四個操作:(1) 投影到 Q/K/V,(2) 計算與所有位置的縮放分數,(3) 應用因果遮罩 + softmax,(4) 由機率加權求和 V 向量。一段簡短的段落。