The KV cache is the optimization that makes autoregressive decode practical, and it is often the biggest memory object in an inference server. Today you compute its exact size, implement the update pattern, and see why GQA, sliding windows, and PagedAttention exist.
Without KV caching, every decode step recomputes keys and values for the whole prefix. With caching, old keys and values are stored once and reused. That turns a catastrophic recompute loop into an append-and-attend loop.
The price is memory. At useful context lengths and batches, KV cache can exceed model weights. Day 24's PagedAttention is mostly a solution to KV cache allocation and fragmentation, so Day 20 gives you the arithmetic first.
B is batch size.T or seq_len is cached sequence length.n_layers is the number of Transformer layers.n_heads is the number of query heads.n_kv_heads is the number of key/value heads; in MHA it equals n_heads.head_dim is the vector width per head.dtype_bytes is bytes per scalar, such as 2 for FP16/BF16.By the end of this section, you should be able to explain why caching is mathematically valid.
In self-attention, each token produces a query Q, key K, and value V. During decode, the old prompt tokens are unchanged. Their keys and values are also unchanged because the model weights are fixed and the token representations for those positions have already been computed.
So instead of:
step 1: recompute K,V for tokens 1..512
step 2: recompute K,V for tokens 1..513
step 3: recompute K,V for tokens 1..514
we do:
prefill: compute and store K,V for tokens 1..512
step 1: compute K,V only for token 513, append
step 2: compute K,V only for token 514, append
Attention still reads the previous K/V. The saved work is the projection and recomputation of the prefix.
By the end of this section, you should be able to compute cache size from model config alone.
Per token, per layer:
K and V scalars = 2 * n_kv_heads * head_dim
bytes = 2 * n_kv_heads * head_dim * dtype_bytes
All layers, all tokens, all batch items:
cache_bytes = B * T * n_layers * 2 * n_kv_heads * head_dim * dtype_bytes
Concrete LLaMA-2 7B-style example:
n_layers = 32
n_kv_heads = 32
head_dim = 128
dtype_bytes = 2
T = 4096
B = 1
per token per layer = 2 * 32 * 128 * 2
= 16,384 bytes
= 16 KB
per token all layers = 32 * 16 KB = 512 KB
cache at T=4096 = 4096 * 512 KB ~= 2 GB
At batch 64, that same cache is about 128 GB. That is larger than the model weights.
By the end of this section, you should be able to compute GQA memory savings.
Multi-head attention uses one K/V head per query head. Grouped-query attention keeps many query heads but shares fewer K/V heads.
LLaMA-2 70B-style example:
n_layers = 80
n_heads = 64
n_kv_heads = 8
head_dim = 128
dtype_bytes = 2
per token = 80 * 2 * 8 * 128 * 2
= 327,680 bytes
= 320 KB
If it used full MHA with 64 KV heads instead of 8, the cache would be 64 / 8 = 8x larger. GQA is not a small tweak; it is a serving-enabling architectural choice.
Common layouts:
[B, n_kv_heads, T, head_dim]
[B, T, n_kv_heads, head_dim]
During decode, a new token appends one position along T, and attention reads all previous T positions for the active heads. A layout that gives contiguous reads for the kernel's access pattern wins.
There is no universal answer independent of kernel implementation. The important habit is to ask: which axis does a warp walk across? Which values are adjacent in memory? Which append pattern causes fragmentation?
Preallocation is simple but wasteful. Dynamic growth is flexible but can reallocate and fragment. Sliding-window attention caps memory by keeping only the most recent W tokens, which is useful for models designed for it. Paged caches, introduced in vLLM's PagedAttention, split KV storage into fixed-size blocks so many requests can share a memory pool without giant contiguous allocations.
You do not need PagedAttention yet. Day 24 builds it. Today, remember the reason it exists: KV cache memory grows linearly with B * T, while live request lengths vary wildly.
Use the notebook to:
(K, V) arrays.Optional CUDA/PyTorch extension: time decode with and without cache for sequence lengths 128, 512, and 2048.
B T n_layers 2 n_kv_heads head_dim dtype_bytes.n_kv_heads / n_heads."The KV cache turns decode from recompute into reuse, but reuse has a memory bill."
Primary references and the companion notebook for today's exercise.
Companion Jupyter notebook with runnable calculations and optional hardware-specific cells.
Open notebook