LogoCookLLM文档
LogoCookLLM文档
首页CookLLM - LLM 系统课程

核心课程

基础知识
GPU 编程基础
FlashAttention
Flash Attention 原理详解
并行策略
量化技术
激活检查点
CPU 卸载
推理优化
缩放定律
数据工程
对齐微调
系统工程FlashAttention

Flash Attention 原理详解

通过交互式可视化,深入理解 Flash Attention 的核心技术:内存瓶颈、Online Softmax、与分块矩阵乘法。

标准 Attention 的内存瓶颈

在深入 Flash Attention 的代码实现之前,我们必须先回答一个底层问题:为什么标准的注意力机制公式 Softmax(QKT)VSoftmax(QK^T)VSof 在现代 上跑得还不够快?

登录以继续阅读

这是一篇付费内容,请登录您的账户以访问完整内容。

t
ma
x
(
Q
KT
)
V
GPU

GPU 内存层级:SRAM 与 HBM

首先我们需要建立一个极其重要的概念:在 GPU 架构中,所有的计算(如矩阵加法、乘法)都必须在靠近核心的 SRAM(Shared Memory,共享内存) 中进行。

这意味着:哪怕你的显存(HBM)有 80GB 那么大,数据也必须先被“搬运”到几十 MB 大小的 SRAM 中,才能被计算核心处理。

标准实现的逻辑陷阱

在 PyTorch 等深度学习框架的“朴素”实现中,Attention 的计算过程被拆分成了多个独立的算子(Op)。这导致了一个严重的效率问题:

  1. 第一步(QKTQK^TQKT):GPU 把 QQQ 和 KKK 从 HBM 搬到 SRAM,计算出分数矩阵 SSS。
  2. 中间结果过大:由于 SSS 的形状是 (N,N)(N, N),对于长序列来说,这个矩阵大到 根本存不下。
# 标准 Attention 实现的 IO 噩梦
def standard_attention(Q, K, V):
    # 1. HBM -> SRAM(计算) -> HBM(存 S)
    S = Q @ K.T

    # 2. HBM(读 S) -> SRAM(计算) -> HBM(存 P)
    P = softmax(S)

    # 3. HBM(读 P) -> SRAM(计算) -> HBM(存 O)
    O =

这种 “搬进来 -> 算一下 -> 踢出去 -> 再搬回来” 的反复 I/O 往返,就是性能最大的杀手。

SRAM 与 HBM 的带宽差异

你可能会问,存回 HBM 再读回来会有多大影响?

速度差异对比

存储类型容量示例 (A100)带宽速度比喻
SRAM (共享内存)~20 MB~19 TB/sF1 赛车 🏎️
HBM (显存)40~80 GB~1.5 TB/s普通轿车 🚗

SRAM 的带宽通常比 HBM 高出约 10 倍以上。

瓶颈本质:IO 受限 (IO-bound)

由于这个速度鸿沟的存在,如果算法不断在 HBM 和 SRAM 之间搬运中间数据,就会出现一种尴尬的局面:

GPU 强大的计算核心大部分时间都在“空转”,苦苦等待数据从缓慢的 HBM 运送过来。

这种状态被称为 I/O 受限(IO-bound),即计算能力被内存传输速度拖了后腿。

量化感受一下:在 standard Attention 中,读写 N×NN \times NN×N 的中间矩阵 SSS 和 PPP 所花费的时间,远远超过了实际进行矩阵乘法计算的时间。

SRAM 的容量限制

既然 SRAM 这么快,为什么不把整个 Attention 矩阵都放在 SRAM 里算完再走?这里涉及到了物理与经济的刚性制约:

物理制约与成本

SRAM 的存储密度极低,导致成本极高。根据相关资料(如 FlashAttention 论文引用的背景):

  • 成本:制造一个 80GB 容量的 SRAM 存储器,成本可能高达 13,000 美元(估算数量级)。
  • 对比:同样容量的 HBM 仅需 2,000 美元。

容量极限

在实际硬件中,A100 的 HBM 可以达到 80GB,但 SRAM 通常只有 192 KB / SM(每个流式多处理器)。由于这种容量限制,你无法一次性把整个 N×NN \times NN×N 的注意力矩阵塞进 SRAM。当序列长度 NNN 增加时,中间矩阵的大小呈 O(N2)O(N^2)O(N2) 爆炸式增长。

核心思路:IO 复杂度优化

Flash Attention 的核心逻辑就在于:既然 SRAM 贵且小但快,HBM 大且便宜但慢,那么我们就必须放弃“全量读写”的幻想。

我们需要引入两个核心思想:

  1. 分块(Tiling):将数据切分成能塞进 SRAM 的“小方块”。
  2. 算子融合(Kernel Fusion):在 SRAM 内部,一气呵成完成 QKTQK^TQKT、Softmax 和 VVV 的乘法,只在最后一步将最终结果 OOO 写回 HBM。

避免中间矩阵落盘

通过这种方式,我们根本不生成(也不写入 HBM)那个巨大的 N×NN \times NN×N 中间矩阵。

方法HBM 读写量复杂度
Standard AttentionO(N2)O(N^2)O(N2)随着序列变长,IO 爆炸
Flash AttentionO(N)O(N)O(N)线性增长,极大节省带宽

Online Softmax 原理

在前文中我们发现了硬件的物理瓶颈(SRAM 容量限制)。为了突破这个瓶颈,我们需要从算法层面进行彻底的重构。本质上,这是一个数学问题:

如何在一个不断流动的数据流上,计算需要全局信息的统计量(如 Softmax)?

这正是 在线算法(Online Algorithms) 的用武之地。

离线算法的局限性

标准的 Softmax 计算是一个典型的离线算法(Offline Algorithm)。它要求我们在开始计算之前,必须“看到”完整的输入数据。

为了保证数值稳定性(防止 exe^xex 溢出),我们必须分三步走:

  1. 全局扫描找最大值:我们需要先遍历所有数据,找到全局最大值 mmm。
  2. 全局求和:用这个 mmm 作为基准,计算所有元素的指数和 lll。
  3. 全局归一化:最后再次遍历,做除法。
yi=exi−m∑exj−my_i = \frac{e^{x_i - m}}{\sum e^{x_j - m}}yi​=∑exj​−m

死锁:只要有一个数据块没加载进来,我们就无法确定 mmm,也就无法计算任何一个 exi−me^{x_i - m}exi​−m。这意味着我们必须把所有数据反复在 HBM 和 SRAM 之间搬运。

在线算法与动态修正

能不能一边读数据,一边就把计算做完?

这就需要引入一种 动态修正(Dynamic Correction) 的技巧。其核心思想是:先基于当前已知的局部信息进行计算,一旦发现更有价值的新信息(更大的最大值),就通过数学变换修正之前的错误结果。

假设我们维护两个流式统计量:

  • m(i)m^{(i)}m(i):当前的局部最大值。
  • l(i)l^{(i)}l(i):当前的局部指数和。

修正公式推导

当我们遇到一个新的数据块,且发现其最大值 mnew>moldm_{new} > m_{old}mnew​>mold​ 时,利用指数性质,我们可以这样更新:

lnew=lold×emold−mnew+exnew−mnew新局部和=旧局部和×衰减因子+新项贡献\begin{aligned} l_{new} &= l_{old} \times e^{m_{old} - m_{new}} + e^{x_{new} - m_{new}} \\ \text{新局部和} &= \text{旧局部和} \times \text{衰减因子} + \text{新项贡献} \end{aligned}lnew​新局部和​

这个公式告诉我们:不需要回头重读旧数据,只需要乘上一个衰减因子,就能把历史结果“对齐”到新的基准上。

数值演示:以序列 [3, 2, 5, 1] 为例

让我们看看这种“流式处理”是如何工作的:

  1. 看到 3:
    • 当前最大 m=3m=3m=3,累加和 l=e3−3=1l=e^{3-3}=1l=e3−3=1。
  2. 看到 2:
    • 2<32 < 32<3,无需修正。
    • 。

结果验证:最终得到的 m=5m=5m=5 和 l≈1.20l \approx 1.20l≈1.20,与离线全量计算的结果完全一致。

Flash Attention 的数学原理总结

这种数学上的 “单次遍历(One-Pass)” 特性,完美契合了 GPU 的硬件需求:

  • 分块并行:我们可以把长序列切分成无数个小块(Tiles)。
  • 独立计算:每个 Block 只需要算出自己的局部 mmm 和 lll。
  • 最终合并:最后只需要把这些局部的统计量通过类似的公式“归约(Reduce)”在一起。

第一性原理:Flash Attention 并没有发明新的数学,它只是巧妙地利用了在线 Softmax 算法,打破了“必须全量数据存在显存”的物理限制,将 O(N2)O(N^2)O(N2) 的显存读写转换为了 O(N)O(N)O(N) 的流式计算。


分块矩阵乘法 (Tiling)

在解决了“如何在线计算 Softmax”的数学问题后,我们还需要解决物理问题:如何把巨大的矩阵塞进极小的 SRAM 中?

答案就是 分块矩阵乘法(Block Matrix Multiplication)。它是 Flash Attention 能够实现并行加速的物理基础。

为什么要进行“分块”?

假设我们要计算 C=A×BC = A \times BC=A×B,其中矩阵均为 10000×1000010000 \times 1000010000×10000。

  • 核心不足:如果并行计算每一个元素,需要 1 亿个核心,这不现实。
  • 内存限制:SRAM 只有 20MB,无法一次性装入这么大的矩阵。

分而治之:分块的核心思想是,大矩阵的乘法 = 子矩阵乘法之和。我们可以把大矩阵切成无数个 128×128128 \times 128128×128 的小块,每次只把几块搬进 SRAM 计算,算完再累加。

可视化演示:分块计算流程

这个演示组件展示了当 GPU 计算矩阵乘法时,数据是如何被“切块”和“分发”的。

×
=
A(0,0)
×
B(0,0)
+
A(0,1)
×
B(1,0)
=
C(0,0)

观察重点

请点击紫色矩阵 CCC 中的任意一个块(例如 C(0,0)C(0,0)C(0,0)):

  1. 全局依赖关系(主视图): 你会看到 AAA 的对应行和 BBB 的对应列亮起。这代表了计算该结果所需的全部数据。

  2. SRAM 的分步执行(底部面板): 底部面板展示了这些数据是如何分批进入 SRAM 的。我们并没有一次性读取整行整列,而是分成了两次独立的加载与计算。

    • Term 1: 加载 AAA 的第一个分块和 BBB 的第一个分块 →\rightarrow→ 算出部分结果。
    • Term 2: 加载 AAA 的第二个分块和 BBB 的第二个分块 累加到最终结果。

Tiling 的本质:将一个巨大的矩阵运算,拆解为无数个可以在 SRAM 小缓存中独立完成的微型运算。

Tiling 与 Attention 的结合

既然我们已经理解了 A×BA \times BA×B 的分块原理,现在让我们把这个逻辑应用到 Attention 的流水线中:Q×KT→S×V→OQ \times K^T \rightarrow S \times V \rightarrow OQ×KT→S×V→O。

在这个阶段,我们先暂时忽略 Softmax 的复杂性,专注于矩阵块是如何流动的。这是通往 Flash Attention 的必经之路:理解 Attention 计算的分块线性性。

循环策略对比:V1 vs V2

Flash Attention 的演进过程中出现了两种不同的分块循环策略。我们可以通过这两个伪代码来对比它们的 IO 差异。

假设输入维度:

  • Q/K/V/O: 形状 (8, 128)
  • 切块 (Tiling): 切成 4 个 Block,每个 Block 形状 (2, 128)
Flash Attention V1 (Outer loop over K)
# 1. 初始化 HBM 中的 O 为全 0 (必须!)
O = HBM_ZEROS(8, 128)

# 外循环:遍历 K, V 的 4 个 Block
FOR j in 0..3:
    # 把 K[j], V[j] 加载到 SRAM (Cache)
    Kj =




















Flash Attention V2 (Outer loop over Q)
# 1. HBM 中的 O 不需要初始化!(可以是随机垃圾值)

# 外循环:遍历 Q 的 4 个 Block
FOR i in 0..3:
    Qi = LOAD_TO_SRAM(Q[i])

    # [优化] 在 SRAM 中初始化累加器为 0
    # 这是纯寄存器/SRAM 操作,速度极快
    Oi_acc = SRAM_ZEROS(2

















Click specific Q row to start (Pivot Q)
Q
Kᵀ
S
×
V
=
O
SRAM Hold
Q0
STREAM
K0
V0
K1
V1
K2
V2
K3
V3
➡
Accumulate
O0
✔ Writes to HBM: 1

图解交互指南

这个可视化组件现在支持切换 两种循序策略,请尝试点击顶部的 V1 / V2 按钮体验差异:

  1. V2 模式(默认,Outer Q loop):

    • 点击左侧蓝色的 Q 行 (QiQ_iQi​)。
    • 观察:QiQ_iQi​ 被锁定在 SRAM 中(高亮),而 K,VK, VK,V 所有的块像流水一样流过。
    • 结果:OiO_i 像一个蓄水池,在 SRAM 内一次性累加完成,最终。

核心洞察: Flash Attention V2 之所以快,不仅仅是因为分块,更是因为它通过颠倒循环顺序(从 Outer K 改为 Outer Q),使得输出 OOO 的计算完全局部化在 SRAM 中,消除了昂贵的显存读写瓶颈。

分块 Attention 的 Softmax 修正

前面的推导有一个"小小的"遗漏:我们假装 Softmax 不存在。

现在,让我们把 Softmax 请回来,看看它会带来什么麻烦。

朴素实现:局部 Softmax 的局限

我们尝试在每个 Block 内部做一个"局部 Softmax":

Naive Block-wise Attention (Mathematically Incorrect)
# 尝试直接将 Softmax 放入 V2 循环结构

# 外循环:遍历 Q 的 4 个 Block
FOR i in 0..3:
    Qi = LOAD_TO_SRAM(Q[i])

    # 在 SRAM 中初始化累加器
    Oi_acc = SRAM_ZEROS(2, 128















问题来了:每个 PijP_{ij}Pij​ 是用自己的rowmax归一化的。

比如 P11P_{11}P11​ 用的是 max⁡(,而 用的是 。但真正的 Softmax 需要用整行的全局最大值!

这就引出了 Flash Attention 的核心问题:

当我处理完 Block jjj,进入 Block j+1j+1j+1 时,如果发现了一个更大的 max,我该怎么"修正"之前已经算好的结果?

解决方案:在线重缩放 (Online Rescaling)

还记得 Online Softmax 吗?它的核心思想就是:当发现新的更大值时,回去乘以一个修正因子。

现在我们把这个思想应用到 Attention 的分块计算中。

初始化

我们为固定的 QiQ_iQi​ Block,遍历所有 KjK_jKj​ Block。在遍历开始前,先初始化:

Initialization (Before Loop)
m = [-∞, -∞]       # 每行的 running max
l = [0, 0]         # 每行的 running sum(exp)
Oi_acc = ZEROES(2, 128

内循环: 遍历 K-Blocks

现在我们展示一次迭代的完整逻辑(处理第 jjj 个 K-Block):

Inner Loop: Process K_j
# ---- 保存旧状态 ----
m_prev = m
l_prev = l
Oi_acc_prev = Oi_acc

# ---- Step 1: 更新 running max ----
Sij = Qi @ Kj.T                        # 当前 Block 的 Score
m =










理解 diag(exp(...)) 操作:

这个看起来复杂的矩阵操作,其实就是对每一行独立地乘以修正因子。展开来看:

# 对于 Block 内部的每一行 row:
Oi_acc[row] = Oi_acc_prev[row] * exp(m_prev[row] - m[row]) + (Pij @ Vj)[row]

把所有行叠加在一起,就等价于矩阵形式。

具体数值例子:

# 假设 Oi_acc_prev 是 2×4 矩阵:
Oi_acc_prev = [[1.0, 2.0, 3.0, 4.0],   # Row 0
               [5.0, 6.0, 7.0, 8.0]]   




















直觉:

  • Row 0 在上一轮"高估"了自己的重要性(用了较小的 max = 5),现在发现真正的 max 是 7,所以需要缩小。
  • Row 1 的 max 没变,所以保持原样。

这正是 Online Softmax 的核心思想!

最终步骤: 归一化

当所有 K-Block 都处理完毕后,用累积的 lll 做最终归一化:

Final Normalization
O_final = diag(1 / l) @ O
#         ↑ 除以总 sum, 得到真正的 Softmax 加权结果

Flash Attention 的本质:

  1. 延迟归一化:我们不在每一步都除以 sum,而是把所有 exp 值累加起来,最后一次性除。
  2. 动态重缩放:每当发现更大的 max,就用 exp(m_old - m_new) 去修正之前的所有结果。
  3. 数学等价:最终结果与标准 Attention 完全一致,但内存占用从 O(N2)O(N^2)O(N2) 降到了 O(N)O(N)O(N)。

FlashAttention:高效注意力机制

深入理解 FlashAttention 的原理与实现

parallelism

parallelism module

目录

标准 Attention 的内存瓶颈
GPU 内存层级:SRAM 与 HBM
标准实现的逻辑陷阱
SRAM 与 HBM 的带宽差异
速度差异对比
瓶颈本质:IO 受限 (IO-bound)
SRAM 的容量限制
物理制约与成本
容量极限
核心思路:IO 复杂度优化
避免中间矩阵落盘
Online Softmax 原理
离线算法的局限性
在线算法与动态修正
修正公式推导
数值演示:以序列 [3, 2, 5, 1] 为例
Flash Attention 的数学原理总结
分块矩阵乘法 (Tiling)
为什么要进行“分块”?
可视化演示:分块计算流程
观察重点
Tiling 与 Attention 的结合
循环策略对比:V1 vs V2
图解交互指南
分块 Attention 的 Softmax 修正
朴素实现:局部 Softmax 的局限
解决方案:在线重缩放 (Online Rescaling)
初始化
内循环: 遍历 K-Blocks
最终步骤: 归一化
(N,N)
SRAM
  • 被迫回写:GPU 只能被迫将这个巨型矩阵 SSS 从 SRAM “踢”出去,写回到慢速的 HBM 中。
  • 反复折腾:到了下一步计算 Softmax(S)Softmax(S)Softmax(S) 时,GPU 又得重新去 HBM 把刚才存进去的 SSS 再搬回 SRAM。
  • P
    @
    V
    return O
    exi​−m
    ​
    =lold​×emold​−mnew​
    l=1+e2−3≈1.36l = 1 + e^{2-3} \approx 1.36
    l=1+e2−3≈1.36
  • 看到 5(出现更大值!):
    • 新最大值 m=5m=5m=5。
    • 修正历史:将之前的 1.361.361.36 乘以衰减因子 e3−5e^{3-5}e3−5。
    • 加入新值:加上 e5−5e^{5-5}e5−5。
    • 。
  • 看到 1:
    • 1<51 < 51<5,无需修正。
    • l=1.18+e1−5≈1.20l = 1.18 + e^{1-5} \approx 1.20l=1.18+e1−5≈1.20。
  • →\rightarrow→
    LOAD_TO_SRAM(K[j])
    Vj = LOAD_TO_SRAM(V[j])
    # 内循环:遍历 Q 的 4 个 Block
    FOR i in 0..3:
    Qi = LOAD_TO_SRAM(Q[i])
    # [IO 瓶颈] 必须从 HBM 读取之前的中间结果
    # 因为我们这次只计算了部分贡献,需要累加到旧值上
    Oi = LOAD_FROM_HBM(O[i])
    # 计算当前 K 块带来的增量
    Sij = Qi @ Kj.T
    Oi = Oi + Sij @ Vj
    # [IO 瓶颈] 必须立即写回 HBM
    # 因为 SRAM 放不下所有的 O,且马上要处理下一个 Q
    WRITE_TO_HBM(O[i], Oi)
    END FOR
    END FOR
    # 总结:O 的每个 Block 被读写了 4 次 (等于 K 的分块数)
    ,
    128
    )
    # 内循环:遍历 K, V 的 4 个 Block
    FOR j in 0..3:
    Kj = LOAD_TO_SRAM(K[j])
    Vj = LOAD_TO_SRAM(V[j])
    # 所有的累加都在 SRAM 内部完成!
    Sij = Qi @ Kj.T
    Oi_acc = Oi_acc + Sij @ Vj
    END FOR
    # [优化] 内循环结束后,只写一次 HBM (覆盖写)
    # 这时 Oi_acc 已经是最终结果了
    WRITE_TO_HBM(O[i], Oi_acc)
    END FOR
    # 总结:O 的每个 Block 只被写了 1 次 (这也是 V2 提速的关键原因)
    Oi​
    只写一次 HBM
  • V1 模式(旧逻辑,Outer K loop):

    • 点击顶部绿色的 K 列 (KjK_jKj​)。
    • 观察:Kj,VjK_j, V_jKj​,Vj​ 被锁定在 SRAM 中,而 QQ 所有的块流过。
  • 调试视角:

    • 点击中间黄色的 SSS 块,可以查看具体的单步计算:Si,j=Qi⋅KjTS_{i,j} = Q_i \cdot K_j^TSi,j​=Qi​⋅K。
  • )
    # 内循环:遍历 K, V 的 4 个 Block
    FOR j in 0..3:
    Kj = LOAD_TO_SRAM(K[j])
    Vj = LOAD_TO_SRAM(V[j])
    Sij = Qi @ Kj.T
    # [问题所在] 这里的 Softmax 只是局部的! 因为我们不知道全局的最大值和全局的分母 (Sum)
    Pij_local = exp(Sij - MAX(Sij))
    Oi_acc += Pij_local @ Vj
    END FOR
    # 写入结果 (虽然结果是错的)
    WRITE_TO_HBM(O[i], Oi_acc)
    S11)\max(S_{11})
    max(S11​)
    P12P_{12}P12​
    max⁡(S12)\max(S_{12})max(S12​)

    直接把这些"各自为政"的 PijP_{ij}Pij​ 累加起来,结果是错的!

    )
    # 累积输出
    max
    (rowmax(Sij), m_prev)
    # 取新旧 max 的较大值
    # ---- Step 2: 更新 running sum (带重缩放) ----
    l = rowsum(exp(Sij - m)) + l_prev * exp(m_prev - m)
    # ↑ 关键: 对旧的 sum 乘以修正因子!
    # ---- Step 3: 计算局部 "概率" ----
    Pij = exp(Sij - m)
    # ---- Step 4: 更新输出 (带重缩放) ----
    Oi_acc = diag(exp(m_prev - m)) @ Oi_acc_prev + Pij @ Vj
    # ↑ 关键: 对旧的输出乘以修正因子!
    # Row 1
    # 假设:
    m_prev = [5.0, 6.0] # 上一轮每行的 max
    m = [7.0, 6.0] # 这一轮每行的 max (Row 0 变大了, Row 1 没变)
    # Step 1: 计算每行的修正因子 (向量)
    correction = exp(m_prev - m)
    = exp([5-7, 6-6])
    = exp([-2, 0])
    = [0.135, 1.0]
    # Step 2: 从向量创建对角矩阵
    diag(correction) = [[0.135, 0 ],
    [0, 1.0 ]]
    # Step 3: 对角矩阵 × Oi_acc_prev = 每行乘以对应的修正因子
    [[0.135, 0 ] [[1, 2, 3, 4], [[0.135, 0.27, 0.41, 0.54],
    [0, 1.0 ]] @ [5, 6, 7, 8]] = [5.0, 6.0, 7.0, 8.0 ]]
    #
    # Row 0: 所有元素 × 0.135 (缩小!)
    # Row 1: 所有元素 × 1.0 (不变)
    +
    exnew​−mnew​
    =旧局部和×衰减因子+新项贡献
    ​
    l=(1.36×e−2)+1≈1.18l = (1.36 \times e^{-2}) + 1 \approx 1.18
    l=(1.36×e−2)+1≈1.18
    Q
  • 结果:对于流过的每一个 QiQ_iQi​,我们都需要更新对应的 OiO_iOi​。由于 OiO_iOi​ 无法常驻 SRAM,只能产生大量的 Partial Update(紫色小块),导致反复读写 HBM。
  • j
    T
    ​