LogoCookLLM文档
LogoCookLLM文档
首页CookLLM

原理精讲

词元化
Tokenization 基础BPE 算法详解GPT 系列 TokenizerBPE 训练工程化
模型架构
Attention 机制详解
Engram
GPU 编程基础
GPU 架构基础张量布局Triton 入门:向量加法
FlashAttention
Flash Attention 原理详解从朴素实现到 Auto-TuningBlock Pointer 与多维支持Causal Masking 优化Grouped Query Attention反向传播实现

动手训练

系统工程FlashAttention

反向传播实现

会员专享

实现 Flash Attention 的梯度计算,通过 Recomputation 实现内存高效的训练。

配套代码

登录以继续阅读

这是一篇付费内容,请登录您的账户以访问完整内容。

Grouped Query Attention

实现 GQA/MQA 支持,让多个 Query Head 共享 KV,优化 KV Cache 内存占用。

目录

为什么需要自定义反向传播?
PyTorch 自动微分的局限性
Recomputation 策略
Attention 反向传播的数学原理
前向传播回顾
梯度推导
1. ∂L∂V\frac{\partial \mathcal{L}}{\partial \mathbf{V}}∂V∂L​ (最简单)
2. ∂L∂P\frac{\partial \mathcal{L}}{\partial \mathbf{P}}∂P∂L​ (中间梯度)
3. ∂L∂S\frac{\partial \mathcal{L}}{\partial \mathbf{S}}∂S∂L​ (Softmax 反向传播)
4. ∂L∂Q\frac{\partial \mathcal{L}}{\partial \mathbf{Q}}∂Q∂L​ 和 ∂L∂K\frac{\partial \mathcal{L}}{\partial \mathbf{K}}∂K∂L​
完整的梯度计算流程
代码实现解析
Forward Kernel 的修改
Backward Kernel 实现
关键实现细节
1. 循环顺序的反转
2. Atomic Add 处理 dQ
3. 重新计算 P 而非保存
torch.autograd.Function 封装
性能验证
数值正确性测试
内存占用对比
设计权衡与优化方向
Recomputation 的开销
进一步优化方向
总结