chapter eight

8 Flash attention

 

This chapter covers

  • Why standard attention becomes memory-bound and how to spot that shift.
  • Core ideas behind Flash Attention: tiling, online softmax, and SRAM residency.
  • Restructuring loops and work distribution so tiles cooperate efficiently.
  • Building a fused Flash Attention kernel with WMMA tensor cores.
  • Evaluating Flash Attention across versions and hardware generations.
  • Deciding when Flash Attention is the right optimization for your model.

8.1 Why memory bottlenecks attention

Tensor cores gave us raw compute throughput. But attention has a different bottleneck: the sheer volume of data shuffled between on-chip SRAM and off-chip HBM. Understanding this shift is the first step toward the Flash Attention algorithm.

8.1.1 From compute-bound to memory-bound

In the last chapter, we explored tensor cores and unlocked immense computational throughput using specialized hardware. Instructions like WMMA dramatically accelerate the massive matrix multiplications that form the backbone of deep learning models. Our focus was almost entirely on optimizing the raw arithmetic, ensuring that every cycle of the GPU was spent doing useful math. But as we move into optimizing complex operations like the attention mechanism, we encounter a different kind of performance barrier.

8.1.2 The high-bandwidth memory bottleneck in standard attention

8.2 Designing Flash Attention

8.2.1 Building intuition: why this problem is harder than it looks

8.2.2 The Flash Attention solution: tiling and online softmax

8.2.3 Tiling: breaking the problem down

8.2.4 Online softmax: computing global statistics iteratively

8.3 Implementing Flash Attention

8.3.1 Dissecting the naive implementation

8.3.2 Understanding the work distribution strategy

8.3.3 The three-level loop hierarchy

8.3.4 The fused Flash Attention kernel

8.3.5 Benchmarking our kernel

8.4 The Evolution of Flash Attention

8.4.1 Flash Attention 1 and 2

8.4.2 Flash Attention 3: asynchrony and low precision

8.4.3 Flash Attention 4: complexity for Blackwell

8.5 Summary