Training asks how to make the weights better. Inference asks how to turn fixed weights into streamed tokens quickly, cheaply, and controllably. Today you separate prefill from decode, compute first-pass latency estimates, and implement the sampling loop from raw logits.
Training asks: "How do we make the weights better?" Inference asks a different question: "Given fixed weights, how do we turn a prompt into tokens as fast, cheaply, and controllably as possible?"
Every chat completion has the same shape:
That simple loop creates almost every performance constraint in production LLM systems. Prefill is the first pass over the prompt. Decode is the one-token-at-a-time loop after that. If you understand this split, KV cache (Day 20), FlashAttention (Day 21), speculative decoding (Day 23), continuous batching (Day 24), and the capstone engine (Days 27-30) all become much easier to reason about.
By the end of this lesson you should be able to look at a request like "prompt length 512, generate 128 tokens with a 7B model" and estimate where the time and memory go.
By the end of this lesson you should be able to:
2 * N * tokens.max_new_tokens, stop tokens, repetition penalty, frequency penalty, and presence penalty.generate() loop without using transformers.generate.This is a light-math lesson, but notation matters because inference papers and serving docs use the same few symbols repeatedly.
Shape variables:
B — batch size. Number of sequences processed together. B = 1 for one chat request.T_prompt — prompt length in tokens. If the tokenizer turns your prompt into 512 integers, T_prompt = 512.T_new — number of generated tokens. If the answer streams 128 tokens, T_new = 128.T_total — total context length after generation: T_prompt + T_new.V — vocabulary size. The number of token IDs the model can choose from.D — model width, also called d_model. The residual-stream vector size.Model and compute variables:
N — number of model parameters. A 7B model has N = 7,000,000,000.FLOP — floating-point operation. One multiply or add is one floating-point operation by the usual ML counting convention.2 * N FLOPs/token — rough inference rule of thumb: one generated token costs about two operations per parameter for the matrix multiplies. This ignores attention's smaller terms, kernel overhead, and cache effects, but it is good enough for first estimates.BW — memory bandwidth. How many bytes per second the GPU can move from HBM/VRAM.Sampling variables:
logits — raw model scores before softmax. Shape (B, T, V).p_i — probability of token i after softmax.tau — temperature. I use tau instead of T so we do not confuse temperature with sequence length.top-k — keep the k highest-probability tokens, set the rest to zero, renormalize.top-p — keep the smallest set of tokens whose cumulative probability is at least p, set the rest to zero, renormalize.Softmax reminder from Day 1:
For logits z = [z_1, z_2, ..., z_V]:
softmax(z)_i = exp(z_i) / sum_j exp(z_j)
Decoded:
exp(x) means e ** x, where e ≈ 2.718.sum_j means "sum over every token index j in the vocabulary."In code, always subtract the max logit for numerical stability:
def softmax(logits):
shifted = logits - logits.max(dim=-1, keepdim=True).values
probs = shifted.exp()
return probs / probs.sum(dim=-1, keepdim=True)
By the end of this section you should be able to point at any part of a generation trace and say whether it is prefill or decode.
Suppose the prompt is:
"Explain KV caching in one sentence."
The tokenizer might turn it into 8 token IDs:
[315, 849, 701, 23892, 287, 530, 6827, 13]
Do not worry about the exact IDs. Tokenization was Day 5. For now, the only important fact is the shape:
tokens: (B, T_prompt) = (1, 8)
The model's forward pass turns those token IDs into logits:
logits = model(tokens)
logits shape: (B, T_prompt, V) = (1, 8, V)
There is one logit vector for each input position. For generation, we only need the last one:
next_token_logits = logits[:, -1, :]
next_token_logits shape: (B, V) = (1, V)
That first full-prompt forward pass is prefill.
Prefill processes the whole prompt in parallel.
Input:
(B, T_prompt) token IDs
Operation:
one full transformer forward pass
Output:
(B, T_prompt, V) logits
Generation uses:
logits[:, -1, :] # scores for the next token after the prompt
During prefill, attention can compute all prompt positions at once. Position 0, position 1, position 2, and so on are all known. The causal mask still prevents position 2 from looking at position 3, but the GPU can perform the matrix multiplications for all positions in one large batch.
That is why prefill usually has high GPU utilization: it is a large dense computation.
Decode happens after the first token has been sampled.
Suppose prefill sampled token 42. The sequence becomes:
[315, 849, 701, 23892, 287, 530, 6827, 13, 42]
Now the model must predict token 10. But token 10 depends on token 9. Token 11 depends on token 10. Token 12 depends on token 11. This dependency chain is why decode is sequential.
Decode loop:
for step in range(max_new_tokens):
logits = model(tokens_so_far)
next_token = sample(logits[:, -1, :])
tokens_so_far = torch.cat([tokens_so_far, next_token], dim=1)
That version is conceptually correct but inefficient because it recomputes the entire prefix on every step. Production inference uses a KV cache: keys and values for previous tokens are stored and reused. You do not need the implementation yet; Day 20 is the deep dive. For today, remember the effect:
Without KV cache: each decode step recomputes all previous tokens.
With KV cache: each decode step processes only the newest token, while attending to cached previous K/V.
The key difference is not "prefill uses the model and decode uses something else." Both use the same weights. The difference is the shape and dependency pattern:
| Phase | Input shape | Parallel across token positions? | Main user-facing metric |
|---|---|---|---|
| Prefill | (B, T_prompt) | Yes | Time to first token |
| Decode | repeated (B, 1) steps with KV cache | No, across generated positions | Inter-token latency / streaming rate |
By the end of this section you should be able to read a latency report and know what each number means.
TTFT means time to first token. It includes:
If a chat app feels slow before anything appears, TTFT is the number you are noticing.
ITL means inter-token latency: the time between streamed output tokens after the first one.
If a chat app starts quickly but then writes slowly, ITL is the number you are noticing.
Sometimes systems report tokens/sec, which is the reciprocal of ITL:
tokens_per_second = 1 / seconds_per_token
So:
20 ms/token = 0.020 seconds/token
tokens/sec = 1 / 0.020 = 50 tokens/sec
Suppose a request has:
TTFT = 300 ms
ITL = 20 ms/token
T_new = 50 generated tokens
Total time is approximately:
300 ms + 50 * 20 ms = 1300 ms
Average over generated tokens:
1300 ms / 50 = 26 ms/token
But that average hides two different user experiences:
For inference engineering, optimize these separately. Long prompts hurt TTFT. Slow decode hurts streaming.
By the end of this section you should be able to compute a rough request cost on paper.
For a dense decoder-only Transformer, the main matrix multiplications cost about:
FLOPs per token ≈ 2 * N
where N is the number of parameters.
Why the factor 2? A matrix-vector or matrix-matrix multiply uses one multiply and one add for each weight involved. The exact count depends on architecture details, but 2 * N is the standard first-pass estimate.
You used the same idea in Day 8 and Day 9 when estimating training compute. In training, backward pass adds more work. In inference, we only do forward passes.
Parameters:
N = 7B = 7 * 10^9
T_prompt = 512
T_new = 128
Prefill FLOPs:
2 * N * T_prompt
= 2 * 7 * 10^9 * 512
= 14 * 10^9 * 512
= 7168 * 10^9
= 7.168 * 10^12 FLOPs
Decode FLOPs:
T_new * 2 * N
= 128 * 2 * 7 * 10^9
= 128 * 14 * 10^9
= 1792 * 10^9
= 1.792 * 10^12 FLOPs
Pure FLOPs make prefill look about 4x more expensive than decode:
7.168e12 / 1.792e12 = 4
But wall-clock time often disagrees. Why?
A rough A100 FP16 peak is 312 TFLOP/s. A TFLOP is 10^12 FLOPs.
If prefill fully used that compute:
7.168e12 FLOPs / 312e12 FLOPs/sec
= 0.02297 sec
≈ 23 ms
That is plausible because prefill is a large dense operation. Large matmuls keep the GPU busy.
For decode at batch size 1, the GPU repeatedly does small matrix-vector operations. Each decode step touches a large fraction of the model weights, but only for one token. For a 7B FP16 model, the weights occupy roughly:
7e9 params * 2 bytes/param = 14 GB
If each token requires reading about 14 GB of weights from HBM, and memory bandwidth is around 2 TB/s:
14 GB / 2000 GB/sec = 0.007 sec = 7 ms
That is a lower bound. Kernel overhead, imperfect bandwidth, attention, sampling, and framework overhead can easily push a single-token decode step toward 10-20 ms. At 14 ms/token:
128 tokens * 14 ms/token = 1792 ms = 1.79 sec
So the request can look like:
prefill: ~23 ms
decode: ~1792 ms
The important lesson: FLOPs are not latency. FLOPs tell you how much arithmetic exists. Latency depends on whether the GPU is limited by arithmetic throughput, memory bandwidth, launch overhead, or scheduling.
This is the bridge to Day 16's roofline model.
By the end of this section you should be able to implement the common sampling algorithms from logits alone.
Suppose the model's final logit vector has 5 entries:
token id: 0 1 2 3 4
logit: 2.0 -1.5 0.3 4.2 -0.8
Logits are not probabilities. They are unnormalized scores. Higher means more likely, but the values do not sum to 1.
Apply softmax:
softmax([2.0, -1.5, 0.3, 4.2, -0.8])
≈ [0.097, 0.003, 0.018, 0.876, 0.006]
Token 3 dominates with probability about 87.6%.
Greedy decoding chooses the highest-logit token:
argmax(logits) = token 3
This is deterministic. Same prompt, same model, same output. It is useful for tests, extraction tasks, and any setting where randomness is unwanted.
In API settings, this is often described as temperature = 0, though mathematically temperature 0 is a limit, not a literal division by zero.
def greedy_sample(logits):
# logits: (B, V)
return logits.argmax(dim=-1, keepdim=True)
Temperature rescales logits before softmax:
probs = softmax(logits / tau)
where tau is temperature.
tau < 1 sharpens the distribution.tau = 1 leaves it unchanged.tau > 1 flattens the distribution.Use tau, not T, because T already means sequence length.
Concrete example with the logits above:
tau = 1.0:
[0.097, 0.003, 0.018, 0.876, 0.006]
tau = 0.7:
[0.041, 0.000, 0.004, 0.954, 0.001]
tau = 2.0:
[0.206, 0.036, 0.088, 0.619, 0.051]
At tau = 0.7, the best token becomes even more likely. At tau = 2.0, weaker tokens get a chance.
The same idea can be seen on a simpler probability distribution:
base probabilities = [0.7, 0.2, 0.1]
Temperature can be applied to probabilities as:
p_i' = p_i ** (1 / tau) / sum_j p_j ** (1 / tau)
For tau = 0.5:
[0.7^2, 0.2^2, 0.1^2] = [0.49, 0.04, 0.01]
sum = 0.54
renormalized = [0.907, 0.074, 0.019]
For tau = 2.0:
[sqrt(0.7), sqrt(0.2), sqrt(0.1)]
≈ [0.837, 0.447, 0.316]
sum ≈ 1.600
renormalized ≈ [0.523, 0.279, 0.198]
Top-k keeps the k highest-logit tokens and removes the rest.
For our logits:
token id: 0 1 2 3 4
logit: 2.0 -1.5 0.3 4.2 -0.8
The top 2 tokens are:
token 3: logit 4.2
token 0: logit 2.0
Softmax only over those two:
softmax([4.2, 2.0]) ≈ [0.900, 0.100]
So top-k with k = 2 says: "sample only from tokens 3 and 0; ignore the other three completely."
Top-k is simple, but it is not adaptive. k = 50 might be too wide when the model is certain and too narrow when the model is uncertain.
Top-p, also called nucleus sampling, keeps the smallest set of tokens whose cumulative probability reaches at least p.
Using the original softmax probabilities:
token id sorted by probability: 3 0 2 4 1
probability: 0.876 0.097 0.018 0.006 0.003
cumulative: 0.876 0.973 0.991 0.997 1.000
For top_p = 0.9, keep tokens 3 and 0:
token 3 alone reaches 0.876, which is below 0.9
token 3 + token 0 reaches 0.973, which is above 0.9
Then renormalize those kept probabilities:
token 3: 0.876 / 0.973 ≈ 0.900
token 0: 0.097 / 0.973 ≈ 0.100
In this example, top-p 0.9 keeps the same two tokens as top-k 2. In general it adapts:
That is why top-p usually feels better than top-k for open-ended writing.
By the end of this section you should be able to write the token picker yourself.
import torch
import torch.nn.functional as F
def sample_next_token(
logits: torch.Tensor,
*,
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
) -> torch.Tensor:
"""
logits: (B, V), raw scores for the next token.
returns: (B, 1), sampled token ids.
"""
if temperature == 0:
return logits.argmax(dim=-1, keepdim=True)
if temperature < 0:
raise ValueError("temperature must be non-negative")
logits = logits / temperature
if top_k is not None:
values, _ = torch.topk(logits, k=top_k, dim=-1)
kth_value = values[:, [-1]]
logits = torch.where(
logits < kth_value,
torch.full_like(logits, float("-inf")),
logits,
)
if top_p is not None:
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative = sorted_probs.cumsum(dim=-1)
# Remove tokens after cumulative probability exceeds top_p.
# Keep the first token above the threshold, because it is what reaches p.
remove_sorted = cumulative > top_p
remove_sorted[:, 1:] = remove_sorted[:, :-1].clone()
remove_sorted[:, 0] = False
remove = torch.zeros_like(remove_sorted).scatter(1, sorted_idx, remove_sorted)
logits = logits.masked_fill(remove, float("-inf"))
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
The top-p block is the only tricky part. Read it slowly:
-inf, so their softmax probability becomes 0.This version assumes your model has the same simple interface as Day 9:
logits = model(idx) # idx shape: (B, T), logits shape: (B, T, V)
@torch.no_grad()
def generate(
model,
idx: torch.Tensor,
*,
max_new_tokens: int,
block_size: int,
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
eos_token_id: int | None = None,
) -> torch.Tensor:
"""
idx: (B, T_start), prompt token ids.
returns: (B, T_start + generated), prompt plus generated ids.
This intentionally does not use a KV cache yet.
Day 20 will replace repeated full-prefix forward passes with cached decode.
"""
model.eval()
for _ in range(max_new_tokens):
# Crop to model context window.
idx_cond = idx[:, -block_size:]
# Forward pass over current context.
logits = model(idx_cond) # (B, T_context, V)
next_logits = logits[:, -1, :] # (B, V)
# Convert next-token logits into an actual sampled token id.
next_id = sample_next_token(
next_logits,
temperature=temperature,
top_k=top_k,
top_p=top_p,
) # (B, 1)
# Append sampled token.
idx = torch.cat([idx, next_id], dim=1)
if eos_token_id is not None and (next_id == eos_token_id).all():
break
return idx
This is intentionally minimal. Production engines add:
But the skeleton is the same: forward pass, sample, append, repeat.
By the end of this section you should be able to explain the common generation knobs without treating them as magic.
max_new_tokensmax_new_tokens is a hard cap on generated tokens.
If it is too low, the model can be cut off mid-answer:
Prompt: "Explain FlashAttention in two paragraphs."
max_new_tokens = 12
Output: "FlashAttention is an optimized attention algorithm that reduces memory traffic by"
This is not a model-quality failure. It is a decoding configuration failure.
A stop condition ends generation early.
Examples:
"\n\nUser:"."</tool_call>".Token-level stop conditions are easier and safer than string-level ones because generation happens in token IDs. String stops require decoding partial text and checking boundaries.
Repetition penalty reduces the chance of repeating tokens that already appeared.
A simple version:
def apply_repetition_penalty(logits, generated_ids, penalty=1.1):
# logits: (B, V), generated_ids: (B, T)
logits = logits.clone()
for b in range(logits.shape[0]):
seen = set(generated_ids[b].tolist())
for token_id in seen:
if logits[b, token_id] > 0:
logits[b, token_id] /= penalty
else:
logits[b, token_id] *= penalty
return logits
This is heuristic, not a training objective. It is useful when a model gets stuck in loops, but too much penalty can make outputs weird by suppressing legitimate repeated words.
These are common in APIs.
Presence penalty says: "If this token has appeared at all, subtract a fixed amount from its logit."
logit[token] -= presence_penalty
Frequency penalty says: "Subtract more if this token has appeared many times."
logit[token] -= frequency_penalty * count[token]
Concrete example:
base logit for token "the" = 3.0
count["the"] = 4
presence_penalty = 0.2
frequency_penalty = 0.1
adjusted logit = 3.0 - 0.2 - 0.1 * 4
= 3.0 - 0.2 - 0.4
= 2.4
The token can still be sampled. It is just less likely.
By the end of this section you should understand why inference systems obsess over batching and caches.
Token n + 1 depends on token n.
token 1 -> token 2 -> token 3 -> token 4 -> ...
You cannot generate token 4 before you know token 3. This is the autoregressive constraint.
Prefill can parallelize across positions because all prompt tokens are known. Decode cannot parallelize across generated positions because future generated tokens are not known yet.
At batch size 1, decode often looks like:
read model weights -> do one token of useful work -> repeat
At larger batches, the same weights can be used for many sequences at once:
read model weights -> do B tokens of useful work -> repeat
That raises arithmetic intensity: more math per byte loaded. This is why continuous batching is such a big deal in production inference. Day 24 is the deep dive.
Without a cache, the decode step after a 512-token prompt would run attention over all 512 previous tokens and recompute their keys and values. Then the next step would recompute 513 tokens. Then 514. That is wasteful.
With a KV cache:
This turns decode from "recompute the full prefix every time" into "append one new token's K/V and attend over the cache." Day 20 will turn this into exact memory math.
Use the companion notebook, then swap in your Day 9 GPT for meaningful timings.
Do this exercise in a notebook or a single Python file. Use your Day 9 tiny GPT if you have it; otherwise stub the model with random logits and focus on sampling first.
Given:
logits = [2.0, -1.5, 0.3, 4.2, -0.8]
tau = 2.0 and recompute probabilities.k = 2. Which tokens remain?p = 0.9. Which tokens remain?Expected approximate values:
softmax tau=1.0 -> [0.097, 0.003, 0.018, 0.876, 0.006]
softmax tau=2.0 -> [0.206, 0.036, 0.088, 0.619, 0.051]
greedy token -> 3
top-k k=2 -> tokens 3 and 0
top-p p=0.9 -> tokens 3 and 0
Implement:
greedy_sample(logits)
temperature_sample(logits, tau)
top_k_sample(logits, k)
top_p_sample(logits, p)
Test each on the logits above. Use torch.manual_seed(0) before stochastic tests so results are reproducible.
Implement the full generate() loop from this lesson. Run it with:
Compare outputs on the same prompt. Do not judge only "quality"; observe diversity, repetition, and failure modes.
Measure:
start = time.perf_counter()
out = generate(...)
elapsed = time.perf_counter() - start
tokens_per_sec = generated_tokens / elapsed
Then repeat with prompt lengths:
T_prompt = 16, 64, 256, 512
Because this no-cache version recomputes the prefix, longer prompts should slow down decode. That pain is the motivation for KV caching.
Close the page and answer from memory. If you cannot, re-read the relevant section.
max_new_tokens cause?n+1 depends on generated token n, so future decode steps cannot start yet.2 * 7e9 * 512 = 7.168e12 FLOPs.An LLM inference request is not one operation. It is two regimes glued together:
prefill: big parallel prompt pass -> first token
decode: sequential one-token loop -> remaining tokens
Prefill affects time to first token. Decode affects streaming speed. Sampling decides which token turns the model's probability distribution into actual text. The rest of Week 3 is about making this loop faster by understanding the hardware and removing unnecessary memory movement.
"Inference is two regimes glued together: a wide prompt pass to get the first token, then a narrow sequential loop that has to be made fast one token at a time."
Sampling docs, inference arithmetic, and the serving systems that build on today's split.
Practical documentation for greedy, beam, temperature, top-k, top-p, and other decoding strategies.
Open docsConcise compute and memory math for inference. Useful before KV cache and roofline lessons.
Read postSystems-oriented overview of LLM inference metrics, serving constraints, and deployment tradeoffs.
Open guideThe paper behind vLLM memory management. It turns the KV-cache pressure introduced today into a serving architecture.
Open paperCompanion notebook for stable softmax, greedy/temperature/top-k/top-p sampling, and a minimal generate loop.
Open notebook