LLM Inference Engineer · Day 06
Day 06 · Week 1 · Foundations
🎯

The Attention Mechanism

Scaled dot-product attention is the single most important architectural idea in deep learning since backpropagation. It is also famously confusing on first encounter. Today we kill the confusion: build it from scratch with real numbers, derive every formula, and trace the exact shapes that will matter in Weeks 3–4 when we optimize it.

Time~180 min
DifficultyMedium-Hard
PrerequisiteDays 1–5
Why This Lesson

The one operation that changed everything — and the one that will dominate your inference budget.

Attention is the architecture. Not a component inside a larger model — the mechanism that is the Transformer. Every open-weight model you will ever serve — LLaMA, Mistral, Qwen, DeepSeek, Gemma — is a stack of attention layers and feedforward layers, repeated dozens of times. If you understand attention deeply, you understand the machine.

This is also the lesson with the highest inference-engineering payoff. The attention score matrix has shape (B, h, T, T). That term is why long-context inference is hard, why the KV cache exists, why FlashAttention was a landmark paper, and why PagedAttention changed how we serve LLMs. You cannot reason about any of those systems without first understanding exactly what they are optimizing. That is what today is for.

Here is the mental model to hold throughout: attention is a soft, content-based dictionary lookup. A database has rows. Each row has a key ("this is what I am") and a value ("this is what I contain"). To retrieve information you issue a query ("this is what I need"). A normal database finds the one exact match and returns its value. Attention does the soft, differentiable version: compute how well your query matches every key, turn those scores into a probability distribution with softmax, and return a weighted average of all the values. That is all it is. The rest is engineering detail.

Learning objectives

  1. Read the math notation cheatsheet and decode every symbol in the attention formula on sight.
  2. Work a 3-token, 2-dimensional example by hand through all four steps of scaled dot-product attention.
  3. Implement single-head attention from scratch in PyTorch and verify every intermediate shape.
  4. Implement multi-head attention with a combined Q/K/V projection and explain why the reshape trick is valid.
  5. Apply a causal mask and prove the causal-equivariance property with a test.
  6. State the time and memory complexity of attention and explain when it dominates the FFN layer.
  7. Explain the KV cache idea at an architectural level and trace how it reduces redundant computation during generation.
  8. Read Q, K, V projections in any open Transformer source file and understand what they do.
Math Notation Cheatsheet

Every symbol decoded before you see it used.

This lesson is the most math-dense day in Week 1. Before we touch a formula, here is every symbol you will encounter, with a plain-English and Python reading. Bookmark this table — refer back whenever something looks unfamiliar.

SymbolReads asPython analogyWhere it appears
Q"queries matrix"a 2-D array of row vectorsinput to every step
K"keys matrix"same shape as Qprovides what tokens represent
V"values matrix"same shape as Q (often)provides what tokens contribute
d_k"dimension of keys"int d_k = 64scaling factor denominator
d_v"dimension of values"int d_v = 64output width of attention
T"sequence length"len(tokens)height of Q, K, V; side of score matrix
B"batch size"len(batch)leading batch dimension
h"number of heads"n_heads = 8multi-head split
QKᵀ"Q times K-transpose"Q @ K.Traw similarity scores
√d_k"square root of d_k"math.sqrt(d_k)variance-normalizing divisor
softmax(z)"softmax of vector z"exp(z)/sum(exp(z))turns scores into probabilities
"elementwise multiply"a * b (NumPy)masking, gates
Wᵥ, W_Q, W_K"weight matrices"nn.Linear(D, D, bias=False).weightlearnable projections
−∞"negative infinity"float('-inf')causal mask before softmax
ℝ^{T×d}"real-valued matrix of shape T-by-d"torch.Tensor of shape (T, d)type annotations

Key rule: when you see a subscript like S_{ij} it means element at row i, column j — the same as S[i, j] in Python. When you see Kᵀ it means the transpose of K — swap rows and columns, same as K.T or K.transpose(-2, -1).

From RNNs to Attention

Two fatal flaws that led to a better design.

Before transformers, the standard tool for sequence modeling was the recurrent neural network. An RNN processes tokens one at a time, maintaining a hidden state vector that summarizes everything seen so far. At each step it updates this state using the current token and then moves on.

This design has two killer flaws. The first is long-range dependencies. Information from token 1 must survive being rewritten at token 2, then token 3, and so on. By token 100, the original signal has been compressed and overwritten so many times that almost nothing useful remains. RNNs are notoriously bad at carrying information across long distances — the vanishing gradient problem is its mathematical manifestation.

The second is the sequential bottleneck. You cannot compute step 5 until step 4 has finished. GPUs want to execute millions of operations in parallel; forcing sequential ordering wastes almost all that parallelism.

In 2014, Bahdanau et al. added an "attention" mechanism on top of RNN encoders for machine translation: at each decoder step, instead of relying on a single compressed state, the model could directly look back at all encoder hidden states and weight them. This dramatically improved translation quality. In 2017, Vaswani et al. asked: what if we drop the RNN entirely? The Transformer was born — a network where every token directly attends to every other token in a single, fully-parallel operation. No compression, no forced sequential order, no information bottleneck.

The paper's title — "Attention Is All You Need" — turned out to be literally true.

The Intuition

Query, key, value: the librarian analogy.

Before any equations, pin down the intuition with a concrete story. Imagine a library where every book has two metadata cards attached to the spine: a key card that describes what the book is about ("18th-century naval history") and a value card that describes what you actually learn from reading it (specific battle dates, fleet sizes, strategic decisions). When you walk in with a query ("I need information about the Battle of Trafalgar"), the librarian checks your query against every book's key card, computes a relevance score, and returns a weighted summary — spending 70% of the reading time on the most relevant book, 20% on a related one, and 10% across the rest.

In a Transformer, every token simultaneously plays all three roles. It issues a query asking "what context do I need to interpret myself?" It has a key announcing "this is what I represent to other tokens." And it has a value carrying "this is what I contribute if someone attends to me." The three roles are learned independently through three weight matrices W_Q, W_K, W_V. They start random and gradient descent trains them to be useful.

Query token "what do I need?" 0.70 0.20 0.10 key: "subject match" value: rich relevant info key: "partial match" value: somewhat relevant key: "weak match" value: tangentially related output 0.7·V₁ + 0.2·V₂ + 0.1·V₃ attention weights (softmax output) sum to 1 across the keys
The librarian analogy for attention. One query token computes a relevance score against every other token's key, normalizes the scores with softmax into weights that sum to 1, then returns a weighted blend of the value vectors. In self-attention every token simultaneously issues a query AND acts as a key+value for everyone else.

The crucial difference from a real library: the keys, values, and queries are all learned continuous vectors, not categorical labels. The model learns, through gradient descent on billions of examples, which directions in vector space mean "this token is a verb looking for its subject" or "this token is a pronoun resolving its antecedent." The librarian analogy gives the architecture; training gives the semantics.

Step-by-Step With Numbers

Work a T=3, d_k=2 example by hand through all four steps.

By the end of this section you'll be able to execute every step of scaled dot-product attention with a pencil, and you'll know exactly what shape every intermediate tensor has. We use the smallest non-trivial example: three tokens, two-dimensional key/value space.

Suppose after the Q, K, V projections we have (all numbers made up but self-consistent):

Q = [[1, 0], ← query for token 0 [0, 1], ← query for token 1 [1, 1]] ← query for token 2 shape: (T=3, d_k=2) K = [[1, 0], ← key for token 0 [0, 1], ← key for token 1 [1, 1]] ← key for token 2 shape: (T=3, d_k=2) V = [[2, 0], ← value for token 0 [0, 3], ← value for token 1 [1, 1]] ← value for token 2 shape: (T=3, d_v=2)

Step 1: Compute raw similarity scores S = Q Kᵀ

Matrix-multiply Q by the transpose of K. The result has shape (T, T) = (3, 3). Element S[i, j] is the dot product of token i's query with token j's key — how much token i "wants to attend to" token j.

Kᵀ = [[1, 0, 1], [0, 1, 1]] shape: (d_k=2, T=3) S = Q @ Kᵀ = [[1,0],[0,1],[1,1]] @ [[1,0,1],[0,1,1]] S[0,0] = q₀·k₀ = 1·1 + 0·0 = 1 S[0,1] = q₀·k₁ = 1·0 + 0·1 = 0 S[0,2] = q₀·k₂ = 1·1 + 0·1 = 1 S[1,0] = q₁·k₀ = 0·1 + 1·0 = 0 S[1,1] = q₁·k₁ = 0·0 + 1·1 = 1 S[1,2] = q₁·k₂ = 0·1 + 1·1 = 1 S[2,0] = q₂·k₀ = 1·1 + 1·0 = 1 S[2,1] = q₂·k₁ = 1·0 + 1·1 = 1 S[2,2] = q₂·k₂ = 1·1 + 1·1 = 2 S = [[1, 0, 1], [0, 1, 1], [1, 1, 2]]
Step 1 — Raw Scores: S = Q Kᵀ shape (T, T) = (3, 3) Q (3×2) 1 0 0 1 1 1 q₀ @ Kᵀ (2×3) 1 0 1 0 1 1 k₀ k₁ k₂ = S (3×3) — raw scores k₀ k₁ k₂ 1 0 1 0 1 1 1 1 2 q₀ q₁ q₂ Large score = query and key point in the same direction. q₂ matches k₂ most (score=2) because both are [1,1]. S[i,j] = dot(q_i, k_j) — measures how much query i "wants" information from token j
Step 1: raw score matrix S = Q Kᵀ, shape (T, T). Each cell S[i,j] is the dot product of query i with key j. Darker = higher score = stronger affinity. The bottom-right cell has score 2 because both q₂ and k₂ equal [1,1].

Step 2: Scale by 1/√d_k — and why it matters

Divide every element of S by √d_k = √2 ≈ 1.414. This is not cosmetic. Here is the argument in numbers.

Imagine entries of Q and K are drawn from a standard normal distribution (mean 0, variance 1). The dot product of a d_k-dimensional query with a key is a sum of d_k products. Each product has variance 1 (since variance multiplies). Because variance is additive for independent terms, the dot product has variance d_k and standard deviation √d_k. When d_k = 64 (a typical head dimension), the standard deviation of raw scores is 8. When d_k = 128 it is 11. These large spreads push softmax into a near-one-hot state.

Step 2 — Why 1/√d_k: softmax saturation at d_k=64 Without scaling (raw scores, σ≈8) ≈1.0 softmax weight per token ∂L/∂score ≈ 0 everywhere except the max Gradient vanishes → model can't learn With 1/√d_k scaling (σ≈1) ≈0.35 softmax weight per token ∂L/∂score is meaningful for multiple tokens Gradient flows → model learns Dividing by √d_k normalizes the score variance back to 1, keeping softmax soft and gradients alive.
Step 2: the scaling argument. Without 1/√d_k, the raw dot products have standard deviation √d_k, pushing softmax into a near-one-hot regime (left). With scaling, the distribution stays soft (right), and gradient flows to multiple tokens so the model can actually learn.

Back to our example. With d_k = 2, divide every element of S by √2 ≈ 1.414:

S_scaled = S / √2 = [[0.71, 0.00, 0.71], [0.00, 0.71, 0.71], [0.71, 0.71, 1.41]]

Step 3: Softmax along the keys axis

Apply softmax to each row independently. Recall from Day 1: softmax(z)_i = exp(z_i) / Σ_j exp(z_j). Each row will sum to 1 — it is now a probability distribution describing how much that query attends to each key position.

Row 0: exp([0.71, 0.00, 0.71]) = [2.03, 1.00, 2.03] → sum=5.07 A[0] = [0.40, 0.20, 0.40] ← q₀ attends equally to k₀ and k₂ Row 1: exp([0.00, 0.71, 0.71]) = [1.00, 2.03, 2.03] → sum=5.07 A[1] = [0.20, 0.40, 0.40] ← q₁ attends equally to k₁ and k₂ Row 2: exp([0.71, 0.71, 1.41]) = [2.03, 2.03, 4.10] → sum=8.16 A[2] = [0.25, 0.25, 0.50] ← q₂ attends most to k₂ (highest similarity) A = [[0.40, 0.20, 0.40], [0.20, 0.40, 0.40], [0.25, 0.25, 0.50]]
Step 3 — Attention Weights A = softmax(S / √d_k) each row sums to 1 k₀ k₁ k₂ keys (attended to) → q₀ q₁ q₂ 0.40 0.20 0.40 0.20 0.40 0.40 0.25 0.25 0.50 output₀ = 0.40·V₀ + 0.20·V₁ + 0.40·V₂ output₁ = 0.20·V₀ + 0.40·V₁ + 0.40·V₂ output₂ = 0.25·V₀ + 0.25·V₁ + 0.50·V₂ Darker cell = higher weight. Read a row left-to-right: "how much of each token's value enters my output?"
Step 3: attention weight matrix A after softmax. Each row is a valid probability distribution over the keys (sums to 1). Token 2's query attends most strongly to key 2 (weight 0.50) because their vectors are most aligned.

Step 4: Weighted sum of values — the output

Multiply the attention weights A by the value matrix V. This is a standard matrix multiply: output = A @ V, shape (T, d_v) = (3, 2). Each output row is a convex combination of the value vectors.

output = A @ V output[0] = 0.40·[2,0] + 0.20·[0,3] + 0.40·[1,1] = [0.80, 0] + [0, 0.60] + [0.40, 0.40] = [1.20, 1.00] output[1] = 0.20·[2,0] + 0.40·[0,3] + 0.40·[1,1] = [0.40, 0] + [0, 1.20] + [0.40, 0.40] = [0.80, 1.60] output[2] = 0.25·[2,0] + 0.25·[0,3] + 0.50·[1,1] = [0.50, 0] + [0, 0.75] + [0.50, 0.50] = [1.00, 1.25] output = [[1.20, 1.00], [0.80, 1.60], [1.00, 1.25]] shape (T=3, d_v=2)

That is the full attention computation. In the general formula notation:

Attention(Q, K, V) = softmax(Q Kᵀ / √d_k) · V

Read it as: "compute similarity scores (Q Kᵀ), normalize the variance (÷√d_k), turn into probabilities (softmax), blend the values (·V)." Four operations. One line.

Causal Masking

Position t can never see t+1. This is how we train an autoregressive model in parallel.

For a language model like GPT, the training objective is next-token prediction: given tokens 0..t, predict token t+1. If the model could attend to future tokens during training, it would trivially copy them — no learning would happen. We have to enforce that position t only sees positions 0, 1, ..., t.

The mechanism is elegant. Before applying softmax, we set the scores at future positions to negative infinity. After softmax, exp(−∞) = 0, so those positions receive exactly zero weight. No information leaks forward.

The mask is a lower-triangular matrix of ones (1 = allowed, 0 = block). Apply it as:

T = 5
causal_mask = torch.tril(torch.ones(T, T))
# Replace 0s with -inf in the score matrix before softmax:
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
Causal mask (T=5) — 1=visible, ·=−∞ t=0 t=1 t=2 t=3 t=4 t=0 t=1 t=2 t=3 t=4 1 −∞ −∞ −∞ −∞ 1 1 −∞ −∞ −∞ 1 1 1 −∞ −∞ 1 1 1 1 −∞ 1 1 1 1 1 visible — gradient flows masked → score −∞ → softmax → 0 The causal property: Output at position t is a function of tokens 0 .. t only — never t+1 .. T-1. → Training computes loss for ALL T positions in a single forward pass. Fully parallel. → This is why the Transformer trains so much faster than RNNs at sequence processing. Mask broadcasts over (B, h, T, T) — no copy needed.
The causal mask is a lower-triangular matrix applied to the score matrix before softmax. Masked cells get score −∞, which softmax maps to exactly 0. This enforces the autoregressive constraint while keeping the full forward pass parallel — every position's loss is computed simultaneously during training.

Self vs cross attention

VariantQ comes fromK, V come fromMaskUse
Self-attention (encoder)same sequence Xsame sequence Xnone (bidirectional)BERT, understanding
Self-attention (decoder)same sequence Xsame sequence Xcausal lower-triangularGPT, LLaMA, generation
Cross-attentiondecoder hidden statesencoder outputsnone (attend to full encoder)T5, translation decoder

For the rest of this curriculum we are exclusively concerned with decoder self-attention with causal masking — GPT, LLaMA, Mistral, Qwen, DeepSeek, Gemma all live here.

Multi-Head Attention

Run h independent heads in parallel. Concat. Project. Each head specializes.

A single attention head can only learn one type of relationship at a time. A model tracking syntactic subject-verb agreement, long-range coreference, and local repetition simultaneously must compromise with a single head. The fix is straightforward: run h smaller attention computations in parallel, each with its own independent Q, K, V projections, and concatenate their outputs.

Each head operates on dimension d_head = d_model / h. The computation budget stays roughly constant — we split d_model into h pieces rather than running one big computation. After all heads produce their (B, h, T, d_head) outputs, we concatenate back to (B, T, d_model) and apply a final linear layer W_O to mix information across heads.

Multi-Head Attention — shape flow (B, T, d_model) → … → (B, T, d_model) x input (B, T, D) W_QKV Linear(D, 3D) chunk(3) → Q,K,V each (B,T,D) split heads view + transpose (B,T,D) →(B,h,T,d_head) SDPA per head Q @ Kᵀ / √d_head softmax → A A @ V (B,h,T,d_head) concat heads transpose(1,2) → (B,T,D) W_O Linear(D,D) h heads run in parallel on the same input x head 0 head 1 head 2 head h-1 d_head = d_model / h → each head is a full SDPA on a smaller subspace total params: 3·D² (QKV) + D² (W_O) = 4·D² (same as one big head)
Multi-head attention shape flow. The key insight is the reshape: (B, T, D)(B, h, T, d_head) via view + transpose. This turns one large attention into h parallel small attentions, handled as a single batched matrix multiply. Parameter count is identical to one big head: 4·D² total.

Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.W_QKV = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_O   = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        B, T, D = x.shape
        qkv = self.W_QKV(x)                                         # (B, T, 3D)
        q, k, v = qkv.chunk(3, dim=-1)                              # each (B, T, D)

        def split_heads(t):
            return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        q, k, v = map(split_heads, (q, k, v))                       # (B, h, T, d_head)

        scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_head)   # (B, h, T, T)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = attn @ v                                               # (B, h, T, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T, D)         # (B, T, D)
        return self.W_O(out)

One optimization worth noting: W_QKV is a single Linear(D, 3D) instead of three separate Linear(D, D) layers. A single large matrix multiply runs faster on GPU than three smaller ones (fewer kernel launches, better BLAS utilization). It is mathematically equivalent.

Parameter count

Total parameters per attention layer (no biases): 3D² + D² = 4D². Three from the combined QKV projection and one from the output projection. At d_model = 512, n_heads = 8, that is 4 × 512² ≈ 1.05M parameters per layer. LLaMA-7B has d_model = 4096 and 32 layers: 4 × 4096² × 32 ≈ 2.15B parameters in attention alone.

The eight authors of "Attention Is All You Need" (Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin) wrote in the paper: "All authors contributed equally. Listing order is random." They have since founded or co-founded Cohere, Character.ai, Adept, Inflection, and several other AI labs. One paper, eight companies.

Complexity

O(T²·d) compute and O(T²) memory. The whole reason long context is hard.

The score matrix S = Q Kᵀ has shape (B, h, T, T). Its size grows with the square of sequence length. This quadratic scaling is the dominant cost at long context and the root motivation for everything in Week 3 of this curriculum.

Concrete numbers

Sequence length TScore matrix (B=1, h=32, FP16)×32 layers
5120.017 GB0.5 GB
2 0480.27 GB8.6 GB
8 1924.3 GB137 GB
32 76868.7 GB2.2 TB
128 0001 049 GB33.5 TB

At 128k context (Claude 3 / GPT-4 Turbo range) the naive score matrix dwarfs the model weights by a factor of ~100. This is why FlashAttention (Day 21) is not an optimization — it is a necessity. Its insight: never materialize the full (T, T) score matrix in high-bandwidth memory; compute attention in tiles that fit in fast on-chip SRAM instead.

Attention vs FFN crossover

Within one Transformer block, attention competes with the feedforward layer (which you will meet on Day 7) for compute time. Their scaling is different:

  • Attention: O(B · T² · d_model) — quadratic in sequence length.
  • FFN: O(B · T · d_model²) — linear in sequence length, quadratic in model dimension.

The crossover happens when T ≈ d_model. For LLaMA-7B, d_model = 4096. At T = 4096 tokens the two layers cost roughly the same. At T = 32 768, attention costs 8× more per layer than FFN. Long-context inference is almost entirely an attention problem.

Attention vs FFN compute cost — crossover at T ≈ d_model Sequence length T → Compute (FLOPs) 1k 4k 16k 64k 128k FFN O(T·D²) Attention O(T²·D) crossover T ≈ 4096 = D
Attention compute scales as (red); FFN scales linearly in T (gold). They cross at T ≈ d_model. For a model with d_model=4096, attention dominates beyond ~4k tokens. At 128k tokens, attention is roughly 1000× the FFN cost per token generated.
Why This Matters for Inference

Recomputing K and V for every past token, every step, is catastrophically wasteful.

During training, the entire sequence of T tokens is processed together in one forward pass. During inference generation, the model produces one token at a time. After generating token t, to generate token t+1 the model needs to attend over all t+1 positions. The naive approach reruns the full attention computation: recompute Q, K, V for all past tokens, recompute QKᵀ, re-softmax, re-blend. This is enormously wasteful — you are recomputing K and V for tokens 0..t−1 on every single generation step, even though nothing about those tokens has changed.

The KV cache is the fix. Before producing token t, cache the K and V vectors you already computed for tokens 0..t−1. On the next step, only compute K and V for the one new token (position t), append them to the cache, and run attention against the full cached sequence. The score matrix goes from (T, T) (full recompute) to just a single new row (1, T) (cached). For a 512-token prefix the savings are 512×.

KV Cache: cached vs naïve at generation step t Without KV Cache (naïve recompute) all T tokens → W_Q, W_K, W_V projections Q(T×d) @ K(T×d)ᵀ → score (T×T) softmax(score) @ V(T×d) → output cost: O(T²·d) per step K, V for past tokens recomputed every step! With KV Cache (store and reuse) new token only → W_Q, W_K, W_V K_new, V_new appended to K_cache, V_cache Q(1×d) @ K_cache(T×d)ᵀ → score (1×T) cost: O(T·d) per step T× faster than naïve — K, V computed once
Without a KV cache, generating each new token reruns full attention over all past tokens — O(T²) per step. With the KV cache, past K and V vectors are stored and reused; only the new token's K and V are computed — O(T) per step. This is the single most important optimization for autoregressive inference. Day 20 covers the full implementation.

The KV cache has its own cost: it occupies GPU memory. For a model with n_layers layers, n_heads heads, head dimension d_head, and sequence length T, the cache size is:

KV cache memory = 2 · n_layers · n_heads · d_head · T · dtype_bytes Example: LLaMA-7B (n_layers=32, n_heads=32, d_head=128, FP16): = 2 × 32 × 32 × 128 × T × 2 bytes = 524,288 × T bytes ≈ 0.5 MB per token At T = 4096 tokens: 2 GB of KV cache At T = 32768 tokens: 16 GB of KV cache

This is precisely why serving many long-context requests simultaneously is a memory management problem — and why PagedAttention (Day 24) exists. It borrows the page-table idea from operating systems to manage KV cache memory in non-contiguous blocks, enabling much higher serving throughput.

The KV cache, FlashAttention, and PagedAttention are the three most important inference-engineering optimizations for Transformer models. All three exist because of the quadratic T² cost of the attention score matrix you just computed by hand. Understanding why that matrix is expensive is the single most important thing you can take from today's lesson.

The inference pipeline, teased

Here is what each upcoming week-3 day builds on today's foundation:

DayTopicConnection to today
Day 20KV Cache in depthExactly the cache above — memory layout, eviction, batch management
Day 21FlashAttentionAvoids materializing the (T, T) score matrix; tiles the computation in SRAM
Day 24PagedAttentionOS-style paged memory management for the KV cache; enables vLLM
Implementation

Single-head attention: eight lines of PyTorch.

Here is the complete single-head attention module. Read the code line-by-line while tracking shapes in the comments.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SingleHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int):
        super().__init__()
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)
        self.d_k = d_k

    def forward(self, x, mask=None):
        # x: (B, T, d_model)
        Q = self.W_Q(x)                                # (B, T, d_k)
        K = self.W_K(x)                                # (B, T, d_k)
        V = self.W_V(x)                                # (B, T, d_k)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)  # (B, T, T)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)               # (B, T, T)
        return attn @ V                                # (B, T, d_k)

One detail: bias=False on every nn.Linear. This is standard in transformer implementations. The bias adds parameters without meaningful empirical benefit when a layer normalization follows attention (which it always does — we cover that on Day 7).

In modern PyTorch, replace the inner computation with F.scaled_dot_product_attention(Q, K, V, is_causal=True), which automatically dispatches to FlashAttention on supported hardware. Use it in production; keep the explicit version above for learning.

Same Math, Different Framework

MLX: identical math, JAX-style idioms.

For completeness on Apple Silicon, here is multi-head attention in MLX. Compare each line to the PyTorch version — the differences are cosmetic but reveal framework philosophy.

import mlx.core as mx
import mlx.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.W_QKV = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_O   = nn.Linear(d_model, d_model, bias=False)

    def __call__(self, x, mask=None):
        B, T, D = x.shape
        qkv = self.W_QKV(x)
        q, k, v = mx.split(qkv, 3, axis=-1)

        q = q.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3)
        k = k.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3)
        v = v.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3)

        scores = (q @ k.transpose(0, 1, 3, 2)) / mx.sqrt(self.d_head)
        if mask is not None:
            scores = mx.where(mask == 0, mx.array(-1e9), scores)
        attn = mx.softmax(scores, axis=-1)
        out = attn @ v
        out = out.transpose(0, 2, 1, 3).reshape(B, T, D)
        return self.W_O(out)

Key differences from PyTorch: __call__ instead of forward; mx.split instead of .chunk; mx.where instead of masked_fill; no .contiguous() (lazy graph handles layout). Given the same weights and inputs, both implementations produce identical outputs to within floating-point precision.

Sanity Checks

Four tests that catch almost every attention bug.

Attention bugs are notoriously hard to detect from a loss curve alone. Run these checks on any new implementation before trusting it.

  1. Shape preservation. Input (B, T, D) must produce output (B, T, D). If anything else comes back, you have a transpose or reshape bug.
  2. Rows sum to 1. After softmax, each row of the attention weight matrix must sum to exactly 1 (or 0 for a fully-masked row). assert A[0, 0].sum(-1).allclose(ones).
  3. Causal correctness. Perturbing token at position i must not change outputs at positions < i:
    x = torch.randn(1, 5, 64)
    y1 = attn(x, mask=causal_mask)
    x2 = x.clone(); x2[0, 4] = torch.randn(64)   # change last token
    y2 = attn(x2, mask=causal_mask)
    assert torch.allclose(y1[:, :4], y2[:, :4])
  4. Gradient flow. loss = attn(x).sum(); loss.backward() must not produce NaN gradients. If it does, your scaling or masking is wrong.
Exercise

Eight exercises in the notebook.

Companion notebook: day-6-attention.ipynb.

  1. Numerical worked example. Reproduce the T=3, d_k=2 example from this lesson in NumPy. Print every intermediate: Q, K, V, S, S_scaled, A, output. Verify against the values in the lesson.
  2. Scaling demonstration. For d_k ∈ {8, 64, 512, 4096}, show that max(softmax(raw_scores)) approaches 1.0 as d_k grows, and stays reasonable with 1/√d_k scaling.
  3. Single-head attention as nn.Module. Test shapes at every step. Verify against F.scaled_dot_product_attention with an assertion.
  4. Multi-head from raw matrices. Implement multi-head attention using only raw matrix multiplies (no nn.Linear). Confirm shapes match.
  5. Causal mask demo. Print a masked attention weight matrix — confirm the upper triangle is exactly zero. Run the causal equivariance test.
  6. Verify against nn.MultiheadAttention. Copy weights from your implementation to PyTorch's built-in and confirm outputs match to 1e-5.
  7. O(T²) cost measurement. Time and measure memory of naive attention at T ∈ {256, 512, 1024, 2048} and plot vs T². The quadratic blow-up should be unmistakable.
  8. Minimal KV-cache demo. Implement a function that produces identical output to full attention but reuses cached K and V. Assert that cached and non-cached give identical results.
Self-Check

Ten questions before moving on.

Close the page and answer from memory. If you can't, re-read the relevant section.

  1. Decode the full attention formula softmax(QKᵀ / √d_k) V symbol by symbol. What is Q? What is Kᵀ? What does softmax operate over?
  2. Why divide by √d_k? What is the variance argument? What goes wrong at d_k = 4096 without it?
  3. Why does softmax run along the keys axis and not the queries axis? Justify with the database analogy.
  4. Work through our T=3, d_k=2 example. Recompute S[2,2] by hand and verify the softmax weights for row 2.
  5. With a causal mask in place, what is attn[0, 0, 5, 7] after softmax? Why is it exactly zero?
  6. If d_model = 768 and n_heads = 12, what is d_head? What is the total parameter count of one attention layer (no biases)?
  7. In multi-head attention, do all heads operate on the same Q, K, V vectors? How does the view + transpose trick create independent heads?
  8. What is the time complexity of computing the attention score matrix? What is the memory complexity? How do they scale when you double T?
  9. Explain the KV cache idea in two sentences. Why does it reduce the cost of generation from O(T²) per step to O(T) per step?
  10. What is the difference between encoder self-attention, decoder self-attention (causal), and cross-attention? Which one do GPT-family models use?

"Attention is a soft, content-based dictionary lookup that turned out to be everything we needed — and the quadratic cost of that lookup is what all of inference engineering Week 3 exists to solve."

Day 6
Further Reading

Go deeper.

Hand-picked references for this lesson.

Paper · 2017

Vaswani et al. — Attention Is All You Need

The paper. Section 3 is the meat of today's lesson. Read it after completing the exercises.

Open paper
Blog · Visual

Alammar — The Illustrated Transformer

The best visual walkthrough ever made. Read alongside the paper.

Read post
Blog · Code

Harvard NLP — The Annotated Transformer

Paper plus working PyTorch code, side by side. Every equation has its line of code.

Read post
YouTube · 2 hr

Karpathy — Let's build GPT from scratch

Self-attention from scratch on Shakespeare. The code that became nanoGPT.

Watch on YouTube
YouTube · 25 min

3Blue1Brown — Attention, visually explained

Beautifully animated geometry of attention. Best complement to this lesson's diagrams.

Watch on YouTube
Paper · 2014

Bahdanau et al. — Neural MT with attention

The original attention mechanism added to RNN encoder-decoders for translation.

Open paper
Blog · Anthropic

Mathematical Framework for Transformer Circuits

What attention heads actually compute — explicit linear-algebra view. Advanced but rewarding.

Read post
Repo · Karpathy

nanoGPT model.py

Clean causal self-attention in ~50 lines. Production-quality reference implementation.

View source
Blog · Lilian Weng

The Transformer Family v2.0

Survey of attention variants. A useful map of the design space once you know the basics.

Read post
Paper · 2022

Dao et al. — FlashAttention

The paper that solved the T² memory problem you computed in the complexity section. Required reading before Day 21.

Open paper
Repo · HF

transformers — modeling_gpt2.py

Production reference with all corner cases. Find GPT2Attention and trace it against today's lesson.

View source