More Attention is all you need
Zhongjun Qiu 元婴开发者

本章在之前对注意力机制的基础上,介绍几种改进的注意力机制,包括多查询注意力(MQA)、分组查询注意力(GQA)以及FlashAttention等技术,这些技术在提升模型性能和推理效率方面发挥了重要作用。

1. 原始注意力机制(Scaled Dot-Product Attention)

🧠 核心思想

通过计算 Query 与 Key 的点积相似度,得到每个位置对序列中其他位置的注意力权重,并用该权重加权 Value,得到上下文信息。

🧩 公式

🔍 步骤

  1. 对输入向量 q, k, v 分别映射为 Q, K, V;
  2. 计算注意力得分矩阵;
  3. 经过 softmax 归一化;
  4. 加权求和得到输出。

💡 特点

  • 全局注意力,全序列交互;
  • 复杂度 O(n2)n为序列的长度;
  • 高性能但难以扩展到长序列。

2. 多头注意力(Multi-Head Attention, MHA)

🧠 核心思想

将注意力分成多个头(head),让模型学习不同的语义子空间表示。

🧩 公式

MHA(Q, K, V) = Concat(head1, …, headh)WO

其中 headi = Attention(QWiQ, KWiK, VWiV)

💡 特点

  • 多头捕获不同关系模式;
  • 并行性强;
  • 缺点:推理阶段每个 head 都要维护独立的 KV 缓存,显存开销大。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class MHA(nn.Module):
"""
Multi-Head Attention Module With KV-Cache

Args:
args: ModelArgs
model arguments
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.num_head = args.num_head
assert self.dim % self.num_head == 0
self.dim_head = self.dim // self.num_head
self.dropout = args.dropout
self.max_seq_len = args.max_seq_len

self.W_q = nn.Linear(self.dim, self.dim, bias=False)
self.W_k = nn.Linear(self.dim, self.dim, bias=False)
self.W_v = nn.Linear(self.dim, self.dim, bias=False)
self.atten_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.W_o = nn.Linear(self.dim, self.dim, bias=False)

def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: # return (out, kv_cache)
"""
q: (B, S_q, dim)
S_q may be 1 (in inference) or T (in training)
k: (B, S_k, dim)
v: (B, S_v, dim)
S_k and S_v may be 1 (in inference) or T (in training)
S_k and S_v must be equal, but may differ from S_q
mask: (B, 1, S_q, S_k)
kv_cache: ( (B, H, S_past_k, dim_head), (B, H, S_past_v, dim_head) )
"""
B = q.size(0)
Q: Tensor = self.W_q(q)
K: Tensor = self.W_k(k)
V: Tensor = self.W_v(v)

# Reshape (B, S, dim) -> (B, H, S, dim_head), H * dim_head = dim
Q = Q.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)
K_new = K.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)
V_new = V.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)

if kv_cache is not None:
# In inference mode
# past_K/V: (B, H, S_past, dim_head)
past_K, past_V = kv_cache
# (B, H, S_past, dim_head) + (B, H, S_k, dim_head) -> (B, H, S_past + S_k, dim_head)
K = torch.cat([past_K, K_new], dim=2)
V = torch.cat([past_V, V_new], dim=2)
# new cache
kv_cache = (K, V)
else:
# In training mode
K = K_new
V = V_new

# Q @ K.T -> (B, H, S_q, S_k_total)
# Inference: (S_q=1, S_k_total=i): (B,H,1,dim) @ (B,H,dim,i) -> (B,H,1,i)
# Training: (S_q=T, S_k_total=T): (B,H,T,dim) @ (B,H,dim,T) -> (B,H,T,T)
attention = Q @ K.transpose(-1, -2) / math.sqrt(self.dim_head)
if mask is not None:
attention = attention.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(attention.float(), -1).to(q.dtype)
attention: Tensor = self.atten_dropout(attention)
# Inference: (B,H,1,i) @ (B,H,i,dim) -> (B,H,1,dim)
# Training: (B,H,T,T) @ (B,H,T,dim) -> (B,H,T,dim)
out = attention @ V

# (B, H, S_q, dim_head) -> (B, S_q, H, dim_head) -> (B, S_q, dim)
out = out.permute(0, 2, 1, 3).reshape(B, -1, self.dim)
out = self.resid_dropout(self.W_o(out))

return out, kv_cache

如上述代码所示,维护的kv cache结构是 ((B, H, Spastk, dimhead), (B, H, Spastv, dimhead)),即对于每一个head来说,都需要维护之前所有token对应的K和V,显存开销极大。

3. 单查询多头注意力(Multi-Query Attention, MQA)

🧠 核心思想

所有 Query 头共享一组 Key/Value 投影,用于加速推理。

🧩 公式

Qi = XWiQ,  K = XWK,  V = XWV

💡 优点

  • 显存节省:KV 缓存只存一份;
  • 推理速度显著提高;
  • 参数量与 MHA 几乎相同。

⚠️ 缺点

  • 不同 head 共用同一 K/V,表达能力降低;
  • 模型多样性下降。

🔗 应用

  • Google PaLM
  • Efficient Transformers
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class MQA(nn.Module):
"""
Multi-Query Attention Module With KV-Cache

Args:
args: ModelArgs
model arguments
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.num_head = args.num_head
assert self.dim % self.num_head == 0
self.dim_head = self.dim // self.num_head
self.dropout = args.dropout
self.max_seq_len = args.max_seq_len

self.W_q = nn.Linear(self.dim, self.dim, bias=False)
# k/v share the same projection in MQA
# Thus, the number of k/v projection matrices is 1 instead of num_head
self.W_k = nn.Linear(self.dim, self.dim_head, bias=False)
self.W_v = nn.Linear(self.dim, self.dim_head, bias=False)
self.atten_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.W_o = nn.Linear(self.dim, self.dim, bias=False)

def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: # return (out, kv_cache)
"""
q: (B, S_q, dim)
S_q may be 1 (in inference) or T (in training)
k: (B, S_k, dim)
v: (B, S_v, dim)
S_k and S_v may be 1 (in inference) or T (in training)
S_k and S_v must be equal, but may differ from S_q
mask: (B, 1, S_q, S_k)
kv_cache: ( (B, S_past_k, dim_head), (B, S_past_v, dim_head) )
"""
B = q.size(0)
Q: Tensor = self.W_q(q)
K: Tensor = self.W_k(k)
V: Tensor = self.W_v(v)

# Reshape (B, S, dim) -> (B, H, S, dim_head)
Q = Q.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)
# (B, S, dim_head)
K_new, V_new = K, V
# (B, 1, S, dim_head)
K = K.reshape(B, -1, 1, self.dim_head).permute(0, 2, 1, 3)
V = V.reshape(B, -1, 1, self.dim_head).permute(0, 2, 1, 3)

if kv_cache is not None:
# In inference mode
past_K, past_V = kv_cache
K_new = torch.cat([past_K, K_new], dim=1)
V_new = torch.cat([past_V, V_new], dim=1)
kv_cache = (K_new, V_new)

attention = Q @ K.transpose(-1, -2) / math.sqrt(self.dim_head)
if mask is not None:
attention = attention.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(attention.float(), -1).to(q.dtype)
attention: Tensor = self.atten_dropout(attention)
out = attention @ V

out = out.permute(0, 2, 1, 3).reshape(B, -1, self.dim)
out = self.resid_dropout(self.W_o(out))

return out, kv_cache

如上述代码所示,kv cache的形状是( (B, S_past_k, dim_head), (B, S_past_v, dim_head) ),与H无关。

相当于有H个head,但只需要保存一份KV。

4. 分组查询注意力(Grouped-Query Attention, GQA)

🧠 核心思想

介于 MHA 与 MQA 之间的折中: 多个 Query head 共享同一组 K/V(按组共享)。

🧩 公式

Qi = XWiQ,  (Ki, Vi) = XWjK, XWjV,  i ∈ groupj

💡 特点

  • 内存效率:KV 缓存数量为组数 g

  • 灵活性高:可调节组数 g 控制性能与显存之间的折中;

  • 兼容性强:几乎可直接替换 MHA 层。

🔗 应用

  • LLaMA 2 / 3
  • Mistral
  • Gemma
  • Falcon 180B
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class GQA(nn.Module):
"""
Group-Query Attention Module With KV-Cache

Args:
args: ModelArgs
model arguments
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.num_head = args.num_head
self.num_group = args.num_group
assert self.num_group <= self.num_head and self.num_head % self.num_group == 0
assert self.dim % self.num_head == 0
self.dim_head = self.dim // self.num_head
self.dropout = args.dropout
self.max_seq_len = args.max_seq_len

self.W_q = nn.Linear(self.dim, self.dim, bias=False)
# k/v share the same projection in MQA
# Thus, the number of k/v projection matrices is num_gruop instead of num_head
self.W_k = nn.Linear(self.dim, self.num_group * self.dim_head, bias=False)
self.W_v = nn.Linear(self.dim, self.num_group * self.dim_head, bias=False)
self.atten_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.W_o = nn.Linear(self.dim, self.dim, bias=False)


def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: # return (out, kv_cache)
"""
q: (B, S_q, dim)
S_q may be 1 (in inference) or T (in training)
k: (B, S_k, dim)
v: (B, S_v, dim)
S_k and S_v may be 1 (in inference) or T (in training)
S_k and S_v must be equal, but may differ from S_q
mask: (B, 1, S_q, S_k)
kv_cache: ( (B, num_group, S_past_k, dim_head), (B, num_group, S_past_v, dim_head) )
"""
B = q.size(0)
Q: Tensor = self.W_q(q)
K: Tensor = self.W_k(k)
V: Tensor = self.W_v(v)

# Reshape (B, S, dim) -> (B, H, S, dim_head)
Q = Q.reshape(B, -1, self.num_head, self.dim_head).permute(0, 2, 1, 3)
# (B, num_group, S, dim_head)
K_new = K.reshape(B, -1, self.num_group, self.dim_head).permute(0, 2, 1, 3)
V_new = V.reshape(B, -1, self.num_group, self.dim_head).permute(0, 2, 1, 3)

if kv_cache is not None:
# In inference mode
past_K, past_V = kv_cache
# (B, num_group, S_past, dim_head) + (B, num_group, S_k, dim_head)
# -> (B, num_group, S_past + S_k, dim_head)
K = torch.cat([past_K, K_new], dim=2)
V = torch.cat([past_V, V_new], dim=2)
else:
# In training mode
K = K_new
V = V_new
kv_cache = (K, V)

# expand K/V to (B, H, S, dim_head)
K = K.repeat_interleave(self.num_head // self.num_group, dim=1)
V = V.repeat_interleave(self.num_head // self.num_group, dim=1)

attention = Q @ K.transpose(-1, -2) / math.sqrt(self.dim_head)
if mask is not None:
attention = attention.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(attention.float(), -1).to(q.dtype)
attention: Tensor = self.atten_dropout(attention)
out = attention @ V

out = out.permute(0, 2, 1, 3).reshape(B, -1, self.dim)
out = self.resid_dropout(self.W_o(out))

return out, kv_cache

代码中主要是kv cahce部分,只需要存储num_group个KV实例,一般num_group为4,8,16。

并且当num_group = num_head时,就等价于MHA;

当num_group = 1时,就等价于MQA。

MHA MQA GQA MLA 对比

image

如图所示,前三种方法的核心区别在于查询头 (Q) 与键/值头 (K/V) 之间的关系,这直接决定了推理时 KV Cache 的大小:

  • MHA (多头注意力): 这是标准配置。如图所示,每一个 Q 头都一对一匹配一组独立的 K/V 头。如果有 8 个 Q 头,就有 8 组 K/V。这提供了最强的模型表达能力,但代价是 KV 缓存最大
  • MQA (多查询注意力): 这是一种激进的优化。所有的 Q 头共享同一组 K/V 头。这使得 KV 缓存降到了最低(仅 1 组),极大提升了推理速度和吞吐量,但可能会牺牲一些模型精度。
  • GQA (分组查询注意力): 这是 MHA 和 MQA 之间的完美平衡。它将 Q 头分组(如图中 2 个 Q 为一组),组内共享一组 K/V 头。这在实现接近 MHA 质量的同时,将 KV 缓存显著减小(例如,从 8 组降到 4 组),是目前(如 Llama 2/3)的主流选择。

第四种:多头潜在注意力 (Multi-Head Latent Attention, MLA),与前三种方法通过 减少 K/V 头的数量 来节省显存的思路不同,MLA 提出了一种全新的优化哲学:压缩 K/V 缓存的表示

MLA 假设,由 MHA 生成的完整 K/V 状态是高度冗余的。我们不需要通过“共享”或“分组”来减少头的数量,而是可以学习一个更小、信息密度更高的“潜在表示”(Latent Representation)来作为 K/V 缓存的“摘要”。

从图示中可以清晰地看到:

  1. 模型在内部依然计算出了完整的、与 MHA 相同的多组 K/V 头(图中有 8 组)。
  2. 在将它们写入缓存之前,这些完整的 K/V 头会经过一个“投影 (projection)”步骤。
  3. 这个投影步骤会生成一个“压缩的潜在 KV (Compressed Latent KV)”。
  4. 只有这个压缩后的表示才会被真正缓存起来,用于后续的推理步骤。

这种方法的实现通常涉及以下几个步骤:

  • 投影 (Projection): 这个“投影”层通常是一个可学习的参数矩阵(例如一个简单的线性层或小型前馈网络)。它的作用是降低维度。例如,它可以将 8 个 K/V 头的表示“浓缩”成 2 个“潜在” K/V 头的表示,或者保持头的数量不变,但减小每个头的特征维度。
  • 缓存 (Caching): 在推理的每一步,模型计算出 N 组 K/V 状态,通过这个投影层,得到一个更小的 L 组(或维度更低)的潜在 K/V 状态。这个潜在状态被写入缓存,从而节省了大量 VRAM。

MLA 的优势在于,它理论上保留了所有 MHA 原始 K/V 头的信息(通过学习如何压缩它们),而不是像 GQA 或 MQA 那样在结构上就减少了 K/V 头的数量。

最后,我们可以用一个表格来清晰地总结这四种方法在推理时对 KV 缓存的处理:

方法 K/V 头配置 KV 缓存大小 核心思想
MHA N 组 Q, N 组 K/V 最大 (N 倍) 不优化:每个 Q 头有独立的 K/V
MQA N 组 Q, 1 组 K/V 最小 (1 倍) 全局共享:所有 Q 头共享 1 组 K/V
GQA N 组 Q, G 组 K/V 中等 (G 倍) 分组共享:每组 Q 头共享 1 组 K/V
MLA N 组 Q, N 组 K/V -> 压缩 小 (L 倍) 压缩:缓存 K/V 状态的一个低维“潜在”投影

5. FlashAttention(高效注意力计算)

前面我们讨论的 MQA, GQA 等方法,主要优化的是推理时 KV Cache 的显存占用。而 FlashAttention 则着眼于一个更根本的问题:注意力计算本身的速度和显存瓶颈

它是一种 I/O 感知(I/O-aware)的注意力算法,不修改注意力计算的数学本质,但通过巧妙的工程实现,使其在 GPU 上的运行速度更快、显存占用更低

1. 核心瓶颈:内存带宽 (The I/O Bottleneck)

我们先回顾标准注意力的计算:O = Softmax(QKT)V

在 GPU 上执行此操作的标准方法(例如 PyTorch 的 nn.MultiHeadAttention)存在一个巨大瓶颈:

  1. 计算 S = QKT: 假设序列长度为 N,特征维度为 d。这会产生一个巨大的中间矩阵 S,大小为 N × N
  2. HRAM 读写: 这个 N × N 的矩阵 S 必须被写入 GPU 的 HRAM(高带宽内存,即 VRAM 显存),然后再由 Softmax 操作读出。
  3. Softmax 操作: 计算 Softmax(S),结果再次写入 HRAM。
  4. 计算 O = Softmax(S)V: 再次从 HRAM 中读取 Softmax(S),与 V 相乘,最后将结果 O 写入 HRAM。

问题在于:

  • HRAM 很慢: 相比于 GPU 核心的片上 SRAM(超高速缓存),HRAM 的读写速度要慢几个数量级。
  • I/O 成为瓶颈: 注意力计算变成了内存带宽限制 (memory-bound) 的操作。GPU 核心(ALU)大部分时间都在“等待”数据从 HRAM 传来或写入 HRAM,而不是在真正地“计算”(FLOPs)。
  • 显存占用: 当序列长度 N 很长时(例如 8K、16K),这个 N × NS 矩阵本身就会占用巨量的显存(例如,N=8K 时,一个 float32 矩阵 S 就需要 8192 × 8192 × 4 bytes ≈ 256 MB,如果有 32 个头,就是 8 GB,这还只是一个层!)。

2. FlashAttention 的核心原理

FlashAttention 的核心思路是:彻底避免将 N × N 的注意力矩阵 S 写入 HRAM。

它通过两大技术实现了这一点:

  1. Kernel Fusion (核函数融合):

    FlashAttention 将标准注意力中的多次 HRAM 读写(QKT、Softmax、OV)融合成一个单独的 GPU Kernel(核函数)。

    这个 Kernel 一次性从 HRAM 中读取 Q, K, V,然后将所有中间计算(S 矩阵、Softmax)全部在高速的 SRAM 中完成,最后只将最终的输出 O 写回到 HRAM。

  2. Tiling (分块/瓦片):

    SRAM 的容量非常小(例如,NVIDIA A100 只有 192KB/SM),无法容纳整个 N × N 的矩阵。

    FlashAttention 使用分块策略:它将 Q, K, V 矩阵分割成更小的块(Tiles)。Kernel 加载 Q 的一个块,然后迭代地加载 KV 的相应块。

    一个形象的比喻:

    • 标准注意力:就像一个厨师(GPU 核心)。做一道菜(注意力计算),每切完一种配料(QKT),就把它放回 500 米外的冰箱(HRAM),要用时(Softmax)再跑 500 米去取,做完下一步(Softmax(S))再放回冰箱,最后(OV)再跑 500 米去取… 厨师大部分时间都在跑步。
    • FlashAttention:厨师(GPU 核心)把做 一份菜 所需的所有配料(Q, K, V 的一个“块”)一次性拿到自己的小工作台(SRAM)上。在工作台上完成所有步骤(S 块、Softmax、 O 块),直到做出最终的成品(O 块)的一部分,再把它放到餐桌上(HRAM)。厨师几乎不需要离开工作台。

3. 关键实现:分块与在线 Softmax

这里最大的技术难点是 Softmax。Softmax 需要对 一整行 数据进行归一化(xi/∑(xj))。如果我只加载了 S 矩阵的一个块(Sij),我怎么知道这一行的“最大值”和“总和”呢?

FlashAttention 使用了一种Online Softmax的数值稳定算法:

  1. 分块计算: 假设 Q 块与 K 的第 1 个块 K1 计算,得到了 S 矩阵的第 1 个块 S1
  2. 计算局部 Softmax:在 SRAM 中计算 S1 的 Softmax(并记录当前的行最大值 m1 和行总和 l1),然后与 V1 相乘,得到一个临时的输出 O1
  3. 迭代更新: 当 Q 块与 K 的第 2 个块 K2 计算,得到 S2
  4. 寻找新全局值: 找到 S1S2 共同的行最大值 mnew
  5. 重新缩放 (Rescale):
    • 缩放旧值: 将之前计算的 O1 按照新的最大值 mnew 和旧的最大值 m1 之间的差值进行缩放。
    • 计算新值: 正常计算 S2 的 Softmax(使用 mnew)并与 V2 相乘,得到 O2
  6. 累加: 最终的输出块是 O = (scaled O1) + O2

通过这种方式,FlashAttention 可以在不访问完整 S 矩阵的情况下,分块迭代地计算出与标准 Softmax 完全相同的精确结果。

此外,在训练(反向传播)时,它不在前向传播中存储 N × NS 矩阵,而是在反向传播中从 Q, K 重新计算 S 的块。这是一种典型的“用计算换带宽”的策略,因为在 SRAM 中的重计算(FLOPs)远比从 HRAM 读写(I/O)要快。

4. 核心优势 (The Advantages)

  1. 更快的速度 (Faster): FlashAttention 将注意力计算从 I/O 限制转变为计算限制 (FLOPs-bound)。由于它极大地减少了对 HRAM 的读写次数,其速度比标准注意力快得多(论文中提到最高 7.6 倍,FlashAttention-2 中更快)。
  2. 更低的显存 (Memory Efficient):
    • 训练时: 不再需要物化 N × N 的注意力矩阵。这是最大的胜利。这使得模型可以训练更长的序列(Long Context)而不会 OOM (Out of Memory)。
    • 推理时: 在处理长 Prompt(Prefill 阶段)时,FlashAttention 同样因为不产生 N × N 矩阵而节省了大量显存并极大提速。
  3. 完全等价 (Numerically Exact): 这至关重要!FlashAttention 不是一个近似算法(如稀疏注意力、线性注意力)。它是一种硬件感知的 实现,其计算结果与标准注意力在数值上是完全等价的。

总结来说,FlashAttention(及其后续版本 FlashAttention-2,以及FlashAttention-3)已经成为大模型训练和推理的事实标准。它通过 I/O 优化的方式,在不牺牲模型精度的前提下,同时解决了注意力计算的速度和显存两大瓶颈。

pytorch内置的scaled_dot_product_attention(SDPA)函数已经帮我们集成了FlashAttention的思想

6. 线性注意力(Linear Attention)

FlashAttention 优化了 如何计算 标准注意力,但没有改变其 O(N2) 的复杂度本质。当序列长度 N 达到数十万甚至上百万时,即使有 FlashAttention,计算和内存开销依然是无法承受的。

线性注意力 (Linear Attention) 则从一个更激进的角度出发:它直接修改注意力的数学公式,旨在将计算和内存复杂度从 O(N2) 降低到 O(N)

这是一个根本性的改变,其核心代价是:它不再是精确的 Softmax 注意力,而是一种近似。这是一种典型的用模型性能换取极致效率的权衡。

1. 核心思想:改变计算顺序

我们再次回到标准注意力的公式:

这个公式的瓶颈在于 S = QKT 这个 N × N 的矩阵。如果我们能避免计算它,问题就解决了。

观察这个公式,如果我们能先计算 KTV,再把它与 Q 相乘,会怎么样?

  • KT 的维度是 (d × N)V 的维度是 (N × dv)
  • KTV 的结果是一个维度为 (d × dv) 的小矩阵,其大小与序列长度 N 无关
  • 然后用 Q (维度 N × d) 与这个 (d × dv) 的矩阵相乘,最终得到 (N × dv) 的输出,整个过程的复杂度是 O(N)

问题在于,Softmax 函数阻止了我们这样做。 Softmax 需要作用于 QKT 的每一行,它的存在使得矩阵乘法不满足结合律,我们不能简单地把括号从 (QKT)V 移动到 Q(KTV)

线性注意力的核心突破就是:用一个可以满足结合律的核函数 (Kernel Function) 来替代 Softmax

我们将注意力公式一般化为:

其中 sim(Qi, Kj)QiKj 的相似度函数。在标准注意力中,

线性注意力的做法是,将相似度函数 sim(Q, K) 分解为两个独立作用于 QK 的函数 ϕ(⋅) 的点积: sim(Q, K) = ϕ(Q)ϕ(K)T

这样,注意力计算就变成了(忽略分母的归一化项): O = (ϕ(Q)ϕ(K)T)V

由于矩阵乘法满足结合律,我们现在可以重新组合它: O = ϕ(Q)(ϕ(K)TV)

这正是我们想要的!我们成功地将计算顺序改变了:

  1. 计算 ϕ(K)TV: 这是一个 (d × dv) 的矩阵。复杂度为 O(N ⋅ d ⋅ dv)
  2. 计算 ϕ(Q): 复杂度为 O(N ⋅ d)
  3. 计算 ϕ(Q) × (ϕ(K)TV): 得到最终输出 O。复杂度为 O(N ⋅ d ⋅ dv)

总的计算和内存复杂度都变成了 O(N),与序列长度 N 呈线性关系!

2. 如何替换 Softmax?

那么,这个神奇的函数 ϕ(⋅) 应该是什么样呢?这是不同线性注意力变体(如 Performer, Linear Transformer, RWKV 等)的核心区别所在。一个简单而常见的选择是:

ϕ(x) = elu(x) + 1

其中 elu 是一个标准的激活函数 (Exponential Linear Unit)。使用 elu + 1 的好处是保证了输出值非负,这在一定程度上模拟了 Softmax 中 exp 函数的作用。

当然,为了让结果更接近 Softmax,我们还需要加上之前忽略的归一化项。完整的公式变为:

分母和分子都可以利用结合律来高效计算。

3. 线性注意力的优势与劣势

优势 (Advantages)

  1. 线性复杂度 (O(N)): 这是最大的优势。它使得处理极长的序列(例如整本书、数小时的音频)在计算上成为可能。
  2. 可表示为 RNN 形式: 这是线性注意力一个极其强大的特性。由于计算可以分解,它可以被写成一个循环神经网络(RNN)的形式。
    • 在推理(生成)时,每生成一个新 token,我们不需要重新计算整个序列的注意力。
    • 我们只需要维护一个很小的状态(即那个 (d × dv) 的矩阵 ∑(ϕ(K)TV)),然后用新的 qnew 与之交互,并用新的 knew, vnew 更新这个状态。
    • 这使得 autoregressive 推理的每一步都是 O(1) 的复杂度,速度极快,且不随已生成序列的长度而变慢。这正是 RWKV 这类架构的核心优势。

劣势 (Disadvantages)

  1. 性能损失 (Performance Degradation): 这是最主要的代价。线性注意力的核函数 ϕ 终究只是对 Softmax 的一种近似。Softmax 的指数特性使其能够“聚焦”到非常少数的关键 token 上(即注意力得分非常尖锐),而线性注意力的多项式或类线性核函数则更“平滑”,难以实现这种尖锐的聚焦能力。这通常会导致在标准语言模型基准测试(如困惑度 PPL)上表现不如标准注意力。
  2. 表达能力受限 (Limited Expressive Power): 理论和实践都表明,线性注意力的表达能力弱于标准注意力,对于需要精确捕捉长距离依赖关系中特定“点对点”关系的任务,可能会力不从心。
  3. 训练稳定性 (Training Stability): 某些线性注意力的变体可能在训练时不如标准注意力稳定。

7. 稀疏注意力(Sparse Attention)

前面我们看到,FlashAttention 优化了标准注意力的计算过程,而线性注意力则修改了注意力的数学公式。稀疏注意力(Sparse Attention)走了第三条路:它保留了 Softmax 注意力的核心形式,但基于一个关键假设来减少计算量:大部分的注意力得分都很低,是没必要计算的

稀疏注意力的核心思想是:与其让每个 token 关注(Attend to)序列中的所有其他 token(即稠密连接),不如只让它关注一个预先定义好的、稀疏的子集。

这是一种从 O(N2) 的“All-to-All”交互,到 O(Nlog N) 的“Some-to-Some”交互的转变。

1. 核心瓶颈与动机 (The Motivation)

在标准的自注意力中,一个长度为 N 的序列会产生一个 N × N 的注意力矩阵。然而,大量研究和实践发现,这个矩阵通常是稀疏的。这意味着,对于一个给定的 token,它真正有意义的注意力权重往往只集中在少数几个其他的 token 上。

例如,在句子 “The quick brown fox jumps over the lazy dog” 中,“jumps” 这个词可能主要关注 “fox” 和 “over”,而与句子末尾的 “dog” 关系不大。计算 “jumps” 和 “dog” 之间的精确注意力得分,可能是一种计算资源的浪费。

稀疏注意力的动机就是:既然最终的注意力矩阵是稀疏的,我们能不能在计算之前就“跳过”那些不重要的 token 对,从而避免 O(N2) 的计算?

2. 实现稀疏性的关键:注意力模式 (Sparsity Patterns)

如何决定哪些 token 对是“重要的”并需要计算呢?这是各种稀疏注意力模型(如 Longformer、BigBird 等)创新的核心。它们设计了不同的固定稀疏模式 (Fixed Sparsity Patterns) 来近似完整的注意力。

最常见的几种模式包括:

a) 滑动窗口注意力 (Sliding Window / Local Attention)

  • 思想:一个 token 主要与其邻近的 token 相关。这在语言和视觉中都是一个非常强的先验知识。
  • 实现:每个 token 只关注其左边和右边 w 个 token(w 是窗口大小)。
  • 复杂度:每个 token 只计算 2w 个得分,总复杂度为 O(N ⋅ w)。如果 w 是一个小的常数,复杂度就是线性的 O(N)
  • 问题:这种模式完全切断了远距离的依赖关系。窗口之外的 token 之间无法直接交换信息。

b) 扩张/空洞滑动窗口 (Dilated / Strided Sliding Window)

  • 思想:为了在不增加计算量的情况下扩大感受野(Receptive Field),我们可以让窗口内存在“间隙”。
  • 实现:窗口内的 token 不是连续的,而是以一定的步长(dilation rate)跳跃式选择。例如,关注位置 i-4, i-2, i, i+2, i+4
  • 优势:通过在不同层或不同头使用不同的扩张率,模型可以捕捉到不同尺度的远距离依赖。

c) 全局注意力 (Global Attention)

  • 思想:序列中有一些特殊的“明星” token,它们应该有能力关注所有其他 token,也应该被所有其他 token 关注。
  • 实现:预先选择少数几个 token 作为全局 token(例如,[CLS] token)。这些 token 的注意力计算是稠密的(O(N)),但由于数量很少,总的额外开销不大。
  • 优势:这是解决滑动窗口无法传递长距离信息的关键。信息可以通过全局 token 在序列的不同部分之间传递。

d) 随机注意力 (Random Attention)

  • 思想:为了弥补固定模式可能错过的连接,为每个 token 额外增加几个随机选择的 token 进行关注。
  • 实现:每个 token 除了关注其固定模式内的 token 外,还随机采样 r 个序列中的其他 token。
  • 优势:增加了连接的鲁棒性,理论上保证了信息可以在序列的任意两点间以较短的路径传播。

组合模式 (Combined Patterns) 最成功的稀疏注意力模型,如 LongformerBigBird,通常会组合上述多种模式。例如,Longformer 就结合了滑动窗口注意力全局注意力,使得模型既能高效处理局部信息,又能通过全局 token 捕捉长距离依赖。

3. 优势与劣势

优势 (Advantages)

  1. 更高的计算效率:将复杂度从 O(N2) 降低到 O(Nlog N),显著减少了计算量和内存需求。
  2. 支持更长的上下文:这是最直接的好处。像 Longformer 这样的模型可以将 Transformer 的上下文长度从 512 或 1024 扩展到 4096 甚至更长。
  3. 保留了强大的表达能力:与线性注意力相比,稀疏注意力仍然使用 Softmax,保留了其“聚焦”能力,因此在很多任务上比线性注意力的性能损失更小。

劣势 (Disadvantages)

  1. 是近似,非精确:这是与 FlashAttention 最大的不同。稀疏注意力是一种近似,它基于“大多数连接不重要”的假设。如果任务中存在一个关键的、但不符合预设稀疏模式的长距离依赖,模型就可能捕捉不到。
  2. 实现复杂:与标准的稠密矩阵乘法不同,稀疏注意力需要专门的 CUDA Kernel 来高效实现。如果不进行底层优化,在 GPU 上对稀疏矩阵进行操作(需要处理各种索引)可能比直接计算稠密矩阵还要慢。
  3. 模式依赖:模型的性能可能依赖于所选择的稀疏模式,需要根据任务进行调整,缺乏通用性。

要真正从稀疏注意力中获益,关键在于其实现方式。其核心原则是从一开始就避免计算和存储完整的 N × N 注意力矩阵。以下是实现这一目标的主流思路:

任何先计算出稠密注意力矩阵,再用掩码使其稀疏的方法,都无法降低 O(N2) 的计算复杂度,因此是无效的。高效的实现必须在计算注意力得分之前,就只处理预设模式中的 (Query, Key) 对。

方法一:数据重排与批处理矩阵乘法 (框架级实现)

这是一种在 PyTorch 或 TensorFlow 等高级框架中不需编写底层代码即可实现的方法,思路如下:

  1. 确定索引:对于序列中的每一个 Query qi,根据稀疏模式(如滑动窗口)计算出它需要交互的所有 Key kj 的索引。
  2. 收集数据 (Gather):创建一个新的、更小的 Key 矩阵和 Value 矩阵。例如,对于滑动窗口注意力,这个新矩阵的维度会是 (N, window_size, d)。这一步通过高效的索引和数据复制操作(如 torch.gatherunfold)完成。
  3. 批处理计算 (Batched Compute):将原始的 Q 矩阵与这个新的、小型的 K_windowed 矩阵进行批处理矩阵乘法 (Batched Matrix Multiplication)。这会将 N 个独立的、小规模的注意力计算并行化,总的计算量是 O(N ⋅ window_size),而非 O(N2)

这种方法虽然会产生一个中间的、重排后的 K_windowed 矩阵,但已经成功地将计算和内存复杂度从二次方降低到了近似线性。

##### 方法二:融合的自定义 CUDA Kernel (性能极致化)

这是像 Longformer 和 FlashAttention 等 SOTA 实现所采用的方法,追求极致的硬件效率:

  1. 核函数融合 (Kernel Fusion):将“确定索引”、“收集数据”和“计算得分”等多个步骤融合到一个单独的 GPU 核函数中。
  2. I/O 感知计算:该核函数直接在 GPU 上运行。对于每一个 Query qi
    • 将其加载到 GPU 核心的高速缓存 SRAM 中。
    • 根据稀疏模式,直接从慢速的 HRAM (显存) 中按需读取对应的 kjvj 到 SRAM。
    • 所有计算(点积、Softmax、与 Value 相乘)都在高速的 SRAM 内部完成。
    • 只将最终的输出 oi 写回到 HRAM。

这种方法借鉴了 FlashAttention 的 I/O 感知思想,完全避免了在 HRAM 中创建任何中间矩阵(如 K_windowed),最大限度地减少了对内存带宽的占用,是当前实现稀疏注意力的最高效方式。

我们现在有了四种主要的注意力机制,可以用一张表格来清晰地对比它们:

特性 标准 Attention FlashAttention 线性 Attention 稀疏 Attention
核心思想 稠密的 All-to-All I/O 感知的工程优化 修改数学公式近似 Softmax 预设模式跳过部分计算
等价性 - 完全等价 不等价 (近似) 不等价 (近似)
复杂度 O(N2) O(N2) O(N) O(Nlog N)
优势 表达能力强,是黄金标准 速度快,显存低,无损精度 极致效率,推理O(1),支持超长序列 平衡效率与性能,支持长序列
劣势 速度慢,显存占用高 复杂度本质未变 损失模型性能 实现复杂,可能丢失关键信息
代表应用 早期 Transformer (BERT, GPT-2) 当前所有主流大模型 (Llama, GPT-4) RWKV, Performer Longformer, BigBird

稀疏注意力是解决 Transformer 长序列问题的里程碑式工作。它在理论上证明了我们可以通过聪明的近似来大幅降低计算复杂度,催生了一系列成功的长上下文模型。

然而,随着 FlashAttention 的出现,情况发生了变化。FlashAttention 极大地优化了精确注意力的计算效率,使得在当前硬件(如 A100/H100)上,处理中等长度(如 4K-16K)的稠密注意力变得非常快。这使得稀疏注意力的工程优势在一定程度上被削弱了。不过,稀疏化的思想仍然极具价值,尤其是在未来需要处理数十万甚至上百万长度序列的场景下,降低计算复杂度的量级仍然是不可或缺的。

 REWARD AUTHOR
 Comments
Comment plugin failed to load
Loading comment plugin