Attention is all you need
Zhongjun Qiu 元婴开发者

本篇主要讲解了注意力机制(Attention Mechanism)的基本原理和计算过程,包括各种优化技巧,如缩放点积注意力(Scaled Dot-Product Attention)、自注意力(Self-Attention)、掩码自注意力(Masked Self-Attention)和多头注意力(Multi-Head Attention)。

注意力机制

🧠 什么是注意力机制

注意力机制(Attention Mechanism)最早源于计算机视觉领域,其核心思想模拟了人类的视觉系统:当我们观察一幅图像时,我们并不会平均“看待”所有像素,而是会迅速将注意力集中在重点区域,而忽略次要部分。

在自然语言处理(NLP)领域,这个思想同样至关重要。一个句子中的不同词语,其重要性也并非是均等的。通过让模型学会“关注”最相关的词(Token),我们能以更高效、更精确的方式理解和处理文本。

核心思想:从“加权求和”开始

从根本上说,注意力机制是一种动态的加权求和

它的目标是根据当前的任务“动态地”计算出一组“注意力权重”,然后用这组权重去汇集(加权求和)信息。权重越高的部分,代表模型认为它“越重要”。

深入理解:三个核心角色 Q, K, V

为了实现这种动态加权,注意力机制引入了三个核心变量:

  1. Query (Q) —— 查询值:代表当前的任务或“查询意图”。它发问:“我现在需要什么信息?”
  2. Key (K) —— 键值:代表信息库中各项内容的“索引或标签”。它回答:“我这里有什么?”
  3. Value (V) —— 真值:代表信息库中各项内容的“实际信息”。它代表:“这是你真正想要的内容。”

场景比喻:在图书馆中检索信息

想象一下,你想写一篇关于“人工智能在医疗领域的应用”的论文(这就是你的 Query)。

你来到图书馆,图书馆里有成千上万本书。为了方便检索,每本书都有一张“索引卡”,上面写着书名或关键词(这就是 Key)。而书架上对应的每一本书,是它所承载的“真正知识”(这就是 Value)。

你的检索过程如下:

  1. 匹配 (Q vs K):你拿着你的 Query(“AI医疗应用”),去和书架上每一张索引卡(Key)进行对比。
  2. 计算相关性
    • K₁:《深度学习与图像识别》 → 相关性 80%
    • K₂:《自然语言处理综述》 → 相关性 60%
    • K₃:《儿童趣味烹饪》 → 相关性 0%
  3. 计算注意力权重 (Softmax):将这些“相关性”转化为一个总和为 100% 的“注意力分配”:
    • K₁:50%
    • K₂:30%
    • K₃:0%
    • 其他:20%
  4. 加权求和 (Weights × V): 输出结果 = 50% × V₁ + 30% × V₂ + 0% × V₃ + …

你并没有完整阅读所有书籍,而是根据你的“查询意图”(Q),有选择性地“关注”了与 K 最匹配的 V。 这就是注意力机制的核心流程。

👀 注意力机制的计算步骤

在真实的模型中,Q、K、V 都不是文本,而是高维的词向量(Embedding)。向量的“方向”代表了它们的语义。

步骤 1:计算相关性 (Q · K)

我们如何计算 Q 和 K 之间的“相关性”? 在向量空间中,点积(Dot Product)是衡量“方向相似性”的绝佳工具:

  • 两个向量方向越接近(语义越相似),点积越大;
  • 两个向量方向越垂直(语义越无关),点积越接近 0。

为什么注意力计算可以通过 QKT 实现?

我们可以把整个过程理解为一次语义匹配

符号 含义 类比(搜索场景)
Q (Query) 当前要关注的词或片段 搜索关键词
K (Key) 序列中所有词的表示 候选文档索引
QKᵀ Query 与每个 Key 的相似度 搜索相关性打分

假设我们有一个 Query 向量 q 和 3 个 Key 向量 k1, k2, k3,将它们堆叠成矩阵 K。 那么 qKT 本质上就是同时计算 q 与所有 ki 的点积,得到一组“原始相关性分数”: x = qKT = [q ⋅ k1, q ⋅ k2, q ⋅ k3]

步骤 2:归一化为权重 (Softmax)

得到的原始分数 x(例如 [10.5, 3.2, −5.1])并不能直接用作权重。 我们需要一个函数将它们转换为“总和为 1”且“非负”的概率分布。

这就是 Softmax 函数的作用:

经过 Softmax,分数 [10.5, 3.2, −5.1] 可能变成 [0.99, 0.01, 0.00],这组清晰的权重就是“注意力分数”。

步骤 3:加权求和 (Weights · V)

最后,我们将“注意力分数”与对应的 Value (V) 相乘再求和:

Attention = softmax(QKT) × V

这实现了最终目的:根据 Q 与 K 的相关性,对 V 进行加权汇总。

🎯 关键优化:缩放点积注意力 (Scaled Dot-Product Attention)

在实际应用中,当向量的维度 dk 很大时(如 512),QKT 的结果数值可能非常大。 这会导致 Softmax 饱和

  • 输入为 [1, 2, 3] → 输出 [0.09, 0.24, 0.67](平滑)
  • 输入为 [10, 20, 30] → 输出 [0.00, 0.00, 1.00](“赢家通吃”)

Softmax 饱和会使梯度几乎为 0(梯度消失),模型将无法学习。

解决方案: 在送入 Softmax 前,将 QKT 的结果除以一个“缩放因子” (词向量维度)。

这能控制方差,使 Softmax 输出保持在可学习的范围内。

最终公式为:

深入探讨:为什么必须除以 ?而不是其他的数值

“Scaled” 的核心动机是 控制数值范围,让 Softmax 不“失控”

一、统计直观理解

假设两个向量 qk 的每个分量独立且服从均值为 0、方差为 1 的正态分布。

则它们的点积为: 根据统计规律: Var(q ⋅ k) = dk

维度 (dk) 点积标准差 () 点积范围(约)
8 2.8 -5 ~ +5
64 8 -15 ~ +15
512 22.6 -40 ~ +40

dk 很大时,Softmax 输入范围巨大,e40 级别的值会直接导致饱和。

解决方案:除以 将方差重新归一化到 1。

二、Python 实验:Softmax 饱和现象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np
import torch
import torch.nn.functional as F

torch.manual_seed(42)

# 模拟未缩放的点积
def simulate_attention_score(dk, n_keys=5):
q = torch.randn(dk)
K = torch.randn(n_keys, dk)
scores = q @ K.T
probs = F.softmax(scores, dim=-1)
return scores.detach().numpy(), probs.detach().numpy()

for dk in [8, 64, 512]:
scores, probs = simulate_attention_score(dk)
print(f"--- d_k = {dk} (未缩放) ---")
print("原始分数:", np.round(scores, 2))
print("Softmax概率:", np.round(probs, 4))
print()
1
2
3
4
5
6
7
8
9
10
11
--- d_k = 8 (未缩放) ---
原始分数: [ 0.84 -0.79 1.12 -0.22 0.15]
Softmax概率: [0.2721 0.0672 0.3531 0.1158 0.1918]

--- d_k = 64 (未缩放) ---
原始分数: [ 6.2 -11.4 13.5 -7.2 4.1 ]
Softmax概率: [0.0002 0. 0.9995 0. 0.0003]

--- d_k = 512 (未缩放) ---
原始分数: [ 49.3 -30.2 11.5 -15.9 -20.3 ]
Softmax概率: [1. 0. 0. 0. 0.]
维度 是否缩放 Softmax 输出 现象
8 [0.27, 0.07, 0.35, 0.11, 0.20] 平滑可学
64 [0.00, 0.00, 0.99, 0.00, 0.00] 饱和危险
512 [1, 0, 0, 0, 0] 完全独裁 ⚠️

三、加上缩放项

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 模拟缩放后的点积
def simulate_scaled_attention_score(dk, n_keys=5):
q = torch.randn(dk)
K = torch.randn(n_keys, dk)
scores = (q @ K.T) / np.sqrt(dk)
probs = F.softmax(scores, dim=-1)
return scores.detach().numpy(), probs.detach().numpy()

for dk in [8, 64, 512]:
scores, probs = simulate_scaled_attention_score(dk)
print(f"--- d_k = {dk} (已缩放) ---")
print("缩放后分数:", np.round(scores, 2))
print("Softmax概率:", np.round(probs, 4))
print()
1
2
3
4
5
6
7
8
9
10
11
--- d_k = 8 (已缩放) ---
缩放后分数: [ 0.27 -0.43 0.36 -0.14 0.1 ]
Softmax概率: [0.2336 0.1085 0.2573 0.1594 0.2412]

--- d_k = 64 (已缩放) ---
缩放后分数: [ 0.84 -0.65 1.22 -0.82 0.46]
Softmax概率: [0.2472 0.0913 0.3558 0.0796 0.2261]

--- d_k = 512 (已缩放) ---
缩放后分数: [ 1.23 -0.41 0.91 -0.62 0.51]
Softmax概率: [0.2907 0.1122 0.2499 0.0956 0.2516]
维度 是否缩放 Softmax 输出 效果
8 [0.23, 0.11, 0.26, 0.16, 0.24] ✅ 完美平滑
64 [0.25, 0.09, 0.36, 0.08, 0.22] ✅ 完美平滑
512 [0.29, 0.11, 0.25, 0.09, 0.26] ✅ 完美平滑

结论: 加上 缩放后,Softmax 输出保持平衡且可学习,梯度稳定,模型训练顺畅

朴素注意力机制的实现

基于上文,我们可以使用 Pytorch 来实现注意力机制的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn.functional as F
import math

def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
'''
Args:
query: 查询值矩阵 (B, L, dim)
key: 键值矩阵 (B, S, dim)
value: 真值矩阵 (B, S, dim)
'''
dim = query.size(-1)

# 1. 计算 Q 与 K 的内积并缩放 (B, L, dim) @ (B, dim, S) -> (B, L, S)
# key.transpose(-2, -1) 将 key 的形状从 (B, S, dim) 变成 (B, dim, S)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(dim)

# 2. Softmax (在最后一个维度 S 上进行,得到注意力权重)
p_attn = scores.softmax(dim=-1)


# 3. 加权求和 (B, L, S) @ (B, S, dim) -> (B, L, dim)
return torch.matmul(p_attn, value), p_attn

# --- 创建输入张量 ---

B = 2 # Batch Size
L = 3 # Query 序列长度
S = 8 # Key/Value 序列长度
dim = 5 # 特征维度 dim

# 随机生成输入张量
query = torch.randn(B, L, dim)
key = torch.randn(B, S, dim)
value = torch.randn(B, S, dim)

output, attn_weights = attention(query, key, value)

print("--- 输入张量形状 ---")
print(f"Query (Q) 形状: {query.shape}")
print(f"Key (K) 形状: {key.shape}")
print(f"Value (V) 形状: {value.shape}")
print("-" * 20)

print("--- 输出张量形状 ---")
print(f"Attention Output 形状: {output.shape}")
print(f"Attention Weights 形状: {attn_weights.shape}")
print("-" * 20)

# 打印第一个 Batch 的注意力权重
print("Batch 0的注意力权重 (L=3行, S=4列):")
print(torch.round(attn_weights[0], decimals=4))
1
2
3
4
5
6
7
8
9
10
11
12
13
--- 输入张量形状 ---
Query (Q) 形状: torch.Size([2, 3, 5])
Key (K) 形状: torch.Size([2, 8, 5])
Value (V) 形状: torch.Size([2, 8, 5])
--------------------
--- 输出张量形状 ---
Attention Output 形状: torch.Size([2, 3, 5])
Attention Weights 形状: torch.Size([2, 3, 8])
--------------------
Batch 0的注意力权重 (L=3行, S=4列):
tensor([[0.0617, 0.0692, 0.2200, 0.1603, 0.1279, 0.0740, 0.1822, 0.1048],
[0.0373, 0.0372, 0.2025, 0.1016, 0.1226, 0.1371, 0.2369, 0.1247],
[0.0964, 0.1245, 0.1140, 0.0532, 0.2114, 0.1940, 0.1347, 0.0718]])

这里query = torch.randn(B, L, dim)默认传入的都是词向量,相当于有B个句子,每个句子里有L个token,每个token由dim维向量表示。

Attention Output 形状: torch.Size([2, 3, 5]):表示的是经过注意力计算以后,原始的Query形状不变,但每个句子里的每个token的表示都已经包含了K和V的信息。

Attention Weights 形状: torch.Size([2, 3, 8]):表示的是Query中句子的3个token分别对于Key中8个token的注意力分数是什么。

自注意力(Self-Attention)

在前面的分析中我们提到,注意力机制(Attention) 的核心思想是:

计算两段序列中每个元素之间的相似度,从而找到一个序列中每个元素对另一个序列中各元素的相关程度,再根据这些相关度对输入进行加权,分配注意力权重。

也就是说,注意力机制本质上是一个“相似度分配器”。

在经典的注意力机制中,Query(Q)和 Key(K)、Value(V)来自不同的序列:

  • Q 通常来自 目标序列(待预测序列)
  • K 和 V 通常来自 源序列(输入序列)

这种机制非常适合用于 Encoder-Decoder 架构,例如 Transformer 的 Decoder 模块

  • Q 来自 Decoder 的输入;
  • K 和 V 来自 Encoder 的输出;
  • 从而实现编码信息(输入语义)与历史信息(已生成内容)的融合。

自注意力的概念

在 Transformer 的 Encoder 模块 中,使用的是注意力机制的一个变种——自注意力(Self-Attention)

顾名思义,“自注意力”即模型在同一序列内部,计算每个 token 对序列中其他 token 的注意力分布。 因此,在自注意力中:

  • Q、K、V 都由同一个输入序列经过不同的线性变换得到:

Q = XWQ,  K = XWK,  V = XWV

通过自注意力机制,模型能够捕捉句子内部的依赖关系,例如:

  • “The animal didn’t cross the street because it was too tired.” 模型能理解 it 指代 animal 而不是 street,就是自注意力在起作用。

代码上,自注意力的实现非常简单,只需将输入重复传入三次:

1
attention(x, x, x)

这表示 Q、K、V 全部来源于同一个输入 x

掩码自注意力(Masked Self-Attention)

掩码自注意力(Masked Self-Attention)是在自注意力的基础上,加入掩码(Mask)的机制,用来遮蔽部分 token,使模型只能看到过去的词,而不能看到未来的词

这种做法的主要目的,是在训练语言模型时保证“因果性(Causality)”,即模型在预测第 t 个词时,只能依赖于 1 ∼ t − 1 的词。

为什么需要 Mask

Transformer 语言模型通常通过预测下一个 token 来学习语言规律。 例如,对于句子:

1
<BOS> I like you <EOS>

模型的学习目标是:

  1. 输入 <BOS>,预测 I
  2. 输入 <BOS> I,预测 like
  3. 输入 <BOS> I like,预测 you
  4. 输入 <BOS> I like you,预测 <EOS>

这显然是一个串行过程,无法并行化。

为了让模型在训练时能够同时处理整个序列,Transformer 引入了 Mask 矩阵 —— 通过遮蔽未来的词,让模型虽然一次性看到整个序列,但在注意力计算时仍然只“关注”历史信息。

掩码矩阵的结构

Mask 矩阵本质上是一个上三角矩阵(上三角区域为 0,下三角为 1):

Query 1 2 3 4
1
2
3
4

这个掩码矩阵控制了每个 Query token 能看到的 Key token 范围。

1
2
3
4
5
# 创建下三角矩阵(上三角为 0,下三角含对角线为 1)
mask = torch.tril(torch.ones(1, 1, size, size))

# 将被遮蔽位置替换为负无穷
attention = attention.masked_fill(mask == 0, float('-inf'))
  • masked_fill 会将矩阵中掩码为 0 的部分填充为 -inf

为什么是负无穷?

注意力得分矩阵 S 在计算后会经过 Softmax: 如果将某个位置设为 −∞,那么: e−∞ = 0 即在 Softmax 之后,该位置的注意力权重变为 0,模型将完全忽略未来信息。

这种掩码也被称为:

  • Look-Ahead Mask(前瞻掩码)
  • Causal Mask(因果掩码)

多头注意力(Multi-Head Attention)

单头注意力只能学习一种“关系模式”,例如句法关系或语义关系。 但语言的关联是多层次、多维度的,因此 Transformer 提出了 多头注意力机制(Multi-Head Attention)

MultiHead(Q, K, V) = Concat(head1, ..., headh)WO

其中每个头的计算如下:

headi = Attention(QWiQ, KWiK, VWiV)

每个头都使用不同的线性变换矩阵 WiQ, WiK, WiV,从而学习不同层次的特征。最后再拼接结果并通过 WO 融合输出。

主要是输入仍是一样(B,S,dim),但dim要切分成num_head份,每个head有dim_head=dim/num_head个维度。

然后就是做一次reshape和transpose:(B,S,dim)->(B,S,num_head, dim_head)->(B, num_head, S, dim_head)。这样前两个维度可以不管,当做并行计算的。后两个维度跟之前的一样,序列长度+词向量维度。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class MHA(nn.Module):
"""Multi-Head Attention"""
def __init__(self, args):
super().__init__()
self.dim = args.dim
self.num_head = args.num_head
self.dim_head = self.dim // self.num_head
self.dropout = args.dropout
self.bias = args.bias

self.W_q = nn.Linear(self.dim, self.dim, bias=self.bias)
self.W_k = nn.Linear(self.dim, self.dim, bias=self.bias)
self.W_v = nn.Linear(self.dim, self.dim, bias=self.bias)
self.W_o = nn.Linear(self.dim, self.dim)
self.atten_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)

def forward(self, q, k, v, mask):
B = q.size(0)

Q = self.W_q(q)
K = self.W_k(k)
V = self.W_v(v)

# (B, S, dim) → (B, num_head, S, dim_head)
Q = Q.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)
K = K.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)
V = V.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)

attention = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
attention = attention.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(attention, dim=-1)
attention = self.atten_dropout(attention)

out = attention @ V
out = out.permute(0, 2, 1, 3).reshape(B, -1, self.dim)
out = self.resid_dropout(self.W_o(out))

return out

补充:Tensor 的连续性与内存布局

在 PyTorch 中,Tensor 在内存中通常是行优先(row-major)存储的。也就是说,最后一维的数据在内存中是连续排列的。某些操作(如 transpose()permute())并不会真正移动数据,而是仅改变了视图中的步长(stride)。 因此,这类操作后的 Tensor 在内存中可能不再连续

当我们调用 view() 时,它要求底层内存是连续的;否则会报错。 而 reshape() 则更智能一些:

  • 如果内存是连续的,它等价于 view()
  • 如果内存不连续,它会自动复制数据重新布局,保证结果正确。

因此,在多头注意力的实现中,我们通常使用 reshape() 而不是 view(),以避免内存布局问题

在 PyTorch 中,每个 Tensor 背后都对应着一块实际的内存区域,用来存储元素的值。

1. 行优先存储(Row-major Order)

PyTorch 默认采用 行优先(row-major) 的存储方式(与 NumPy 一致),也就是最后一维的数据在内存中是连续的

假设我们有一个二维 Tensor:

1
2
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])

它的形状为 (2, 3)。 在内存中,它的线性排列顺序是:

1
[1, 2, 3, 4, 5, 6]

我们可以通过 .storage() 来查看底层存储:

1
2
>>> x.storage()
1, 2, 3, 4, 5, 6

这就是所谓的 行优先(Row-major) 存储:

  • 先存完第 0 行 [1, 2, 3]
  • 再存第 1 行 [4, 5, 6]

2. 步长(Stride)的含义

每个 Tensor 除了数据本身外,还有一个重要的属性:stride(步长)

Stride 描述的是:

在访问某一维上相邻两个元素时,在内存中需要跳过多少个元素。

我们来看看刚才的 x

1
2
>>> x.stride()
(3, 1)

解释如下:

  • 第 0 维(行)步长为 3:要移动到下一行,需跳过 3 个元素;
  • 第 1 维(列)步长为 1:移动到下一个列元素,只需跳过 1 个元素。

可视化理解如下:

索引 x[i, j] 内存偏移量
(0, 0) 1 0 × 3 + 0 × 1 = 0
(0, 1) 2 0 × 3 + 1 × 1 = 1
(0, 2) 3 0 × 3 + 2 × 1 = 2
(1, 0) 4 1 × 3 + 0 × 1 = 3
(1, 1) 5 1 × 3 + 1 × 1 = 4
(1, 2) 6 1 × 3 + 2 × 1 = 5

这正对应 [1, 2, 3, 4, 5, 6] 的内存顺序。


3. Transpose 改变了什么?

现在我们对这个 Tensor 做一次转置操作:

1
y = x.transpose(0, 1)

此时 y 的形状变为 (3, 2),但我们看看 stride:

1
2
>>> y.stride()
(1, 3)

含义是:

  • 沿第 0 维(原列)移动时,内存跳 1;
  • 沿第 1 维(原行)移动时,内存跳 3。

换句话说,transpose() 只是改变了 stride 的定义,而没有真正移动数据。 底层内存依然是 [1, 2, 3, 4, 5, 6],只是 PyTorch 解释数据的方式变了。

因此:

1
2
>>> y.is_contiguous()
False

说明转置后的 Tensor 不再是连续的


4. 连续(contiguous)与非连续(non-contiguous)

当一个 Tensor 在所有维度上满足“后一维的步长等于前一维的长度乘以前一维的步长”时,它就是连续的。

否则,Tensor 就是非连续的。

连续的 Tensor 意味着:你可以把它当成一段紧密排列的内存块直接读取。

非连续的 Tensor 意味着:你必须用 stride 计算偏移量才能取值

 REWARD AUTHOR
 Comments
Comment plugin failed to load
Loading comment plugin