系统工程
Flash Attention
深入理解 Flash Attention 的原理与 Triton 实现
概述
Flash Attention 是一种高效的注意力机制实现,通过 Tiling(分块计算)和 Online Softmax 技术,将内存读写复杂度从 降低到 ,显著提升了 Transformer 模型的训练和推理速度。
本系列假设你已掌握 GPU 编程基础。如果对 SIMT、Shared Memory 等概念不熟悉,建议先学习 GPU 编程基础。
章节内容
Flash Attention 原理详解
通过交互式可视化,深入理解内存瓶颈、Online Softmax 与分块矩阵乘法
从朴素实现到 Auto-Tuning
编写第一个 Flash Attention Kernel,并利用 Auto-Tune 进行性能优化
Block Pointer 与多维支持
从单序列扩展到 Batch/Head 并行,并使用 Block Pointer 简化指针管理
Causal Masking 优化
为自回归模型实现因果注意力机制,通过跳过上三角计算实现 ~2x 加速
Grouped Query Attention
实现 GQA/MQA 支持,让多个 Query Head 共享 KV,优化 KV Cache 内存占用
反向传播实现
实现 Flash Attention 的梯度计算,通过 Recomputation 实现内存高效的训练
为什么需要学这些?
| 你想做的事 | 需要的知识 |
|---|---|
| 理解为什么标准 Attention 慢 | HBM vs SRAM、IO-bound 概念 |
| 实现自己的 Attention Kernel | Online Softmax、Tiling |
| 优化 Kernel 性能 | Autotune、Pipeline、Block Pointer |
| 支持长序列推理 | 理解 的内存优化 |
CookLLM文档