LogoCookLLM Docs
LogoCookLLM Docs
HomeCookLLM

Principles

Tokenization
Tokenization BasicsBPE AlgorithmGPT TokenizersBPE Training Engineering
Model Architecture
Attention Mechanisms
Position Encoding
Position Encoding BasicsRoPE Math DerivationRoPE ImplementationLength Extrapolation
GPU Programming Basics
GPU Architecture BasicsTensor LayoutTriton Basics: Vector Add
FlashAttention
Flash Attention PrinciplesFrom Naive to Auto-TuningBlock Pointers and Multi-Dim SupportCausal Masking OptimizationGrouped Query AttentionBackward Pass

Hands-on Training

X (Twitter)
SystemsFlashAttention

Backward Pass

Premium

Implement Flash Attention gradients with recomputation for memory-efficient training.

Companion Code

Log in to continue reading

This is premium content. Please log in to access the full article.

Grouped Query Attention

Add GQA/MQA support so multiple query heads share KV, reducing KV cache memory.

Table of Contents

Why Custom Backward?
Limits of PyTorch Autograd
Recomputation Strategy
Backward Math
Forward Recap
Gradients
1. ∂L/∂V\partial L/\partial \mathbf{V}∂L/∂V
2. ∂L/∂P\partial L/\partial \mathbf{P}∂L/∂P
3. ∂L/∂S\partial L/\partial \mathbf{S}∂L/∂S (softmax)
4. ∂L/∂Q\partial L/\partial \mathbf{Q}∂L/∂Q and ∂L/∂K\partial L/\partial \mathbf{K}∂L/∂K
Full Gradient Flow
Implementation
Forward Kernel: Save L and M
Backward Kernel
Key Details
1. Loop Order Reversal
2. Atomic Add for dQ
3. Recompute P, Don’t Store
PyTorch autograd.Function Wrapper
Performance Validation
Correctness
Memory Comparison
Tradeoffs and Optimizations
Recomputation Cost
Further Optimization Ideas
Summary