LLM Inference Engineer · Day 21
Day 21 · Week 3 · Inference & Hardware

FlashAttention: IO-Aware Tiling & Online Softmax

FlashAttention is the attention algorithm behind modern fast LLM inference and training. Its central trick is not a new approximation. It computes exact attention while avoiding the full T by T score and probability matrices in HBM.

Time~190 min
DifficultyHard
PrerequisiteDay 20
Notebookday-21
Why This Lesson

Hardware limits shape inference behavior.

Day 6 taught attention as math. Day 16 taught roofline thinking. Day 20 taught that cache memory dominates decode. Day 21 combines them: attention is not just softmax(QK^T)V; it is also a memory movement problem.

FlashAttention is important because it is exact, widely deployed, and conceptually reusable. Once you understand the online softmax update, other IO-aware kernels become much less mysterious.

Learning Objectives

What you should be able to do today.

  1. State why standard attention writes O(T^2) data to HBM.
  2. Compute standard versus FlashAttention HBM traffic for a concrete T and d.
  3. Derive online softmax with a running max and denominator.
  4. Verify online softmax numerically against standard softmax.
  5. Walk through the FlashAttention tile loop.
  6. Use PyTorch scaled_dot_product_attention and know when it can dispatch to a flash backend.
Math Notation Cheatsheet

Decode the symbols before using them.

  • Q, K, and V are query, key, and value matrices from Day 6.
  • S = QK^T is the score matrix before softmax.
  • P = softmax(S) is the probability matrix.
  • B_r is the number of Q rows in a tile.
  • B_c is the number of K/V rows in a tile.
  • m is the running maximum used for stable online softmax.
  • d or l is the running denominator, the sum of exponentials after max correction.
Standard Attention IO

The problem is the T by T tensors.

Objective

By the end of this section, you should be able to point at the memory bottleneck in standard attention.

Standard attention is usually written:

S = QK^T / sqrt(head_dim)
P = softmax(S)
O = P V

That formula is correct, but a naive implementation materializes S and P, each shaped [T, T] per head. For T = 4096, that is:

T^2 = 4096 * 4096 = 16,777,216 elements
FP16 bytes = 2
one T x T matrix ~= 33.6 MB per head
S and P together ~= 67 MB per head

The actual useful input and output tensors are much smaller:

Q, K, V, O each ~= T * head_dim * 2 bytes
for head_dim = 128: 4096 * 128 * 2 ~= 1 MB

The expensive part is not the idea of attention. It is writing and rereading the giant intermediate matrices.

Standard Attention Materializes T x T Q,K S=QK^Twrite HBM P=softmax(S)write HBM P,V -> O The red tensors are O(T^2). At long context, moving them dominates.
Standard attention creates O(T^2) HBM traffic for scores and probabilities.
HBM Traffic at T = 4096, head_dim = 128 Standard attention intermediates ~67 MB/head FlashAttention linear traffic ~4 MB/head The exact ratio depends on what reads/writes you count, but the key change is quadratic traffic to linear traffic.
A concrete IO comparison. FlashAttention does not make attention free; it removes the T by T HBM round-trips.
Online Softmax

You can normalize without seeing every score at once.

Objective

By the end of this section, you should be able to compute online softmax by hand.

Recall stable softmax from Day 1:

softmax(s_i) = exp(s_i - max(s)) / sum_j exp(s_j - max(s))

Online softmax keeps two pieces of state while processing score blocks:

  • m: the running maximum score seen so far.
  • d: the running denominator after correcting for the current maximum.

Worked example with scores [2.0, -1.5, 0.3, 4.2], block size 2.

Block 1: [2.0, -1.5]

m = 2.0
d = exp(2.0 - 2.0) + exp(-1.5 - 2.0)
  = 1 + exp(-3.5)
  = 1.030

Block 2: [0.3, 4.2]

m_new = max(2.0, 4.2) = 4.2
d_new = old_d * exp(old_m - m_new) + sum(exp(block - m_new))
      = 1.030 * exp(2.0 - 4.2) + exp(0.3 - 4.2) + exp(4.2 - 4.2)
      = 1.030 * 0.111 + 0.020 + 1
      = 1.135

That denominator is the same denominator standard softmax would compute after subtracting the global max. Nothing approximate happened.

Online Softmax Keeps Running State block 1: [2.0, -1.5]m=2.0, d=1.030 block 2: [0.3, 4.2]m=4.2, d=1.135 -> d' = d * exp(m_old - m_new) + sum(exp(block - m_new)) The denominator matches standard softmax without seeing every score at once.
The running denominator is rescaled whenever the running maximum increases.
FlashAttention Loop

Keep S and P inside SRAM tiles.

Objective

By the end of this section, you should be able to narrate the FlashAttention forward pass.

For each tile of query rows:

  1. Load a Q tile into SRAM.
  2. Stream a K tile and V tile into SRAM.
  3. Compute the partial score tile S_tile = Q_tile K_tile^T.
  4. Update the running max and denominator for each query row.
  5. Accumulate the output tile using the corrected probabilities.
  6. Move to the next K,V tile.
  7. Write the final O tile once.

The full S and P matrices never cross HBM. They exist only as tile-local values. This is why FlashAttention is called IO-aware: the math is standard attention, but the schedule is designed around memory movement.

FlashAttention: Stream K/V Tiles Through SRAM Q tile K/V tile 0 K/V tile 1 K/V tile 2 O tileaccumulates in SRAM S and P exist only inside the tile loop, never as full T x T tensors in HBM.
FlashAttention streams K/V blocks through a resident Q tile and output accumulator.
Versions and Dispatch

v1 changed IO. v2 improved work partitioning. v3 targets Hopper.

FlashAttention v1 introduced the IO-aware exact attention algorithm. v2 reduced non-matmul overhead and improved parallel work partitioning. v3 targets Hopper GPUs with features like asynchronous data movement and warp-group scheduling.

In PyTorch, the practical entry point is:

torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)

On CUDA with supported dtype, shape, and mask, PyTorch can dispatch to a flash backend. On CPU, MPS, unsupported masks, or unsuitable shapes, it falls back. Your code should call the high-level API when possible and profile the actual dispatch when performance matters.

FlashAttention helps most in long-context prefill, where T_q and T_k are large. Decode with T_q = 1 benefits less from tiling because there is no large query block, though fused decode attention kernels still matter.

Did You Know?

A systems detail worth remembering.

The online softmax trick predates FlashAttention. FlashAttention's insight was to use that numerically stable recurrence to tile exact attention without materializing the full score matrix.
Exercise

Do the arithmetic, then run the notebook.

Use the notebook to:

  1. Implement standard stable softmax and online softmax in NumPy.
  2. Verify the example [2.0, -1.5, 0.3, 4.2] matches to 1e-7.
  3. Compute HBM traffic estimates for standard attention and FlashAttention at T = 512, 1024, 2048, 4096.
  4. If CUDA PyTorch is available, benchmark naive attention against scaled_dot_product_attention.

Read one FlashAttention kernel after the notebook and identify the tile loop, running max, denominator, and output accumulator.

Self-Check

Answer these from memory.

  1. Why is standard attention memory-heavy? It materializes S and P, two O(T^2) matrices.
  2. Is FlashAttention approximate? No. It computes exact attention up to floating-point differences.
  3. What state does online softmax keep? A running max and a running denominator.
  4. Where does FlashAttention help most? Long-context prefill, where T by T intermediates dominate memory traffic.
  5. What PyTorch API should you use first? torch.nn.functional.scaled_dot_product_attention.

"FlashAttention is standard attention scheduled as if bytes matter."

Day 21 · Week 3
Further Reading

Go deeper.

Primary references and the companion notebook for today's exercise.

Paper

FlashAttention v1

The original IO-aware exact attention algorithm.

Open
Paper

FlashAttention v2

Work partitioning and non-matmul efficiency improvements.

Open
Paper

FlashAttention v3

Hopper-focused FlashAttention with asynchronous hardware features.

Open
Paper

Online normalizer softmax

The online normalization trick FlashAttention relies on.

Open
Talk

Tri Dao FlashAttention talk

A useful visual walkthrough from the author.

Open
Notebook · Day 21

FlashAttention notebook

Companion Jupyter notebook with runnable calculations and optional hardware-specific cells.

Open notebook