Scaled dot-product attention is the single most important architectural idea in deep learning since backpropagation. It is also famously confusing on first encounter. Today we kill the confusion: build it from scratch with real numbers, derive every formula, and trace the exact shapes that will matter in Weeks 3–4 when we optimize it.
Attention is the architecture. Not a component inside a larger model — the mechanism that is the Transformer. Every open-weight model you will ever serve — LLaMA, Mistral, Qwen, DeepSeek, Gemma — is a stack of attention layers and feedforward layers, repeated dozens of times. If you understand attention deeply, you understand the machine.
This is also the lesson with the highest inference-engineering payoff. The attention score matrix has shape (B, h, T, T). That T² term is why long-context inference is hard, why the KV cache exists, why FlashAttention was a landmark paper, and why PagedAttention changed how we serve LLMs. You cannot reason about any of those systems without first understanding exactly what they are optimizing. That is what today is for.
Here is the mental model to hold throughout: attention is a soft, content-based dictionary lookup. A database has rows. Each row has a key ("this is what I am") and a value ("this is what I contain"). To retrieve information you issue a query ("this is what I need"). A normal database finds the one exact match and returns its value. Attention does the soft, differentiable version: compute how well your query matches every key, turn those scores into a probability distribution with softmax, and return a weighted average of all the values. That is all it is. The rest is engineering detail.
This lesson is the most math-dense day in Week 1. Before we touch a formula, here is every symbol you will encounter, with a plain-English and Python reading. Bookmark this table — refer back whenever something looks unfamiliar.
| Symbol | Reads as | Python analogy | Where it appears |
|---|---|---|---|
Q | "queries matrix" | a 2-D array of row vectors | input to every step |
K | "keys matrix" | same shape as Q | provides what tokens represent |
V | "values matrix" | same shape as Q (often) | provides what tokens contribute |
d_k | "dimension of keys" | int d_k = 64 | scaling factor denominator |
d_v | "dimension of values" | int d_v = 64 | output width of attention |
T | "sequence length" | len(tokens) | height of Q, K, V; side of score matrix |
B | "batch size" | len(batch) | leading batch dimension |
h | "number of heads" | n_heads = 8 | multi-head split |
QKᵀ | "Q times K-transpose" | Q @ K.T | raw similarity scores |
√d_k | "square root of d_k" | math.sqrt(d_k) | variance-normalizing divisor |
softmax(z) | "softmax of vector z" | exp(z)/sum(exp(z)) | turns scores into probabilities |
⊙ | "elementwise multiply" | a * b (NumPy) | masking, gates |
Wᵥ, W_Q, W_K | "weight matrices" | nn.Linear(D, D, bias=False).weight | learnable projections |
−∞ | "negative infinity" | float('-inf') | causal mask before softmax |
ℝ^{T×d} | "real-valued matrix of shape T-by-d" | torch.Tensor of shape (T, d) | type annotations |
Key rule: when you see a subscript like S_{ij} it means element at row i, column j — the same as S[i, j] in Python. When you see Kᵀ it means the transpose of K — swap rows and columns, same as K.T or K.transpose(-2, -1).
Before transformers, the standard tool for sequence modeling was the recurrent neural network. An RNN processes tokens one at a time, maintaining a hidden state vector that summarizes everything seen so far. At each step it updates this state using the current token and then moves on.
This design has two killer flaws. The first is long-range dependencies. Information from token 1 must survive being rewritten at token 2, then token 3, and so on. By token 100, the original signal has been compressed and overwritten so many times that almost nothing useful remains. RNNs are notoriously bad at carrying information across long distances — the vanishing gradient problem is its mathematical manifestation.
The second is the sequential bottleneck. You cannot compute step 5 until step 4 has finished. GPUs want to execute millions of operations in parallel; forcing sequential ordering wastes almost all that parallelism.
In 2014, Bahdanau et al. added an "attention" mechanism on top of RNN encoders for machine translation: at each decoder step, instead of relying on a single compressed state, the model could directly look back at all encoder hidden states and weight them. This dramatically improved translation quality. In 2017, Vaswani et al. asked: what if we drop the RNN entirely? The Transformer was born — a network where every token directly attends to every other token in a single, fully-parallel operation. No compression, no forced sequential order, no information bottleneck.
The paper's title — "Attention Is All You Need" — turned out to be literally true.
Before any equations, pin down the intuition with a concrete story. Imagine a library where every book has two metadata cards attached to the spine: a key card that describes what the book is about ("18th-century naval history") and a value card that describes what you actually learn from reading it (specific battle dates, fleet sizes, strategic decisions). When you walk in with a query ("I need information about the Battle of Trafalgar"), the librarian checks your query against every book's key card, computes a relevance score, and returns a weighted summary — spending 70% of the reading time on the most relevant book, 20% on a related one, and 10% across the rest.
In a Transformer, every token simultaneously plays all three roles. It issues a query asking "what context do I need to interpret myself?" It has a key announcing "this is what I represent to other tokens." And it has a value carrying "this is what I contribute if someone attends to me." The three roles are learned independently through three weight matrices W_Q, W_K, W_V. They start random and gradient descent trains them to be useful.
The crucial difference from a real library: the keys, values, and queries are all learned continuous vectors, not categorical labels. The model learns, through gradient descent on billions of examples, which directions in vector space mean "this token is a verb looking for its subject" or "this token is a pronoun resolving its antecedent." The librarian analogy gives the architecture; training gives the semantics.
By the end of this section you'll be able to execute every step of scaled dot-product attention with a pencil, and you'll know exactly what shape every intermediate tensor has. We use the smallest non-trivial example: three tokens, two-dimensional key/value space.
Suppose after the Q, K, V projections we have (all numbers made up but self-consistent):
Matrix-multiply Q by the transpose of K. The result has shape (T, T) = (3, 3). Element S[i, j] is the dot product of token i's query with token j's key — how much token i "wants to attend to" token j.
S[i,j] is the dot product of query i with key j. Darker = higher score = stronger affinity. The bottom-right cell has score 2 because both q₂ and k₂ equal [1,1].Divide every element of S by √d_k = √2 ≈ 1.414. This is not cosmetic. Here is the argument in numbers.
Imagine entries of Q and K are drawn from a standard normal distribution (mean 0, variance 1). The dot product of a d_k-dimensional query with a key is a sum of d_k products. Each product has variance 1 (since variance multiplies). Because variance is additive for independent terms, the dot product has variance d_k and standard deviation √d_k. When d_k = 64 (a typical head dimension), the standard deviation of raw scores is 8. When d_k = 128 it is 11. These large spreads push softmax into a near-one-hot state.
1/√d_k, the raw dot products have standard deviation √d_k, pushing softmax into a near-one-hot regime (left). With scaling, the distribution stays soft (right), and gradient flows to multiple tokens so the model can actually learn.Back to our example. With d_k = 2, divide every element of S by √2 ≈ 1.414:
Apply softmax to each row independently. Recall from Day 1: softmax(z)_i = exp(z_i) / Σ_j exp(z_j). Each row will sum to 1 — it is now a probability distribution describing how much that query attends to each key position.
Multiply the attention weights A by the value matrix V. This is a standard matrix multiply: output = A @ V, shape (T, d_v) = (3, 2). Each output row is a convex combination of the value vectors.
That is the full attention computation. In the general formula notation:
Read it as: "compute similarity scores (Q Kᵀ), normalize the variance (÷√d_k), turn into probabilities (softmax), blend the values (·V)." Four operations. One line.
For a language model like GPT, the training objective is next-token prediction: given tokens 0..t, predict token t+1. If the model could attend to future tokens during training, it would trivially copy them — no learning would happen. We have to enforce that position t only sees positions 0, 1, ..., t.
The mechanism is elegant. Before applying softmax, we set the scores at future positions to negative infinity. After softmax, exp(−∞) = 0, so those positions receive exactly zero weight. No information leaks forward.
The mask is a lower-triangular matrix of ones (1 = allowed, 0 = block). Apply it as:
T = 5
causal_mask = torch.tril(torch.ones(T, T))
# Replace 0s with -inf in the score matrix before softmax:
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
| Variant | Q comes from | K, V come from | Mask | Use |
|---|---|---|---|---|
| Self-attention (encoder) | same sequence X | same sequence X | none (bidirectional) | BERT, understanding |
| Self-attention (decoder) | same sequence X | same sequence X | causal lower-triangular | GPT, LLaMA, generation |
| Cross-attention | decoder hidden states | encoder outputs | none (attend to full encoder) | T5, translation decoder |
For the rest of this curriculum we are exclusively concerned with decoder self-attention with causal masking — GPT, LLaMA, Mistral, Qwen, DeepSeek, Gemma all live here.
A single attention head can only learn one type of relationship at a time. A model tracking syntactic subject-verb agreement, long-range coreference, and local repetition simultaneously must compromise with a single head. The fix is straightforward: run h smaller attention computations in parallel, each with its own independent Q, K, V projections, and concatenate their outputs.
Each head operates on dimension d_head = d_model / h. The computation budget stays roughly constant — we split d_model into h pieces rather than running one big computation. After all heads produce their (B, h, T, d_head) outputs, we concatenate back to (B, T, d_model) and apply a final linear layer W_O to mix information across heads.
(B, T, D) → (B, h, T, d_head) via view + transpose. This turns one large attention into h parallel small attentions, handled as a single batched matrix multiply. Parameter count is identical to one big head: 4·D² total.import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_QKV = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
B, T, D = x.shape
qkv = self.W_QKV(x) # (B, T, 3D)
q, k, v = qkv.chunk(3, dim=-1) # each (B, T, D)
def split_heads(t):
return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
q, k, v = map(split_heads, (q, k, v)) # (B, h, T, d_head)
scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_head) # (B, h, T, T)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = attn @ v # (B, h, T, d_head)
out = out.transpose(1, 2).contiguous().view(B, T, D) # (B, T, D)
return self.W_O(out)
One optimization worth noting: W_QKV is a single Linear(D, 3D) instead of three separate Linear(D, D) layers. A single large matrix multiply runs faster on GPU than three smaller ones (fewer kernel launches, better BLAS utilization). It is mathematically equivalent.
Total parameters per attention layer (no biases): 3D² + D² = 4D². Three from the combined QKV projection and one from the output projection. At d_model = 512, n_heads = 8, that is 4 × 512² ≈ 1.05M parameters per layer. LLaMA-7B has d_model = 4096 and 32 layers: 4 × 4096² × 32 ≈ 2.15B parameters in attention alone.
The eight authors of "Attention Is All You Need" (Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin) wrote in the paper: "All authors contributed equally. Listing order is random." They have since founded or co-founded Cohere, Character.ai, Adept, Inflection, and several other AI labs. One paper, eight companies.
The score matrix S = Q Kᵀ has shape (B, h, T, T). Its size grows with the square of sequence length. This quadratic scaling is the dominant cost at long context and the root motivation for everything in Week 3 of this curriculum.
| Sequence length T | Score matrix (B=1, h=32, FP16) | ×32 layers |
|---|---|---|
| 512 | 0.017 GB | 0.5 GB |
| 2 048 | 0.27 GB | 8.6 GB |
| 8 192 | 4.3 GB | 137 GB |
| 32 768 | 68.7 GB | 2.2 TB |
| 128 000 | 1 049 GB | 33.5 TB |
At 128k context (Claude 3 / GPT-4 Turbo range) the naive score matrix dwarfs the model weights by a factor of ~100. This is why FlashAttention (Day 21) is not an optimization — it is a necessity. Its insight: never materialize the full (T, T) score matrix in high-bandwidth memory; compute attention in tiles that fit in fast on-chip SRAM instead.
Within one Transformer block, attention competes with the feedforward layer (which you will meet on Day 7) for compute time. Their scaling is different:
O(B · T² · d_model) — quadratic in sequence length.O(B · T · d_model²) — linear in sequence length, quadratic in model dimension.The crossover happens when T ≈ d_model. For LLaMA-7B, d_model = 4096. At T = 4096 tokens the two layers cost roughly the same. At T = 32 768, attention costs 8× more per layer than FFN. Long-context inference is almost entirely an attention problem.
T² (red); FFN scales linearly in T (gold). They cross at T ≈ d_model. For a model with d_model=4096, attention dominates beyond ~4k tokens. At 128k tokens, attention is roughly 1000× the FFN cost per token generated.During training, the entire sequence of T tokens is processed together in one forward pass. During inference generation, the model produces one token at a time. After generating token t, to generate token t+1 the model needs to attend over all t+1 positions. The naive approach reruns the full attention computation: recompute Q, K, V for all past tokens, recompute QKᵀ, re-softmax, re-blend. This is enormously wasteful — you are recomputing K and V for tokens 0..t−1 on every single generation step, even though nothing about those tokens has changed.
The KV cache is the fix. Before producing token t, cache the K and V vectors you already computed for tokens 0..t−1. On the next step, only compute K and V for the one new token (position t), append them to the cache, and run attention against the full cached sequence. The score matrix goes from (T, T) (full recompute) to just a single new row (1, T) (cached). For a 512-token prefix the savings are 512×.
O(T²) per step. With the KV cache, past K and V vectors are stored and reused; only the new token's K and V are computed — O(T) per step. This is the single most important optimization for autoregressive inference. Day 20 covers the full implementation.The KV cache has its own cost: it occupies GPU memory. For a model with n_layers layers, n_heads heads, head dimension d_head, and sequence length T, the cache size is:
This is precisely why serving many long-context requests simultaneously is a memory management problem — and why PagedAttention (Day 24) exists. It borrows the page-table idea from operating systems to manage KV cache memory in non-contiguous blocks, enabling much higher serving throughput.
The KV cache, FlashAttention, and PagedAttention are the three most important inference-engineering optimizations for Transformer models. All three exist because of the quadratic T² cost of the attention score matrix you just computed by hand. Understanding why that matrix is expensive is the single most important thing you can take from today's lesson.
Here is what each upcoming week-3 day builds on today's foundation:
| Day | Topic | Connection to today |
|---|---|---|
| Day 20 | KV Cache in depth | Exactly the cache above — memory layout, eviction, batch management |
| Day 21 | FlashAttention | Avoids materializing the (T, T) score matrix; tiles the computation in SRAM |
| Day 24 | PagedAttention | OS-style paged memory management for the KV cache; enables vLLM |
Here is the complete single-head attention module. Read the code line-by-line while tracking shapes in the comments.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SingleHeadAttention(nn.Module):
def __init__(self, d_model: int, d_k: int):
super().__init__()
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)
self.d_k = d_k
def forward(self, x, mask=None):
# x: (B, T, d_model)
Q = self.W_Q(x) # (B, T, d_k)
K = self.W_K(x) # (B, T, d_k)
V = self.W_V(x) # (B, T, d_k)
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k) # (B, T, T)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1) # (B, T, T)
return attn @ V # (B, T, d_k)
One detail: bias=False on every nn.Linear. This is standard in transformer implementations. The bias adds parameters without meaningful empirical benefit when a layer normalization follows attention (which it always does — we cover that on Day 7).
In modern PyTorch, replace the inner computation with F.scaled_dot_product_attention(Q, K, V, is_causal=True), which automatically dispatches to FlashAttention on supported hardware. Use it in production; keep the explicit version above for learning.
For completeness on Apple Silicon, here is multi-head attention in MLX. Compare each line to the PyTorch version — the differences are cosmetic but reveal framework philosophy.
import mlx.core as mx
import mlx.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_QKV = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def __call__(self, x, mask=None):
B, T, D = x.shape
qkv = self.W_QKV(x)
q, k, v = mx.split(qkv, 3, axis=-1)
q = q.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3)
k = k.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3)
scores = (q @ k.transpose(0, 1, 3, 2)) / mx.sqrt(self.d_head)
if mask is not None:
scores = mx.where(mask == 0, mx.array(-1e9), scores)
attn = mx.softmax(scores, axis=-1)
out = attn @ v
out = out.transpose(0, 2, 1, 3).reshape(B, T, D)
return self.W_O(out)
Key differences from PyTorch: __call__ instead of forward; mx.split instead of .chunk; mx.where instead of masked_fill; no .contiguous() (lazy graph handles layout). Given the same weights and inputs, both implementations produce identical outputs to within floating-point precision.
Attention bugs are notoriously hard to detect from a loss curve alone. Run these checks on any new implementation before trusting it.
(B, T, D) must produce output (B, T, D). If anything else comes back, you have a transpose or reshape bug.
assert A[0, 0].sum(-1).allclose(ones).
i must not change outputs at positions < i:
x = torch.randn(1, 5, 64)
y1 = attn(x, mask=causal_mask)
x2 = x.clone(); x2[0, 4] = torch.randn(64) # change last token
y2 = attn(x2, mask=causal_mask)
assert torch.allclose(y1[:, :4], y2[:, :4])
loss = attn(x).sum(); loss.backward() must not produce NaN gradients. If it does, your scaling or masking is wrong.
Companion notebook: day-6-attention.ipynb.
F.scaled_dot_product_attention with an assertion.Close the page and answer from memory. If you can't, re-read the relevant section.
softmax(QKᵀ / √d_k) V symbol by symbol. What is Q? What is Kᵀ? What does softmax operate over?√d_k? What is the variance argument? What goes wrong at d_k = 4096 without it?S[2,2] by hand and verify the softmax weights for row 2.attn[0, 0, 5, 7] after softmax? Why is it exactly zero?d_model = 768 and n_heads = 12, what is d_head? What is the total parameter count of one attention layer (no biases)?view + transpose trick create independent heads?"Attention is a soft, content-based dictionary lookup that turned out to be everything we needed — and the quadratic cost of that lookup is what all of inference engineering Week 3 exists to solve."
Hand-picked references for this lesson.
The paper. Section 3 is the meat of today's lesson. Read it after completing the exercises.
Open paperThe best visual walkthrough ever made. Read alongside the paper.
Read postPaper plus working PyTorch code, side by side. Every equation has its line of code.
Read postSelf-attention from scratch on Shakespeare. The code that became nanoGPT.
Watch on YouTubeBeautifully animated geometry of attention. Best complement to this lesson's diagrams.
Watch on YouTubeThe original attention mechanism added to RNN encoder-decoders for translation.
Open paperWhat attention heads actually compute — explicit linear-algebra view. Advanced but rewarding.
Read postClean causal self-attention in ~50 lines. Production-quality reference implementation.
View sourceSurvey of attention variants. A useful map of the design space once you know the basics.
Read postThe paper that solved the T² memory problem you computed in the complexity section. Required reading before Day 21.
Open paperProduction reference with all corner cases. Find GPT2Attention and trace it against today's lesson.