8 Flash attention
This chapter covers
- Why standard attention becomes memory-bound and how to spot that shift.
- Core ideas behind Flash Attention: tiling, online softmax, and SRAM residency.
- Restructuring loops and work distribution so tiles cooperate efficiently.
- Building a fused Flash Attention kernel with WMMA tensor cores.
- Evaluating Flash Attention across versions and hardware generations.
- Deciding when Flash Attention is the right optimization for your model.
8.1 Why memory bottlenecks attention
Tensor cores gave us raw compute throughput. But attention has a different bottleneck: the sheer volume of data shuffled between on-chip SRAM and off-chip HBM. Understanding this shift is the first step toward the Flash Attention algorithm.
8.1.1 From compute-bound to memory-bound
In the last chapter, we explored tensor cores and unlocked immense computational throughput using specialized hardware. Instructions like WMMA dramatically accelerate the massive matrix multiplications that form the backbone of deep learning models. Our focus was almost entirely on optimizing the raw arithmetic, ensuring that every cycle of the GPU was spent doing useful math. But as we move into optimizing complex operations like the attention mechanism, we encounter a different kind of performance barrier.