FlashAttention is the attention algorithm behind modern fast LLM inference and training. Its central trick is not a new approximation. It computes exact attention while avoiding the full T by T score and probability matrices in HBM.
Day 6 taught attention as math. Day 16 taught roofline thinking. Day 20 taught that cache memory dominates decode. Day 21 combines them: attention is not just softmax(QK^T)V; it is also a memory movement problem.
FlashAttention is important because it is exact, widely deployed, and conceptually reusable. Once you understand the online softmax update, other IO-aware kernels become much less mysterious.
T and d.scaled_dot_product_attention and know when it can dispatch to a flash backend.Q, K, and V are query, key, and value matrices from Day 6.S = QK^T is the score matrix before softmax.P = softmax(S) is the probability matrix.B_r is the number of Q rows in a tile.B_c is the number of K/V rows in a tile.m is the running maximum used for stable online softmax.d or l is the running denominator, the sum of exponentials after max correction.By the end of this section, you should be able to point at the memory bottleneck in standard attention.
Standard attention is usually written:
S = QK^T / sqrt(head_dim)
P = softmax(S)
O = P V
That formula is correct, but a naive implementation materializes S and P, each shaped [T, T] per head. For T = 4096, that is:
T^2 = 4096 * 4096 = 16,777,216 elements
FP16 bytes = 2
one T x T matrix ~= 33.6 MB per head
S and P together ~= 67 MB per head
The actual useful input and output tensors are much smaller:
Q, K, V, O each ~= T * head_dim * 2 bytes
for head_dim = 128: 4096 * 128 * 2 ~= 1 MB
The expensive part is not the idea of attention. It is writing and rereading the giant intermediate matrices.
By the end of this section, you should be able to compute online softmax by hand.
Recall stable softmax from Day 1:
softmax(s_i) = exp(s_i - max(s)) / sum_j exp(s_j - max(s))
Online softmax keeps two pieces of state while processing score blocks:
m: the running maximum score seen so far.d: the running denominator after correcting for the current maximum.Worked example with scores [2.0, -1.5, 0.3, 4.2], block size 2.
Block 1: [2.0, -1.5]
m = 2.0
d = exp(2.0 - 2.0) + exp(-1.5 - 2.0)
= 1 + exp(-3.5)
= 1.030
Block 2: [0.3, 4.2]
m_new = max(2.0, 4.2) = 4.2
d_new = old_d * exp(old_m - m_new) + sum(exp(block - m_new))
= 1.030 * exp(2.0 - 4.2) + exp(0.3 - 4.2) + exp(4.2 - 4.2)
= 1.030 * 0.111 + 0.020 + 1
= 1.135
That denominator is the same denominator standard softmax would compute after subtracting the global max. Nothing approximate happened.
By the end of this section, you should be able to narrate the FlashAttention forward pass.
For each tile of query rows:
Q tile into SRAM.K tile and V tile into SRAM.S_tile = Q_tile K_tile^T.K,V tile.O tile once.The full S and P matrices never cross HBM. They exist only as tile-local values. This is why FlashAttention is called IO-aware: the math is standard attention, but the schedule is designed around memory movement.
FlashAttention v1 introduced the IO-aware exact attention algorithm. v2 reduced non-matmul overhead and improved parallel work partitioning. v3 targets Hopper GPUs with features like asynchronous data movement and warp-group scheduling.
In PyTorch, the practical entry point is:
torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
On CUDA with supported dtype, shape, and mask, PyTorch can dispatch to a flash backend. On CPU, MPS, unsupported masks, or unsuitable shapes, it falls back. Your code should call the high-level API when possible and profile the actual dispatch when performance matters.
FlashAttention helps most in long-context prefill, where T_q and T_k are large. Decode with T_q = 1 benefits less from tiling because there is no large query block, though fused decode attention kernels still matter.
Use the notebook to:
[2.0, -1.5, 0.3, 4.2] matches to 1e-7.T = 512, 1024, 2048, 4096.scaled_dot_product_attention.Read one FlashAttention kernel after the notebook and identify the tile loop, running max, denominator, and output accumulator.
S and P, two O(T^2) matrices.torch.nn.functional.scaled_dot_product_attention."FlashAttention is standard attention scheduled as if bytes matter."
Primary references and the companion notebook for today's exercise.
Companion Jupyter notebook with runnable calculations and optional hardware-specific cells.
Open notebook