本章在之前对注意力机制的基础上,介绍几种改进的注意力机制,包括多查询注意力(MQA)、分组查询注意力(GQA)以及FlashAttention等技术,这些技术在提升模型性能和推理效率方面发挥了重要作用。
1. 原始注意力机制(Scaled Dot-Product Attention)
🧠 核心思想
通过计算 Query 与 Key 的点积相似度,得到每个位置对序列中其他位置的注意力权重,并用该权重加权 Value,得到上下文信息。
🧩 公式
🔍 步骤
- 对输入向量 q, k, v 分别映射为 Q, K, V;
- 计算注意力得分矩阵;
- 经过 softmax 归一化;
- 加权求和得到输出。
💡 特点
- 全局注意力,全序列交互;
- 复杂度 O(n2),n为序列的长度;
- 高性能但难以扩展到长序列。
2. 多头注意力(Multi-Head Attention, MHA)
🧠 核心思想
将注意力分成多个头(head),让模型学习不同的语义子空间表示。
🧩 公式
MHA(Q, K, V) = Concat(head1, …, headh)WO
其中 headi = Attention(QWiQ, KWiK, VWiV)
💡 特点
- 多头捕获不同关系模式;
- 并行性强;
- 缺点:推理阶段每个 head 都要维护独立的 KV 缓存,显存开销大。
1 | class MHA(nn.Module): |
如上述代码所示,维护的kv cache结构是 ((B, H, Spastk, dimhead), (B, H, Spastv, dimhead)),即对于每一个head来说,都需要维护之前所有token对应的K和V,显存开销极大。
3. 单查询多头注意力(Multi-Query Attention, MQA)
🧠 核心思想
所有 Query 头共享一组 Key/Value 投影,用于加速推理。
🧩 公式
Qi = XWiQ, K = XWK, V = XWV
💡 优点
- 显存节省:KV 缓存只存一份;
- 推理速度显著提高;
- 参数量与 MHA 几乎相同。
⚠️ 缺点
- 不同 head 共用同一 K/V,表达能力降低;
- 模型多样性下降。
🔗 应用
- Google PaLM
- Efficient Transformers
1 | class MQA(nn.Module): |
如上述代码所示,kv cache的形状是( (B, S_past_k, dim_head), (B, S_past_v, dim_head) ),与H无关。
相当于有H个head,但只需要保存一份KV。
4. 分组查询注意力(Grouped-Query Attention, GQA)
🧠 核心思想
介于 MHA 与 MQA 之间的折中: 多个 Query head 共享同一组 K/V(按组共享)。
🧩 公式
Qi = XWiQ, (Ki, Vi) = XWjK, XWjV, i ∈ groupj
💡 特点
内存效率:KV 缓存数量为组数 g;
灵活性高:可调节组数 g 控制性能与显存之间的折中;
兼容性强:几乎可直接替换 MHA 层。
🔗 应用
- LLaMA 2 / 3
- Mistral
- Gemma
- Falcon 180B
1 | class GQA(nn.Module): |
代码中主要是kv cahce部分,只需要存储num_group个KV实例,一般num_group为4,8,16。
并且当num_group = num_head时,就等价于MHA;
当num_group = 1时,就等价于MQA。
MHA MQA GQA MLA 对比

如图所示,前三种方法的核心区别在于查询头 (Q) 与键/值头 (K/V) 之间的关系,这直接决定了推理时 KV Cache 的大小:
- MHA (多头注意力): 这是标准配置。如图所示,每一个 Q 头都一对一匹配一组独立的 K/V 头。如果有 8 个 Q 头,就有 8 组 K/V。这提供了最强的模型表达能力,但代价是 KV 缓存最大。
- MQA (多查询注意力): 这是一种激进的优化。所有的 Q 头共享同一组 K/V 头。这使得 KV 缓存降到了最低(仅 1 组),极大提升了推理速度和吞吐量,但可能会牺牲一些模型精度。
- GQA (分组查询注意力): 这是 MHA 和 MQA 之间的完美平衡。它将 Q 头分组(如图中 2 个 Q 为一组),组内共享一组 K/V 头。这在实现接近 MHA 质量的同时,将 KV 缓存显著减小(例如,从 8 组降到 4 组),是目前(如 Llama 2/3)的主流选择。
第四种:多头潜在注意力 (Multi-Head Latent Attention, MLA),与前三种方法通过 减少 K/V 头的数量 来节省显存的思路不同,MLA 提出了一种全新的优化哲学:压缩 K/V 缓存的表示。
MLA 假设,由 MHA 生成的完整 K/V 状态是高度冗余的。我们不需要通过“共享”或“分组”来减少头的数量,而是可以学习一个更小、信息密度更高的“潜在表示”(Latent Representation)来作为 K/V 缓存的“摘要”。
从图示中可以清晰地看到:
- 模型在内部依然计算出了完整的、与 MHA 相同的多组 K/V 头(图中有 8 组)。
- 在将它们写入缓存之前,这些完整的 K/V 头会经过一个“投影 (projection)”步骤。
- 这个投影步骤会生成一个“压缩的潜在 KV (Compressed Latent KV)”。
- 只有这个压缩后的表示才会被真正缓存起来,用于后续的推理步骤。
这种方法的实现通常涉及以下几个步骤:
- 投影 (Projection): 这个“投影”层通常是一个可学习的参数矩阵(例如一个简单的线性层或小型前馈网络)。它的作用是降低维度。例如,它可以将 8 个 K/V 头的表示“浓缩”成 2 个“潜在” K/V 头的表示,或者保持头的数量不变,但减小每个头的特征维度。
- 缓存 (Caching): 在推理的每一步,模型计算出
N组 K/V 状态,通过这个投影层,得到一个更小的L组(或维度更低)的潜在 K/V 状态。这个潜在状态被写入缓存,从而节省了大量 VRAM。
MLA 的优势在于,它理论上保留了所有 MHA 原始 K/V 头的信息(通过学习如何压缩它们),而不是像 GQA 或 MQA 那样在结构上就减少了 K/V 头的数量。
最后,我们可以用一个表格来清晰地总结这四种方法在推理时对 KV 缓存的处理:
| 方法 | K/V 头配置 | KV 缓存大小 | 核心思想 |
|---|---|---|---|
| MHA | N 组 Q, N 组 K/V |
最大 (N 倍) | 不优化:每个 Q 头有独立的 K/V |
| MQA | N 组 Q, 1 组 K/V |
最小 (1 倍) | 全局共享:所有 Q 头共享 1 组 K/V |
| GQA | N 组 Q, G 组 K/V |
中等 (G 倍) | 分组共享:每组 Q 头共享 1 组 K/V |
| MLA | N 组 Q, N 组 K/V -> 压缩 |
小 (L 倍) | 压缩:缓存 K/V 状态的一个低维“潜在”投影 |
5. FlashAttention(高效注意力计算)
前面我们讨论的 MQA, GQA 等方法,主要优化的是推理时 KV Cache 的显存占用。而 FlashAttention 则着眼于一个更根本的问题:注意力计算本身的速度和显存瓶颈。
它是一种 I/O 感知(I/O-aware)的注意力算法,不修改注意力计算的数学本质,但通过巧妙的工程实现,使其在 GPU 上的运行速度更快、显存占用更低。
1. 核心瓶颈:内存带宽 (The I/O Bottleneck)
我们先回顾标准注意力的计算:O = Softmax(QKT)V。
在 GPU 上执行此操作的标准方法(例如 PyTorch 的
nn.MultiHeadAttention)存在一个巨大瓶颈:
- 计算 S = QKT:
假设序列长度为
N,特征维度为d。这会产生一个巨大的中间矩阵 S,大小为 N × N。 - HRAM 读写: 这个 N × N 的矩阵 S 必须被写入 GPU 的 HRAM(高带宽内存,即 VRAM 显存),然后再由 Softmax 操作读出。
- Softmax 操作: 计算 Softmax(S),结果再次写入 HRAM。
- 计算 O = Softmax(S)V: 再次从 HRAM 中读取 Softmax(S),与 V 相乘,最后将结果 O 写入 HRAM。
问题在于:
- HRAM 很慢: 相比于 GPU 核心的片上 SRAM(超高速缓存),HRAM 的读写速度要慢几个数量级。
- I/O 成为瓶颈: 注意力计算变成了内存带宽限制 (memory-bound) 的操作。GPU 核心(ALU)大部分时间都在“等待”数据从 HRAM 传来或写入 HRAM,而不是在真正地“计算”(FLOPs)。
- 显存占用: 当序列长度
N很长时(例如 8K、16K),这个 N × N 的 S 矩阵本身就会占用巨量的显存(例如,N=8K时,一个float32矩阵 S 就需要 8192 × 8192 × 4 bytes ≈ 256 MB,如果有 32 个头,就是 8 GB,这还只是一个层!)。
2. FlashAttention 的核心原理
FlashAttention 的核心思路是:彻底避免将 N × N 的注意力矩阵 S 写入 HRAM。
它通过两大技术实现了这一点:
Kernel Fusion (核函数融合):
FlashAttention 将标准注意力中的多次 HRAM 读写(QKT、Softmax、OV)融合成一个单独的 GPU Kernel(核函数)。
这个 Kernel 一次性从 HRAM 中读取 Q, K, V,然后将所有中间计算(S 矩阵、Softmax)全部在高速的 SRAM 中完成,最后只将最终的输出 O 写回到 HRAM。
Tiling (分块/瓦片):
SRAM 的容量非常小(例如,NVIDIA A100 只有 192KB/SM),无法容纳整个 N × N 的矩阵。
FlashAttention 使用分块策略:它将 Q, K, V 矩阵分割成更小的块(Tiles)。Kernel 加载 Q 的一个块,然后迭代地加载 K 和 V 的相应块。
一个形象的比喻:
- 标准注意力:就像一个厨师(GPU 核心)。做一道菜(注意力计算),每切完一种配料(QKT),就把它放回 500 米外的冰箱(HRAM),要用时(Softmax)再跑 500 米去取,做完下一步(Softmax(S))再放回冰箱,最后(OV)再跑 500 米去取… 厨师大部分时间都在跑步。
- FlashAttention:厨师(GPU 核心)把做 一份菜 所需的所有配料(Q, K, V 的一个“块”)一次性拿到自己的小工作台(SRAM)上。在工作台上完成所有步骤(S 块、Softmax、 O 块),直到做出最终的成品(O 块)的一部分,再把它放到餐桌上(HRAM)。厨师几乎不需要离开工作台。
3. 关键实现:分块与在线 Softmax
这里最大的技术难点是 Softmax。Softmax 需要对 一整行 数据进行归一化(xi/∑(xj))。如果我只加载了 S 矩阵的一个块(Sij),我怎么知道这一行的“最大值”和“总和”呢?
FlashAttention 使用了一种Online Softmax的数值稳定算法:
- 分块计算: 假设 Q 块与 K 的第 1 个块 K1 计算,得到了 S 矩阵的第 1 个块 S1。
- 计算局部 Softmax:在 SRAM 中计算 S1 的 Softmax(并记录当前的行最大值 m1 和行总和 l1),然后与 V1 相乘,得到一个临时的输出 O1。
- 迭代更新: 当 Q 块与 K 的第 2 个块 K2 计算,得到 S2。
- 寻找新全局值: 找到 S1 和 S2 共同的行最大值 mnew。
- 重新缩放 (Rescale):
- 缩放旧值: 将之前计算的 O1 按照新的最大值 mnew 和旧的最大值 m1 之间的差值进行缩放。
- 计算新值: 正常计算 S2 的 Softmax(使用 mnew)并与 V2 相乘,得到 O2。
- 累加: 最终的输出块是 O = (scaled O1) + O2。
通过这种方式,FlashAttention 可以在不访问完整 S 矩阵的情况下,分块迭代地计算出与标准 Softmax 完全相同的精确结果。
此外,在训练(反向传播)时,它不在前向传播中存储 N × N 的 S 矩阵,而是在反向传播中从 Q, K 重新计算 S 的块。这是一种典型的“用计算换带宽”的策略,因为在 SRAM 中的重计算(FLOPs)远比从 HRAM 读写(I/O)要快。
4. 核心优势 (The Advantages)
- 更快的速度 (Faster): FlashAttention 将注意力计算从 I/O 限制转变为计算限制 (FLOPs-bound)。由于它极大地减少了对 HRAM 的读写次数,其速度比标准注意力快得多(论文中提到最高 7.6 倍,FlashAttention-2 中更快)。
- 更低的显存 (Memory Efficient):
- 训练时: 不再需要物化 N × N 的注意力矩阵。这是最大的胜利。这使得模型可以训练更长的序列(Long Context)而不会 OOM (Out of Memory)。
- 推理时: 在处理长 Prompt(Prefill 阶段)时,FlashAttention 同样因为不产生 N × N 矩阵而节省了大量显存并极大提速。
- 完全等价 (Numerically Exact): 这至关重要!FlashAttention 不是一个近似算法(如稀疏注意力、线性注意力)。它是一种硬件感知的 实现,其计算结果与标准注意力在数值上是完全等价的。
总结来说,FlashAttention(及其后续版本 FlashAttention-2,以及FlashAttention-3)已经成为大模型训练和推理的事实标准。它通过 I/O 优化的方式,在不牺牲模型精度的前提下,同时解决了注意力计算的速度和显存两大瓶颈。
pytorch内置的scaled_dot_product_attention(SDPA)函数已经帮我们集成了FlashAttention的思想
6. 线性注意力(Linear Attention)
FlashAttention 优化了 如何计算 标准注意力,但没有改变其
O(N2)
的复杂度本质。当序列长度 N 达到数十万甚至上百万时,即使有
FlashAttention,计算和内存开销依然是无法承受的。
而线性注意力 (Linear Attention) 则从一个更激进的角度出发:它直接修改注意力的数学公式,旨在将计算和内存复杂度从 O(N2) 降低到 O(N)。
这是一个根本性的改变,其核心代价是:它不再是精确的 Softmax 注意力,而是一种近似。这是一种典型的用模型性能换取极致效率的权衡。
1. 核心思想:改变计算顺序
我们再次回到标准注意力的公式:
这个公式的瓶颈在于 S = QKT 这个 N × N 的矩阵。如果我们能避免计算它,问题就解决了。
观察这个公式,如果我们能先计算 KTV,再把它与 Q 相乘,会怎么样?
- KT 的维度是 (d × N), V 的维度是 (N × dv)。
- KTV
的结果是一个维度为 (d × dv)
的小矩阵,其大小与序列长度
N无关! - 然后用 Q (维度 N × d) 与这个 (d × dv) 的矩阵相乘,最终得到 (N × dv) 的输出,整个过程的复杂度是 O(N)。
问题在于,Softmax 函数阻止了我们这样做。 Softmax 需要作用于 QKT 的每一行,它的存在使得矩阵乘法不满足结合律,我们不能简单地把括号从 (QKT)V 移动到 Q(KTV)。
线性注意力的核心突破就是:用一个可以满足结合律的核函数 (Kernel Function) 来替代 Softmax。
我们将注意力公式一般化为:
其中 sim(Qi, Kj)
是 Qi 和
Kj
的相似度函数。在标准注意力中,
线性注意力的做法是,将相似度函数 sim(Q, K) 分解为两个独立作用于 Q 和 K 的函数 ϕ(⋅) 的点积: sim(Q, K) = ϕ(Q)ϕ(K)T
这样,注意力计算就变成了(忽略分母的归一化项): O = (ϕ(Q)ϕ(K)T)V
由于矩阵乘法满足结合律,我们现在可以重新组合它: O = ϕ(Q)(ϕ(K)TV)
这正是我们想要的!我们成功地将计算顺序改变了:
- 计算 ϕ(K)TV: 这是一个 (d × dv) 的矩阵。复杂度为 O(N ⋅ d ⋅ dv)。
- 计算 ϕ(Q): 复杂度为 O(N ⋅ d)。
- 计算 ϕ(Q) × (ϕ(K)TV): 得到最终输出 O。复杂度为 O(N ⋅ d ⋅ dv)。
总的计算和内存复杂度都变成了 O(N),与序列长度
N 呈线性关系!
2. 如何替换 Softmax?
那么,这个神奇的函数 ϕ(⋅) 应该是什么样呢?这是不同线性注意力变体(如 Performer, Linear Transformer, RWKV 等)的核心区别所在。一个简单而常见的选择是:
ϕ(x) = elu(x) + 1
其中 elu 是一个标准的激活函数 (Exponential Linear
Unit)。使用 elu + 1
的好处是保证了输出值非负,这在一定程度上模拟了 Softmax 中
exp 函数的作用。
当然,为了让结果更接近
Softmax,我们还需要加上之前忽略的归一化项。完整的公式变为:
分母和分子都可以利用结合律来高效计算。
3. 线性注意力的优势与劣势
优势 (Advantages)
- 线性复杂度 (O(N)): 这是最大的优势。它使得处理极长的序列(例如整本书、数小时的音频)在计算上成为可能。
- 可表示为 RNN 形式:
这是线性注意力一个极其强大的特性。由于计算可以分解,它可以被写成一个循环神经网络(RNN)的形式。
- 在推理(生成)时,每生成一个新 token,我们不需要重新计算整个序列的注意力。
- 我们只需要维护一个很小的状态(即那个 (d × dv) 的矩阵 ∑(ϕ(K)TV)),然后用新的 qnew 与之交互,并用新的 knew, vnew 更新这个状态。
- 这使得 autoregressive 推理的每一步都是 O(1) 的复杂度,速度极快,且不随已生成序列的长度而变慢。这正是 RWKV 这类架构的核心优势。
劣势 (Disadvantages)
- 性能损失 (Performance Degradation): 这是最主要的代价。线性注意力的核函数 ϕ 终究只是对 Softmax 的一种近似。Softmax 的指数特性使其能够“聚焦”到非常少数的关键 token 上(即注意力得分非常尖锐),而线性注意力的多项式或类线性核函数则更“平滑”,难以实现这种尖锐的聚焦能力。这通常会导致在标准语言模型基准测试(如困惑度 PPL)上表现不如标准注意力。
- 表达能力受限 (Limited Expressive Power): 理论和实践都表明,线性注意力的表达能力弱于标准注意力,对于需要精确捕捉长距离依赖关系中特定“点对点”关系的任务,可能会力不从心。
- 训练稳定性 (Training Stability): 某些线性注意力的变体可能在训练时不如标准注意力稳定。
7. 稀疏注意力(Sparse Attention)
前面我们看到,FlashAttention 优化了标准注意力的计算过程,而线性注意力则修改了注意力的数学公式。稀疏注意力(Sparse Attention)走了第三条路:它保留了 Softmax 注意力的核心形式,但基于一个关键假设来减少计算量:大部分的注意力得分都很低,是没必要计算的。
稀疏注意力的核心思想是:与其让每个 token 关注(Attend to)序列中的所有其他 token(即稠密连接),不如只让它关注一个预先定义好的、稀疏的子集。
这是一种从 O(N2)
的“All-to-All”交互,到 O(Nlog N) 或
1. 核心瓶颈与动机 (The Motivation)
在标准的自注意力中,一个长度为 N 的序列会产生一个 N × N
的注意力矩阵。然而,大量研究和实践发现,这个矩阵通常是稀疏的。这意味着,对于一个给定的
token,它真正有意义的注意力权重往往只集中在少数几个其他的 token 上。
例如,在句子 “The quick brown fox jumps over the lazy dog” 中,“jumps” 这个词可能主要关注 “fox” 和 “over”,而与句子末尾的 “dog” 关系不大。计算 “jumps” 和 “dog” 之间的精确注意力得分,可能是一种计算资源的浪费。
稀疏注意力的动机就是:既然最终的注意力矩阵是稀疏的,我们能不能在计算之前就“跳过”那些不重要的 token 对,从而避免 O(N2) 的计算?
2. 实现稀疏性的关键:注意力模式 (Sparsity Patterns)
如何决定哪些 token 对是“重要的”并需要计算呢?这是各种稀疏注意力模型(如 Longformer、BigBird 等)创新的核心。它们设计了不同的固定稀疏模式 (Fixed Sparsity Patterns) 来近似完整的注意力。
最常见的几种模式包括:
a) 滑动窗口注意力 (Sliding Window / Local Attention)
- 思想:一个 token 主要与其邻近的 token 相关。这在语言和视觉中都是一个非常强的先验知识。
- 实现:每个 token 只关注其左边和右边
w个 token(w是窗口大小)。 - 复杂度:每个 token 只计算
2w个得分,总复杂度为 O(N ⋅ w)。如果w是一个小的常数,复杂度就是线性的 O(N)。 - 问题:这种模式完全切断了远距离的依赖关系。窗口之外的 token 之间无法直接交换信息。
b) 扩张/空洞滑动窗口 (Dilated / Strided Sliding Window)
- 思想:为了在不增加计算量的情况下扩大感受野(Receptive Field),我们可以让窗口内存在“间隙”。
- 实现:窗口内的 token
不是连续的,而是以一定的步长(dilation rate)跳跃式选择。例如,关注位置
i-4, i-2, i, i+2, i+4。 - 优势:通过在不同层或不同头使用不同的扩张率,模型可以捕捉到不同尺度的远距离依赖。
c) 全局注意力 (Global Attention)
- 思想:序列中有一些特殊的“明星” token,它们应该有能力关注所有其他 token,也应该被所有其他 token 关注。
- 实现:预先选择少数几个 token 作为全局
token(例如,
[CLS]token)。这些 token 的注意力计算是稠密的(O(N)),但由于数量很少,总的额外开销不大。 - 优势:这是解决滑动窗口无法传递长距离信息的关键。信息可以通过全局 token 在序列的不同部分之间传递。
d) 随机注意力 (Random Attention)
- 思想:为了弥补固定模式可能错过的连接,为每个 token 额外增加几个随机选择的 token 进行关注。
- 实现:每个 token 除了关注其固定模式内的 token
外,还随机采样
r个序列中的其他 token。 - 优势:增加了连接的鲁棒性,理论上保证了信息可以在序列的任意两点间以较短的路径传播。
组合模式 (Combined Patterns) 最成功的稀疏注意力模型,如 Longformer 和 BigBird,通常会组合上述多种模式。例如,Longformer 就结合了滑动窗口注意力和全局注意力,使得模型既能高效处理局部信息,又能通过全局 token 捕捉长距离依赖。
3. 优势与劣势
优势 (Advantages)
- 更高的计算效率:将复杂度从 O(N2) 降低到
O(Nlog N) 或
,显著减少了计算量和内存需求。 - 支持更长的上下文:这是最直接的好处。像 Longformer 这样的模型可以将 Transformer 的上下文长度从 512 或 1024 扩展到 4096 甚至更长。
- 保留了强大的表达能力:与线性注意力相比,稀疏注意力仍然使用 Softmax,保留了其“聚焦”能力,因此在很多任务上比线性注意力的性能损失更小。
劣势 (Disadvantages)
- 是近似,非精确:这是与 FlashAttention 最大的不同。稀疏注意力是一种近似,它基于“大多数连接不重要”的假设。如果任务中存在一个关键的、但不符合预设稀疏模式的长距离依赖,模型就可能捕捉不到。
- 实现复杂:与标准的稠密矩阵乘法不同,稀疏注意力需要专门的 CUDA Kernel 来高效实现。如果不进行底层优化,在 GPU 上对稀疏矩阵进行操作(需要处理各种索引)可能比直接计算稠密矩阵还要慢。
- 模式依赖:模型的性能可能依赖于所选择的稀疏模式,需要根据任务进行调整,缺乏通用性。
要真正从稀疏注意力中获益,关键在于其实现方式。其核心原则是从一开始就避免计算和存储完整的 N × N 注意力矩阵。以下是实现这一目标的主流思路:
任何先计算出稠密注意力矩阵,再用掩码使其稀疏的方法,都无法降低 O(N2) 的计算复杂度,因此是无效的。高效的实现必须在计算注意力得分之前,就只处理预设模式中的 (Query, Key) 对。
方法一:数据重排与批处理矩阵乘法 (框架级实现)
这是一种在 PyTorch 或 TensorFlow 等高级框架中不需编写底层代码即可实现的方法,思路如下:
- 确定索引:对于序列中的每一个 Query qi,根据稀疏模式(如滑动窗口)计算出它需要交互的所有 Key kj 的索引。
- 收集数据 (Gather):创建一个新的、更小的 Key 矩阵和 Value 矩阵。例如,对于滑动窗口注意力,这个新矩阵的维度会是
(N, window_size, d)。这一步通过高效的索引和数据复制操作(如torch.gather或unfold)完成。- 批处理计算 (Batched Compute):将原始的
Q矩阵与这个新的、小型的K_windowed矩阵进行批处理矩阵乘法 (Batched Matrix Multiplication)。这会将 N 个独立的、小规模的注意力计算并行化,总的计算量是 O(N ⋅ window_size),而非 O(N2)。这种方法虽然会产生一个中间的、重排后的
K_windowed矩阵,但已经成功地将计算和内存复杂度从二次方降低到了近似线性。##### 方法二:融合的自定义 CUDA Kernel (性能极致化)
这是像 Longformer 和 FlashAttention 等 SOTA 实现所采用的方法,追求极致的硬件效率:
- 核函数融合 (Kernel Fusion):将“确定索引”、“收集数据”和“计算得分”等多个步骤融合到一个单独的 GPU 核函数中。
- I/O 感知计算:该核函数直接在 GPU 上运行。对于每一个 Query qi:
- 将其加载到 GPU 核心的高速缓存 SRAM 中。
- 根据稀疏模式,直接从慢速的 HRAM (显存) 中按需读取对应的 kj 和 vj 到 SRAM。
- 所有计算(点积、Softmax、与 Value 相乘)都在高速的 SRAM 内部完成。
- 只将最终的输出 oi 写回到 HRAM。
这种方法借鉴了 FlashAttention 的 I/O 感知思想,完全避免了在 HRAM 中创建任何中间矩阵(如
K_windowed),最大限度地减少了对内存带宽的占用,是当前实现稀疏注意力的最高效方式。
我们现在有了四种主要的注意力机制,可以用一张表格来清晰地对比它们:
| 特性 | 标准 Attention | FlashAttention | 线性 Attention | 稀疏 Attention |
|---|---|---|---|---|
| 核心思想 | 稠密的 All-to-All | I/O 感知的工程优化 | 修改数学公式近似 Softmax | 预设模式跳过部分计算 |
| 等价性 | - | 完全等价 | 不等价 (近似) | 不等价 (近似) |
| 复杂度 | O(N2) | O(N2) | O(N) | O(Nlog N) 或 |
| 优势 | 表达能力强,是黄金标准 | 速度快,显存低,无损精度 | 极致效率,推理O(1),支持超长序列 | 平衡效率与性能,支持长序列 |
| 劣势 | 速度慢,显存占用高 | 复杂度本质未变 | 损失模型性能 | 实现复杂,可能丢失关键信息 |
| 代表应用 | 早期 Transformer (BERT, GPT-2) | 当前所有主流大模型 (Llama, GPT-4) | RWKV, Performer | Longformer, BigBird |
稀疏注意力是解决 Transformer 长序列问题的里程碑式工作。它在理论上证明了我们可以通过聪明的近似来大幅降低计算复杂度,催生了一系列成功的长上下文模型。
然而,随着 FlashAttention 的出现,情况发生了变化。FlashAttention 极大地优化了精确注意力的计算效率,使得在当前硬件(如 A100/H100)上,处理中等长度(如 4K-16K)的稠密注意力变得非常快。这使得稀疏注意力的工程优势在一定程度上被削弱了。不过,稀疏化的思想仍然极具价值,尤其是在未来需要处理数十万甚至上百万长度序列的场景下,降低计算复杂度的量级仍然是不可或缺的。