LLM Inference Engineer · Day 20
Day 20 · Week 3 · Inference & Hardware

KV Cache Deep Dive: Memory Math, Layout & Optimization

The KV cache is the optimization that makes autoregressive decode practical, and it is often the biggest memory object in an inference server. Today you compute its exact size, implement the update pattern, and see why GQA, sliding windows, and PagedAttention exist.

Time~170 min
DifficultyHard
PrerequisiteDay 19
Notebookday-20
Why This Lesson

Hardware limits shape inference behavior.

Without KV caching, every decode step recomputes keys and values for the whole prefix. With caching, old keys and values are stored once and reused. That turns a catastrophic recompute loop into an append-and-attend loop.

The price is memory. At useful context lengths and batches, KV cache can exceed model weights. Day 24's PagedAttention is mostly a solution to KV cache allocation and fragmentation, so Day 20 gives you the arithmetic first.

Learning Objectives

What you should be able to do today.

  1. State the per-layer KV cache shape and decode what every axis means.
  2. Compute KV cache memory for any model, batch, sequence length, and dtype.
  3. Explain why GQA and MQA reduce KV cache size linearly.
  4. Implement an append-only KV cache and verify it matches no-cache attention.
  5. Compare preallocation, dynamic growth, sliding windows, and paged caches.
  6. Connect KV cache memory to batching limits in production inference.
Math Notation Cheatsheet

Decode the symbols before using them.

  • B is batch size.
  • T or seq_len is cached sequence length.
  • n_layers is the number of Transformer layers.
  • n_heads is the number of query heads.
  • n_kv_heads is the number of key/value heads; in MHA it equals n_heads.
  • head_dim is the vector width per head.
  • dtype_bytes is bytes per scalar, such as 2 for FP16/BF16.
Why Cache

Keys and values for old tokens do not change.

Objective

By the end of this section, you should be able to explain why caching is mathematically valid.

In self-attention, each token produces a query Q, key K, and value V. During decode, the old prompt tokens are unchanged. Their keys and values are also unchanged because the model weights are fixed and the token representations for those positions have already been computed.

So instead of:

step 1: recompute K,V for tokens 1..512
step 2: recompute K,V for tokens 1..513
step 3: recompute K,V for tokens 1..514

we do:

prefill: compute and store K,V for tokens 1..512
step 1: compute K,V only for token 513, append
step 2: compute K,V only for token 514, append

Attention still reads the previous K/V. The saved work is the projection and recomputation of the prefix.

KV Cache Shape Per Layer K, V: [B, n_kv_heads, T, head_dim] t0t1t2t3t4t5 new Each decode step appends one K row and one V row for every layer.
The sequence axis grows by one slot per decode step.
Memory Formula

The cache size is simple and unforgiving.

Objective

By the end of this section, you should be able to compute cache size from model config alone.

Per token, per layer:

K and V scalars = 2 * n_kv_heads * head_dim
bytes = 2 * n_kv_heads * head_dim * dtype_bytes

All layers, all tokens, all batch items:

cache_bytes = B * T * n_layers * 2 * n_kv_heads * head_dim * dtype_bytes

Concrete LLaMA-2 7B-style example:

n_layers = 32
n_kv_heads = 32
head_dim = 128
dtype_bytes = 2
T = 4096
B = 1

per token per layer = 2 * 32 * 128 * 2
                    = 16,384 bytes
                    = 16 KB

per token all layers = 32 * 16 KB = 512 KB
cache at T=4096     = 4096 * 512 KB ~= 2 GB

At batch 64, that same cache is about 128 GB. That is larger than the model weights.

GQA and MQA

Fewer KV heads means fewer cache rows.

Objective

By the end of this section, you should be able to compute GQA memory savings.

Multi-head attention uses one K/V head per query head. Grouped-query attention keeps many query heads but shares fewer K/V heads.

LLaMA-2 70B-style example:

n_layers = 80
n_heads = 64
n_kv_heads = 8
head_dim = 128
dtype_bytes = 2

per token = 80 * 2 * 8 * 128 * 2
          = 327,680 bytes
          = 320 KB

If it used full MHA with 64 KV heads instead of 8, the cache would be 64 / 8 = 8x larger. GQA is not a small tweak; it is a serving-enabling architectural choice.

MHA, GQA, MQA: Same Q Heads, Less KV MHA: 32 KV heads GQA: 8 KV heads MQA: 1 KV head baseline 4x smaller than MHA 32x smaller
GQA and MQA shrink the KV slab while keeping many query heads.
Layouts

The best layout follows the access pattern.

Common layouts:

[B, n_kv_heads, T, head_dim]
[B, T, n_kv_heads, head_dim]

During decode, a new token appends one position along T, and attention reads all previous T positions for the active heads. A layout that gives contiguous reads for the kernel's access pattern wins.

There is no universal answer independent of kernel implementation. The important habit is to ask: which axis does a warp walk across? Which values are adjacent in memory? Which append pattern causes fragmentation?

Sliding and Paging

Bounding memory changes the model's available context.

Preallocation is simple but wasteful. Dynamic growth is flexible but can reallocate and fragment. Sliding-window attention caps memory by keeping only the most recent W tokens, which is useful for models designed for it. Paged caches, introduced in vLLM's PagedAttention, split KV storage into fixed-size blocks so many requests can share a memory pool without giant contiguous allocations.

You do not need PagedAttention yet. Day 24 builds it. Today, remember the reason it exists: KV cache memory grows linearly with B * T, while live request lengths vary wildly.

Sliding Window Attention Bounds the Cache For window W = 6, old tokens remain in the prompt history but leave the active KV cache. 012345678910111213 dropped from active window kept in cache
Sliding windows bound memory by making long-range attention a model-level tradeoff.
Did You Know?

A systems detail worth remembering.

For large batches and long contexts, the KV cache can be bigger than the model weights. This is why serving engines talk about cache blocks, pages, eviction, and fragmentation as first-class concepts.
Exercise

Do the arithmetic, then run the notebook.

Use the notebook to:

  1. Compute cache sizes for a 70B model with MHA and GQA at batch 1, 16, and 64.
  2. Implement an append-only cache as (K, V) arrays.
  3. Verify cached attention for a new token matches full attention over the same prefix.
  4. Simulate sliding-window truncation and observe the memory cap.

Optional CUDA/PyTorch extension: time decode with and without cache for sequence lengths 128, 512, and 2048.

Self-Check

Answer these from memory.

  1. Why is caching valid? Old tokens' K and V do not change once computed.
  2. What is the KV cache formula? B T n_layers 2 n_kv_heads head_dim dtype_bytes.
  3. How does GQA change memory? It reduces memory linearly by n_kv_heads / n_heads.
  4. Why can cache exceed weights? It grows with batch and sequence length; weights are fixed.
  5. What problem does PagedAttention solve? KV cache allocation and fragmentation across variable-length requests.

"The KV cache turns decode from recompute into reuse, but reuse has a memory bill."

Day 20 · Week 3
Further Reading

Go deeper.

Primary references and the companion notebook for today's exercise.

Blog

Transformer Inference Arithmetic

Clear memory and compute math for inference.

Open
Paper

PagedAttention

vLLM's KV cache paging paper and the basis for Day 24.

Open
Paper

Mistral

Sliding-window attention in a production-grade open model.

Open
Source

Hugging Face cache utils

Reference implementation of modern cache abstractions.

Open
Notebook · Day 20

KV Cache Deep Dive notebook

Companion Jupyter notebook with runnable calculations and optional hardware-specific cells.

Open notebook