Flash Attention
Deeply understand Flash Attention principles and Triton implementation
Overview
Flash Attention is an efficient attention implementation that uses tiling and online softmax to reduce memory IO complexity from to , greatly speeding up Transformer training and inference.
This series assumes you know GPU programming basics. If SIMT, shared memory, etc. are unfamiliar, start with GPU Programming Basics.
Chapters
Flash Attention Principles
Interactive visualizations to understand memory bottlenecks, online softmax, and tiled matmul
From Naive to Auto-Tuning
Write your first Flash Attention kernel and optimize with auto-tune
Block Pointers and Multi-Dim Support
Scale from single sequence to batch/head parallelism and simplify pointer management
Causal Masking Optimization
Implement causal attention and skip upper-triangular compute for ~2x speedup
Grouped Query Attention
Add GQA/MQA support by sharing KV across query heads to reduce KV cache memory
Backward Pass
Implement Flash Attention gradients using recomputation for memory-efficient training
Why Learn This?
| What You Want to Do | What You Need |
|---|---|
| Understand why standard attention is slow | HBM vs SRAM, IO-bound concept |
| Build your own attention kernel | Online softmax, tiling |
| Optimize kernel performance | Autotune, pipeline, block pointers |
| Support long-sequence inference | Understand memory optimization |
CookLLM Docs