English· Español· Deutsch· Nederlands· Français· 日本語· ქართული· 繁體中文· 简体中文· Português· Русский· العربية· हिन्दी· Italiano· 한국어· Polski· Svenska· Türkçe· Українська· Tiếng Việt· Bahasa Indonesia

un

访客
1 / ?
返回课程列表

查询、键、值

来自同一输入的三个线性映射

在嵌入(活动 4)之后,每个位置携带一个 768 维向量 x_t。注意力首先产生 x 的三个不同投影:


Q(查询): 这个位置想知道什么?


K (键): 这个位置向其他位置提供什么?


V (值): 如果被关注,这个位置会传递什么内容?


每个投影都来自一个可学习的权重矩阵:


Q = x · W_Q     # W_Q 形状:(d_model, d_k)
K = x · W_K     # W_K 形状:(d_model, d_k)
V = x · W_V     # W_V 形状:(d_model, d_k)

三个矩阵,均通过反向传播训练。模型学习:在该位置,哪个查询最能检索有用的过去上下文?哪个键能很好地宣传该位置的内容?如果被选中,哪个值能提供内容?


Scaled dot-product attention


图书馆类比

想象一个图书馆卡片目录。你带着一个主题走进去(你的查询)。每张卡片都列出关键词(一个)。当你的主题匹配卡片的关键词时,你取出书籍内容(一个)。注意力机制会并行地为每个标记执行此操作:每个位置查询每个其他位置,排名对齐,并检索值向量的加权组合。


ANDREA-120M 维度


数量备注
d_model768每个位置的向量大小
n_head12并行注意力头
d_k64每个头的维度 (= d_model / n_head)
T1024上下文长度

d_k = d_model / n_head = 768 / 12 = 64。每个头看到完整的768维空间中的64维切片。活动6 (grow_a_language_model_multi_head) 详细介绍了每个头的拆分。

计算 d_k

为两个 ANDREA 变体计算 d_k。(a) ANDREA-12M: d_model = 384, n_head = 12。(b) ANDREA-480M: d_model = 1536, n_head = 24。为每个变体展示公式 d_k = d_model / n_head。

为什么除以 sqrt(d_k)

一个分数矩阵

一旦 Q 和 K 存在(每个形状 (T, d_k)),注意力计算一个分数矩阵:


scores = Q · K^T     # shape: (T, T)

scores[i, j] = 位置 i 的查询与位置 j 的键对齐的强度。每个 (i, j) 对得到一个分数:每个注意力头每个前向传播 1024 × 1024 = 1,048,576 个分数。


为什么需要除法

两个随机 d 维单位向量的点积幅度大约为 sqrt(d)。不进行缩放,分数会随着 d_k 增长:


- d_k = 64:典型的点积数量级为 8。

- d_k = 256:典型的点积数量级为 16。

- d_k = 4096:典型的点积数量级为 64。


大的分数会产生尖锐的 softmax(一个位置主导,其他地方梯度消失)。训练停滞。缩放修复了幅度:


scaled_scores = (Q · K^T) / sqrt(d_k)

对于 ANDREA-120M,sqrt(d_k) = sqrt(64) = 8。每个分数都被除以 8。无论 d_k 如何,幅度大致保持在单位尺度。Softmax 保持良好行为。梯度流动顺畅。


Vaswani 的原始理由

来自 Attention Is All You Need (2017):'对于较大的 d_k 值,向量点积的幅度会变得很大,将 softmax 函数推入梯度极小的区域。' sqrt(d_k) 除数可以抵消这种增长。


代码视图

microgpt_cuda.cu 中,这种缩放表现为字面除法:


scores[i][j] = dot(Q[i], K[j]) * (1.0f / sqrtf(d_k));

每个分数一个浮点乘法。廉价。关键。

在 d_model = 4096 时的缩放

假设一个研究团队构建了 ANDREA-2B,d_model = 4096 & n_head = 32。(a) 计算 d_k。(b) 计算 sqrt(d_k)。(c) 用一句话解释如果团队忘记在此规模下除以 sqrt(d_k) 会发生什么。

为什么位置 i 不能看到位置 j > i

生成产生的约束

ANDREA 一次生成一个 token。在推理时,位置 0 生成第一个 token,然后位置 1 查看位置 0 的输出并生成第二个 token,依此类推。模型在生成过程中永远无法访问未来的 token。


训练必须反映这一点。如果在训练期间位置 5 可以关注位置 6,模型就会学会一个捷径:“通过读取 token 6 来预测 token 6”。在推理时,这个捷径消失了(token 6 还不存在)。模型的训练与推理行为会灾难性地分歧。


一个掩码

因果掩码阻止任何位置 i 对任何位置 j > i 的关注。实现方式:在 j > i 的地方将 scaled_scores[i][j] 设置为 -infinity。经过 softmax 后,这些条目变为 exp(-inf) = 0。掩码干净地将对未来位置的关注置零。


for i in range(T):
for j in range(T):
if j > i:
scaled_scores[i][j] = -1e9   # 实际上是 -inf

经过 softmax(按行)后,每行总和为 1,但只有 [0, i] 的条目携带概率质量。位置 i 只从过去的位置混合信息。


可视化掩码

应用掩码后的得分矩阵形状 (T, T) 看起来像一个下三角结构:


掩码应用后的 scaled_scores,按行 softmax:

第 0 行:  [1.0, 0,   0,   0,   ...]   # 仅看到自身
第 1 行:  [0.4, 0.6, 0,   0,   ...]   # 看到位置 0, 1
第 2 行:  [0.2, 0.3, 0.5, 0,   ...]   # 可见 0, 1, 2
第 3 行:  [0.1, 0.2, 0.3, 0.4, ...]   # 可见 0, 1, 2, 3
...

每行严格的下三角概率分布。未来位置保持不可见。


为什么仅解码器 Transformer 需要这个

仅解码器模型如 ANDREA、GPT 和 LLaMA 都有一个共同目标:从过去预测下一个标记。因果掩码使该目标可以并行训练:每个位置同时计算自己的下一个标记预测,且没有位置通过偷看前方作弊。

掩码 & 风格

活动 2(介绍)涵盖了三种 transformer 风格:仅编码器、编码器-解码器、仅解码器。(a) 哪种风格使用因果掩码?(b) 用一句话说明为什么另一种风格(仅编码器,如 BERT)不会使用因果掩码。(c) 未掩码的编码器训练的目标是什么?

从分数到输出

Softmax:分数到概率

掩码后的缩放分数仍然在实数范围内变化。Softmax 将每一行转换为概率分布:


A[i][j] = exp(scaled_scores[i][j]) / sum_k exp(scaled_scores[i][k])

由此产生三个特性:


- 对于所有 (i, j),A[i][j] >= 0。

- 对于每一行 i,sum_j A[i][j] = 1。

- 较大的原始分数产生较大的概率(单调递增)。


第 i 行的概率向量告诉模型:计算其输出时,位置 i 应该多少关注每个先前位置?


加权 V 求和

位置 i 的最终注意力输出:


output[i] = sum_j A[i][j] · V[j]

每个值向量 V[j] 由注意力概率 A[i][j] 加权,然后求和。位置 i 的输出结合了来自每个先前位置的值向量,按相关性加权。


以矩阵形式,一次处理所有位置:


Attention(Q, K, V) = softmax(mask(Q · K^T / sqrt(d_k))) · V

一行代码。整个注意力机制。Vaswani 等人在 2017 年写下了这一行;自那时起,Transformer 的根本结构就没有改变。


每个头的输出形状

一个注意力头的输出:形状 (T, d_k)。对于 ANDREA-120M:(1024, 64)。所有 12 个头并行计算;它们的输出连接成 (1024, 768) 并输入到最终线性投影 (W_O),然后进入 transformer 块的 MLP。


活动 6 (grow_a_language_model_multi_head) 涵盖多头拆分。活动 7 (grow_a_language_model_transformer_block) 涵盖注意力周围的一切:残差连接、层归一化、MLP。

合成一个管道

用你自己的话合成一个完整的注意力管道。逐步说明从输入向量 x_5 到注意力输出[5] 的单个位置 i(例如序列中的位置 5)发生了什么。按顺序命名四个操作:(1) 投影到 Q/K/V,(2) 计算与所有位置的缩放分数,(3) 应用因果掩码 + softmax,(4) 用概率加权求和 V 向量。一段简短的段落。