Flash Attention 原理详解
通过交互式可视化,深入理解 Flash Attention 的核心技术:内存瓶颈、Online Softmax、与分块矩阵乘法。
标准 Attention 的内存瓶颈
在深入 Flash Attention 的代码实现之前,我们必须先回答一个底层问题:为什么标准的注意力机制公式 在现代 上跑得还不够快?
登录以继续阅读
这是一篇付费内容,请登录您的账户以访问完整内容。
通过交互式可视化,深入理解 Flash Attention 的核心技术:内存瓶颈、Online Softmax、与分块矩阵乘法。
在深入 Flash Attention 的代码实现之前,我们必须先回答一个底层问题:为什么标准的注意力机制公式 在现代 上跑得还不够快?
这是一篇付费内容,请登录您的账户以访问完整内容。
GPU首先我们需要建立一个极其重要的概念:在 GPU 架构中,所有的计算(如矩阵加法、乘法)都必须在靠近核心的 SRAM(Shared Memory,共享内存) 中进行。
这意味着:哪怕你的显存(HBM)有 80GB 那么大,数据也必须先被“搬运”到几十 MB 大小的 SRAM 中,才能被计算核心处理。
在 PyTorch 等深度学习框架的“朴素”实现中,Attention 的计算过程被拆分成了多个独立的算子(Op)。这导致了一个严重的效率问题:
GPU 把 和 从 HBM 搬到 SRAM,计算出分数矩阵 。# 标准 Attention 实现的 IO 噩梦
def standard_attention(Q, K, V):
# 1. HBM -> SRAM(计算) -> HBM(存 S)
S = Q @ K.T
# 2. HBM(读 S) -> SRAM(计算) -> HBM(存 P)
P = softmax(S)
# 3. HBM(读 P) -> SRAM(计算) -> HBM(存 O)
O =
这种 “搬进来 -> 算一下 -> 踢出去 -> 再搬回来” 的反复 I/O 往返,就是性能最大的杀手。
你可能会问,存回 HBM 再读回来会有多大影响?
| 存储类型 | 容量示例 (A100) | 带宽 | 速度比喻 |
|---|---|---|---|
SRAM (共享内存) | ~20 MB | ~19 TB/s | F1 赛车 🏎️ |
HBM (显存) | 40~80 GB | ~1.5 TB/s | 普通轿车 🚗 |
SRAM 的带宽通常比 HBM 高出约 10 倍以上。
由于这个速度鸿沟的存在,如果算法不断在 HBM 和 SRAM 之间搬运中间数据,就会出现一种尴尬的局面:
GPU 强大的计算核心大部分时间都在“空转”,苦苦等待数据从缓慢的 HBM 运送过来。
这种状态被称为 I/O 受限(IO-bound),即计算能力被内存传输速度拖了后腿。
量化感受一下:在 standard Attention 中,读写 的中间矩阵 和 所花费的时间,远远超过了实际进行矩阵乘法计算的时间。
既然 SRAM 这么快,为什么不把整个 Attention 矩阵都放在 SRAM 里算完再走?这里涉及到了物理与经济的刚性制约:
SRAM 的存储密度极低,导致成本极高。根据相关资料(如 FlashAttention 论文引用的背景):
80GB 容量的 SRAM 存储器,成本可能高达 13,000 美元(估算数量级)。HBM 仅需 2,000 美元。在实际硬件中,A100 的 HBM 可以达到 80GB,但 SRAM 通常只有 192 KB / SM(每个流式多处理器)。由于这种容量限制,你无法一次性把整个 的注意力矩阵塞进 SRAM。当序列长度 增加时,中间矩阵的大小呈 爆炸式增长。
Flash Attention 的核心逻辑就在于:既然 SRAM 贵且小但快,HBM 大且便宜但慢,那么我们就必须放弃“全量读写”的幻想。
我们需要引入两个核心思想:
SRAM 的“小方块”。SRAM 内部,一气呵成完成 、Softmax 和 的乘法,只在最后一步将最终结果 写回 HBM。通过这种方式,我们根本不生成(也不写入 HBM)那个巨大的 中间矩阵。
| 方法 | HBM 读写量 | 复杂度 |
|---|---|---|
Standard Attention | 随着序列变长,IO 爆炸 | |
Flash Attention | 线性增长,极大节省带宽 |
在前文中我们发现了硬件的物理瓶颈(SRAM 容量限制)。为了突破这个瓶颈,我们需要从算法层面进行彻底的重构。本质上,这是一个数学问题:
如何在一个不断流动的数据流上,计算需要全局信息的统计量(如 Softmax)?
这正是 在线算法(Online Algorithms) 的用武之地。
标准的 Softmax 计算是一个典型的离线算法(Offline Algorithm)。它要求我们在开始计算之前,必须“看到”完整的输入数据。
为了保证数值稳定性(防止 溢出),我们必须分三步走:
死锁:只要有一个数据块没加载进来,我们就无法确定 ,也就无法计算任何一个 。这意味着我们必须把所有数据反复在 HBM 和 SRAM 之间搬运。
能不能一边读数据,一边就把计算做完?
这就需要引入一种 动态修正(Dynamic Correction) 的技巧。其核心思想是:先基于当前已知的局部信息进行计算,一旦发现更有价值的新信息(更大的最大值),就通过数学变换修正之前的错误结果。
假设我们维护两个流式统计量:
当我们遇到一个新的数据块,且发现其最大值 时,利用指数性质,我们可以这样更新:
这个公式告诉我们:不需要回头重读旧数据,只需要乘上一个衰减因子,就能把历史结果“对齐”到新的基准上。
让我们看看这种“流式处理”是如何工作的:
结果验证:最终得到的 和 ,与离线全量计算的结果完全一致。
这种数学上的 “单次遍历(One-Pass)” 特性,完美契合了 GPU 的硬件需求:
第一性原理:Flash Attention 并没有发明新的数学,它只是巧妙地利用了在线 Softmax 算法,打破了“必须全量数据存在显存”的物理限制,将 的显存读写转换为了 的流式计算。
在解决了“如何在线计算 Softmax”的数学问题后,我们还需要解决物理问题:如何把巨大的矩阵塞进极小的 SRAM 中?
答案就是 分块矩阵乘法(Block Matrix Multiplication)。它是 Flash Attention 能够实现并行加速的物理基础。
假设我们要计算 ,其中矩阵均为 。
SRAM 只有 20MB,无法一次性装入这么大的矩阵。分而治之:分块的核心思想是,大矩阵的乘法 = 子矩阵乘法之和。我们可以把大矩阵切成无数个 的小块,每次只把几块搬进 SRAM 计算,算完再累加。
这个演示组件展示了当 GPU 计算矩阵乘法时,数据是如何被“切块”和“分发”的。
请点击紫色矩阵 中的任意一个块(例如 ):
全局依赖关系(主视图): 你会看到 的对应行和 的对应列亮起。这代表了计算该结果所需的全部数据。
SRAM 的分步执行(底部面板): 底部面板展示了这些数据是如何分批进入 SRAM 的。我们并没有一次性读取整行整列,而是分成了两次独立的加载与计算。
Tiling 的本质:将一个巨大的矩阵运算,拆解为无数个可以在 SRAM 小缓存中独立完成的微型运算。
既然我们已经理解了 的分块原理,现在让我们把这个逻辑应用到 Attention 的流水线中:。
在这个阶段,我们先暂时忽略 Softmax 的复杂性,专注于矩阵块是如何流动的。这是通往 Flash Attention 的必经之路:理解 Attention 计算的分块线性性。
Flash Attention 的演进过程中出现了两种不同的分块循环策略。我们可以通过这两个伪代码来对比它们的 IO 差异。
假设输入维度:
(8, 128)(2, 128)# 1. 初始化 HBM 中的 O 为全 0 (必须!)
O = HBM_ZEROS(8, 128)
# 外循环:遍历 K, V 的 4 个 Block
FOR j in 0..3:
# 把 K[j], V[j] 加载到 SRAM (Cache)
Kj =
# 1. HBM 中的 O 不需要初始化!(可以是随机垃圾值)
# 外循环:遍历 Q 的 4 个 Block
FOR i in 0..3:
Qi = LOAD_TO_SRAM(Q[i])
# [优化] 在 SRAM 中初始化累加器为 0
# 这是纯寄存器/SRAM 操作,速度极快
Oi_acc = SRAM_ZEROS(2
这个可视化组件现在支持切换 两种循序策略,请尝试点击顶部的 V1 / V2 按钮体验差异:
V2 模式(默认,Outer Q loop):
核心洞察: Flash Attention V2 之所以快,不仅仅是因为分块,更是因为它通过颠倒循环顺序(从 Outer K 改为 Outer Q),使得输出 的计算完全局部化在 SRAM 中,消除了昂贵的显存读写瓶颈。
前面的推导有一个"小小的"遗漏:我们假装 Softmax 不存在。
现在,让我们把 Softmax 请回来,看看它会带来什么麻烦。
我们尝试在每个 Block 内部做一个"局部 Softmax":
# 尝试直接将 Softmax 放入 V2 循环结构
# 外循环:遍历 Q 的 4 个 Block
FOR i in 0..3:
Qi = LOAD_TO_SRAM(Q[i])
# 在 SRAM 中初始化累加器
Oi_acc = SRAM_ZEROS(2, 128
问题来了:每个 是用自己的rowmax归一化的。
比如 用的是 ,而 用的是 。但真正的 Softmax 需要用整行的全局最大值!
这就引出了 Flash Attention 的核心问题:
当我处理完 Block ,进入 Block 时,如果发现了一个更大的 max,我该怎么"修正"之前已经算好的结果?
还记得 Online Softmax 吗?它的核心思想就是:当发现新的更大值时,回去乘以一个修正因子。
现在我们把这个思想应用到 Attention 的分块计算中。
我们为固定的 Block,遍历所有 Block。在遍历开始前,先初始化:
m = [-∞, -∞] # 每行的 running max
l = [0, 0] # 每行的 running sum(exp)
Oi_acc = ZEROES(2, 128现在我们展示一次迭代的完整逻辑(处理第 个 K-Block):
# ---- 保存旧状态 ----
m_prev = m
l_prev = l
Oi_acc_prev = Oi_acc
# ---- Step 1: 更新 running max ----
Sij = Qi @ Kj.T # 当前 Block 的 Score
m =
理解 diag(exp(...)) 操作:
这个看起来复杂的矩阵操作,其实就是对每一行独立地乘以修正因子。展开来看:
# 对于 Block 内部的每一行 row:
Oi_acc[row] = Oi_acc_prev[row] * exp(m_prev[row] - m[row]) + (Pij @ Vj)[row]把所有行叠加在一起,就等价于矩阵形式。
具体数值例子:
# 假设 Oi_acc_prev 是 2×4 矩阵:
Oi_acc_prev = [[1.0, 2.0, 3.0, 4.0], # Row 0
[5.0, 6.0, 7.0, 8.0]]
直觉:
这正是 Online Softmax 的核心思想!
当所有 K-Block 都处理完毕后,用累积的 做最终归一化:
O_final = diag(1 / l) @ O
# ↑ 除以总 sum, 得到真正的 Softmax 加权结果Flash Attention 的本质:
exp 值累加起来,最后一次性除。exp(m_old - m_new) 去修正之前的所有结果。SRAMGPU 只能被迫将这个巨型矩阵 从 SRAM “踢”出去,写回到慢速的 HBM 中。GPU 又得重新去 HBM 把刚才存进去的 再搬回 SRAM。V1 模式(旧逻辑,Outer K loop):
调试视角:
直接把这些"各自为政"的 累加起来,结果是错的!