LLM Inference Engineer · Day 15
Day 15 · Week 3 · Inference & Hardware

How Inference Works: Prefill, Decode & Sampling

Training asks how to make the weights better. Inference asks how to turn fixed weights into streamed tokens quickly, cheaply, and controllably. Today you separate prefill from decode, compute first-pass latency estimates, and implement the sampling loop from raw logits.

Time~140 min
DifficultyMedium
PrerequisiteDay 14
Notebookday-15
Why This Lesson

A trained model becomes a latency problem.

Training asks: "How do we make the weights better?" Inference asks a different question: "Given fixed weights, how do we turn a prompt into tokens as fast, cheaply, and controllably as possible?"

Every chat completion has the same shape:

  1. Read the whole prompt.
  2. Predict the next token.
  3. Append that token.
  4. Predict the next token again.
  5. Repeat until a stop condition fires.

That simple loop creates almost every performance constraint in production LLM systems. Prefill is the first pass over the prompt. Decode is the one-token-at-a-time loop after that. If you understand this split, KV cache (Day 20), FlashAttention (Day 21), speculative decoding (Day 23), continuous batching (Day 24), and the capstone engine (Days 27-30) all become much easier to reason about.

By the end of this lesson you should be able to look at a request like "prompt length 512, generate 128 tokens with a 7B model" and estimate where the time and memory go.

Learning Objectives

What you should be able to do today.

By the end of this lesson you should be able to:

  1. Distinguish prefill from decode in both shape and runtime behavior.
  2. Explain TTFT (time to first token) and ITL (inter-token latency).
  3. Compute rough FLOPs for prefill and decode with 2 * N * tokens.
  4. Explain why prefill is usually compute-bound while batch-1 decode is usually memory-bound.
  5. Convert logits into probabilities with softmax.
  6. Implement greedy sampling, temperature sampling, top-k sampling, and top-p sampling.
  7. Explain max_new_tokens, stop tokens, repetition penalty, frequency penalty, and presence penalty.
  8. Write a simple PyTorch generate() loop without using transformers.generate.
Math Notation Cheatsheet

The symbols used in inference docs.

This is a light-math lesson, but notation matters because inference papers and serving docs use the same few symbols repeatedly.

Shape variables:

  • B — batch size. Number of sequences processed together. B = 1 for one chat request.
  • T_prompt — prompt length in tokens. If the tokenizer turns your prompt into 512 integers, T_prompt = 512.
  • T_new — number of generated tokens. If the answer streams 128 tokens, T_new = 128.
  • T_total — total context length after generation: T_prompt + T_new.
  • V — vocabulary size. The number of token IDs the model can choose from.
  • D — model width, also called d_model. The residual-stream vector size.

Model and compute variables:

  • N — number of model parameters. A 7B model has N = 7,000,000,000.
  • FLOP — floating-point operation. One multiply or add is one floating-point operation by the usual ML counting convention.
  • 2 * N FLOPs/token — rough inference rule of thumb: one generated token costs about two operations per parameter for the matrix multiplies. This ignores attention's smaller terms, kernel overhead, and cache effects, but it is good enough for first estimates.
  • BW — memory bandwidth. How many bytes per second the GPU can move from HBM/VRAM.

Sampling variables:

  • logits — raw model scores before softmax. Shape (B, T, V).
  • p_i — probability of token i after softmax.
  • tau — temperature. I use tau instead of T so we do not confuse temperature with sequence length.
  • top-k — keep the k highest-probability tokens, set the rest to zero, renormalize.
  • top-p — keep the smallest set of tokens whose cumulative probability is at least p, set the rest to zero, renormalize.

Softmax reminder from Day 1:

For logits z = [z_1, z_2, ..., z_V]:

softmax(z)_i = exp(z_i) / sum_j exp(z_j)

Decoded:

  • exp(x) means e ** x, where e ≈ 2.718.
  • sum_j means "sum over every token index j in the vocabulary."
  • The output is a probability distribution: every value is positive, and all values sum to 1.

In code, always subtract the max logit for numerical stability:

def softmax(logits):
    shifted = logits - logits.max(dim=-1, keepdim=True).values
    probs = shifted.exp()
    return probs / probs.sum(dim=-1, keepdim=True)
The Two Phases of Inference

Prefill reads the prompt. Decode writes the answer.

By the end of this section you should be able to point at any part of a generation trace and say whether it is prefill or decode.

Start with one concrete prompt

Suppose the prompt is:

"Explain KV caching in one sentence."

The tokenizer might turn it into 8 token IDs:

[315, 849, 701, 23892, 287, 530, 6827, 13]

Do not worry about the exact IDs. Tokenization was Day 5. For now, the only important fact is the shape:

tokens: (B, T_prompt) = (1, 8)

The model's forward pass turns those token IDs into logits:

logits = model(tokens)
logits shape: (B, T_prompt, V) = (1, 8, V)

There is one logit vector for each input position. For generation, we only need the last one:

next_token_logits = logits[:, -1, :]
next_token_logits shape: (B, V) = (1, V)

That first full-prompt forward pass is prefill.

Prefill

Prefill processes the whole prompt in parallel.

Input:

(B, T_prompt) token IDs

Operation:

one full transformer forward pass

Output:

(B, T_prompt, V) logits

Generation uses:

logits[:, -1, :]  # scores for the next token after the prompt

During prefill, attention can compute all prompt positions at once. Position 0, position 1, position 2, and so on are all known. The causal mask still prevents position 2 from looking at position 3, but the GPU can perform the matrix multiplications for all positions in one large batch.

That is why prefill usually has high GPU utilization: it is a large dense computation.

Decode

Decode happens after the first token has been sampled.

Suppose prefill sampled token 42. The sequence becomes:

[315, 849, 701, 23892, 287, 530, 6827, 13, 42]

Now the model must predict token 10. But token 10 depends on token 9. Token 11 depends on token 10. Token 12 depends on token 11. This dependency chain is why decode is sequential.

Decode loop:

for step in range(max_new_tokens):
    logits = model(tokens_so_far)
    next_token = sample(logits[:, -1, :])
    tokens_so_far = torch.cat([tokens_so_far, next_token], dim=1)

That version is conceptually correct but inefficient because it recomputes the entire prefix on every step. Production inference uses a KV cache: keys and values for previous tokens are stored and reused. You do not need the implementation yet; Day 20 is the deep dive. For today, remember the effect:

Without KV cache: each decode step recomputes all previous tokens.
With KV cache: each decode step processes only the newest token, while attending to cached previous K/V.

The picture

Prefill all prompt tokens processed together t1 t2 t3 t4 t5 one full forward pass (B, T_prompt) -> (B, T_prompt, V) first new token Decode one new token per forward step ... t5 new1 ? decode step (B, 1) -> (B, V) new2
Prefill is one full prompt pass; decode is the repeated one-token loop that follows. The weights are the same, but the shape and dependency pattern are different.

The key difference is not "prefill uses the model and decode uses something else." Both use the same weights. The difference is the shape and dependency pattern:

PhaseInput shapeParallel across token positions?Main user-facing metric
Prefill(B, T_prompt)YesTime to first token
Decoderepeated (B, 1) steps with KV cacheNo, across generated positionsInter-token latency / streaming rate
TTFT and ITL

Separate the first-token wait from the streaming rate.

By the end of this section you should be able to read a latency report and know what each number means.

Time to first token

TTFT means time to first token. It includes:

  1. Tokenizing the prompt.
  2. Queueing and scheduling.
  3. Prefill forward pass.
  4. Sampling the first generated token.
  5. Returning that first token to the client.

If a chat app feels slow before anything appears, TTFT is the number you are noticing.

Inter-token latency

ITL means inter-token latency: the time between streamed output tokens after the first one.

If a chat app starts quickly but then writes slowly, ITL is the number you are noticing.

Sometimes systems report tokens/sec, which is the reciprocal of ITL:

tokens_per_second = 1 / seconds_per_token

So:

20 ms/token = 0.020 seconds/token
tokens/sec = 1 / 0.020 = 50 tokens/sec

Why average latency hides the important split

Suppose a request has:

TTFT = 300 ms
ITL  = 20 ms/token
T_new = 50 generated tokens

Total time is approximately:

300 ms + 50 * 20 ms = 1300 ms

Average over generated tokens:

1300 ms / 50 = 26 ms/token

But that average hides two different user experiences:

  • The user waits 300 ms before seeing anything.
  • Then the answer streams smoothly at 50 tokens/sec.

For inference engineering, optimize these separately. Long prompts hurt TTFT. Slow decode hurts streaming.

A Streaming Request Timeline time TTFT tokenize + queue + prefill token 1 ITL repeats once per decoded token
TTFT is the wait before the first streamed token. ITL is the repeated gap between later tokens. User-perceived latency depends on both.
FLOPs: The First Runtime Estimate

FLOPs are the starting estimate, not the latency.

By the end of this section you should be able to compute a rough request cost on paper.

The useful rule of thumb

For a dense decoder-only Transformer, the main matrix multiplications cost about:

FLOPs per token ≈ 2 * N

where N is the number of parameters.

Why the factor 2? A matrix-vector or matrix-matrix multiply uses one multiply and one add for each weight involved. The exact count depends on architecture details, but 2 * N is the standard first-pass estimate.

You used the same idea in Day 8 and Day 9 when estimating training compute. In training, backward pass adds more work. In inference, we only do forward passes.

Concrete example: 7B model, 512-token prompt, 128-token output

Parameters:

N = 7B = 7 * 10^9
T_prompt = 512
T_new = 128

Prefill FLOPs:

2 * N * T_prompt
= 2 * 7 * 10^9 * 512
= 14 * 10^9 * 512
= 7168 * 10^9
= 7.168 * 10^12 FLOPs

Decode FLOPs:

T_new * 2 * N
= 128 * 2 * 7 * 10^9
= 128 * 14 * 10^9
= 1792 * 10^9
= 1.792 * 10^12 FLOPs

Pure FLOPs make prefill look about 4x more expensive than decode:

7.168e12 / 1.792e12 = 4

But wall-clock time often disagrees. Why?

Compute-bound vs memory-bound

A rough A100 FP16 peak is 312 TFLOP/s. A TFLOP is 10^12 FLOPs.

If prefill fully used that compute:

7.168e12 FLOPs / 312e12 FLOPs/sec
= 0.02297 sec
≈ 23 ms

That is plausible because prefill is a large dense operation. Large matmuls keep the GPU busy.

For decode at batch size 1, the GPU repeatedly does small matrix-vector operations. Each decode step touches a large fraction of the model weights, but only for one token. For a 7B FP16 model, the weights occupy roughly:

7e9 params * 2 bytes/param = 14 GB

If each token requires reading about 14 GB of weights from HBM, and memory bandwidth is around 2 TB/s:

14 GB / 2000 GB/sec = 0.007 sec = 7 ms

That is a lower bound. Kernel overhead, imperfect bandwidth, attention, sampling, and framework overhead can easily push a single-token decode step toward 10-20 ms. At 14 ms/token:

128 tokens * 14 ms/token = 1792 ms = 1.79 sec

So the request can look like:

prefill: ~23 ms
decode:  ~1792 ms

The important lesson: FLOPs are not latency. FLOPs tell you how much arithmetic exists. Latency depends on whether the GPU is limited by arithmetic throughput, memory bandwidth, launch overhead, or scheduling.

This is the bridge to Day 16's roofline model.

Sampling: Turning Scores Into Tokens

Logits become text through a token picker.

By the end of this section you should be able to implement the common sampling algorithms from logits alone.

Start with raw logits

Suppose the model's final logit vector has 5 entries:

token id:   0      1      2      3      4
logit:    2.0   -1.5    0.3    4.2   -0.8

Logits are not probabilities. They are unnormalized scores. Higher means more likely, but the values do not sum to 1.

Apply softmax:

softmax([2.0, -1.5, 0.3, 4.2, -0.8])
≈ [0.097, 0.003, 0.018, 0.876, 0.006]

Token 3 dominates with probability about 87.6%.

Greedy decoding

Greedy decoding chooses the highest-logit token:

argmax(logits) = token 3

This is deterministic. Same prompt, same model, same output. It is useful for tests, extraction tasks, and any setting where randomness is unwanted.

In API settings, this is often described as temperature = 0, though mathematically temperature 0 is a limit, not a literal division by zero.

def greedy_sample(logits):
    # logits: (B, V)
    return logits.argmax(dim=-1, keepdim=True)

Temperature sampling

Temperature rescales logits before softmax:

probs = softmax(logits / tau)

where tau is temperature.

  • tau < 1 sharpens the distribution.
  • tau = 1 leaves it unchanged.
  • tau > 1 flattens the distribution.

Use tau, not T, because T already means sequence length.

Concrete example with the logits above:

tau = 1.0:
[0.097, 0.003, 0.018, 0.876, 0.006]

tau = 0.7:
[0.041, 0.000, 0.004, 0.954, 0.001]

tau = 2.0:
[0.206, 0.036, 0.088, 0.619, 0.051]

At tau = 0.7, the best token becomes even more likely. At tau = 2.0, weaker tokens get a chance.

The same idea can be seen on a simpler probability distribution:

base probabilities = [0.7, 0.2, 0.1]

Temperature can be applied to probabilities as:

p_i' = p_i ** (1 / tau) / sum_j p_j ** (1 / tau)

For tau = 0.5:

[0.7^2, 0.2^2, 0.1^2] = [0.49, 0.04, 0.01]
sum = 0.54
renormalized = [0.907, 0.074, 0.019]

For tau = 2.0:

[sqrt(0.7), sqrt(0.2), sqrt(0.1)]
≈ [0.837, 0.447, 0.316]
sum ≈ 1.600
renormalized ≈ [0.523, 0.279, 0.198]

Top-k sampling

Top-k keeps the k highest-logit tokens and removes the rest.

For our logits:

token id:   0      1      2      3      4
logit:    2.0   -1.5    0.3    4.2   -0.8

The top 2 tokens are:

token 3: logit 4.2
token 0: logit 2.0

Softmax only over those two:

softmax([4.2, 2.0]) ≈ [0.900, 0.100]

So top-k with k = 2 says: "sample only from tokens 3 and 0; ignore the other three completely."

Top-k is simple, but it is not adaptive. k = 50 might be too wide when the model is certain and too narrow when the model is uncertain.

Top-p sampling

Top-p, also called nucleus sampling, keeps the smallest set of tokens whose cumulative probability reaches at least p.

Using the original softmax probabilities:

token id sorted by probability: 3      0      2      4      1
probability:                    0.876  0.097  0.018  0.006  0.003
cumulative:                     0.876  0.973  0.991  0.997  1.000

For top_p = 0.9, keep tokens 3 and 0:

token 3 alone reaches 0.876, which is below 0.9
token 3 + token 0 reaches 0.973, which is above 0.9

Then renormalize those kept probabilities:

token 3: 0.876 / 0.973 ≈ 0.900
token 0: 0.097 / 0.973 ≈ 0.100

In this example, top-p 0.9 keeps the same two tokens as top-k 2. In general it adapts:

  • If the model is very certain, top-p may keep only 1 or 2 tokens.
  • If the model is uncertain, top-p may keep dozens or hundreds.

That is why top-p usually feels better than top-k for open-ended writing.

Same Logits, Different Sampling Rules Greedy pick max Temperature tau = 2 flatter distribution Top-p = 0.9 keep nucleus, renorm Bars are illustrative probabilities from logits [2.0, -1.5, 0.3, 4.2, -0.8]. Greedy picks one token; sampling keeps a distribution.
Greedy decoding collapses the distribution to one token. Temperature reshapes the distribution. Top-p truncates the tail and renormalizes the remaining nucleus.
Sampling Code From Scratch

The generation loop is forward, sample, append, repeat.

By the end of this section you should be able to write the token picker yourself.

Shared helper

import torch
import torch.nn.functional as F

def sample_next_token(
    logits: torch.Tensor,
    *,
    temperature: float = 1.0,
    top_k: int | None = None,
    top_p: float | None = None,
) -> torch.Tensor:
    """
    logits: (B, V), raw scores for the next token.
    returns: (B, 1), sampled token ids.
    """
    if temperature == 0:
        return logits.argmax(dim=-1, keepdim=True)

    if temperature < 0:
        raise ValueError("temperature must be non-negative")

    logits = logits / temperature

    if top_k is not None:
        values, _ = torch.topk(logits, k=top_k, dim=-1)
        kth_value = values[:, [-1]]
        logits = torch.where(
            logits < kth_value,
            torch.full_like(logits, float("-inf")),
            logits,
        )

    if top_p is not None:
        sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative = sorted_probs.cumsum(dim=-1)

        # Remove tokens after cumulative probability exceeds top_p.
        # Keep the first token above the threshold, because it is what reaches p.
        remove_sorted = cumulative > top_p
        remove_sorted[:, 1:] = remove_sorted[:, :-1].clone()
        remove_sorted[:, 0] = False

        remove = torch.zeros_like(remove_sorted).scatter(1, sorted_idx, remove_sorted)
        logits = logits.masked_fill(remove, float("-inf"))

    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

The top-p block is the only tricky part. Read it slowly:

  1. Sort logits from most likely to least likely.
  2. Softmax the sorted logits.
  3. Compute cumulative probability.
  4. Mark tokens after the nucleus threshold.
  5. Shift the mask right so the token that crosses the threshold stays included.
  6. Scatter the sorted mask back to original vocabulary order.
  7. Set removed logits to -inf, so their softmax probability becomes 0.

Full generate loop

This version assumes your model has the same simple interface as Day 9:

logits = model(idx)  # idx shape: (B, T), logits shape: (B, T, V)
@torch.no_grad()
def generate(
    model,
    idx: torch.Tensor,
    *,
    max_new_tokens: int,
    block_size: int,
    temperature: float = 1.0,
    top_k: int | None = None,
    top_p: float | None = None,
    eos_token_id: int | None = None,
) -> torch.Tensor:
    """
    idx: (B, T_start), prompt token ids.
    returns: (B, T_start + generated), prompt plus generated ids.

    This intentionally does not use a KV cache yet.
    Day 20 will replace repeated full-prefix forward passes with cached decode.
    """
    model.eval()

    for _ in range(max_new_tokens):
        # Crop to model context window.
        idx_cond = idx[:, -block_size:]

        # Forward pass over current context.
        logits = model(idx_cond)              # (B, T_context, V)
        next_logits = logits[:, -1, :]        # (B, V)

        # Convert next-token logits into an actual sampled token id.
        next_id = sample_next_token(
            next_logits,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )                                     # (B, 1)

        # Append sampled token.
        idx = torch.cat([idx, next_id], dim=1)

        if eos_token_id is not None and (next_id == eos_token_id).all():
            break

    return idx

This is intentionally minimal. Production engines add:

  • KV cache.
  • Batching and scheduling.
  • Streaming callbacks.
  • Stop-string matching.
  • Log-prob returns.
  • Penalties.
  • Speculative decoding.

But the skeleton is the same: forward pass, sample, append, repeat.

Inference Parameters

The common knobs are simple logit and stop rules.

By the end of this section you should be able to explain the common generation knobs without treating them as magic.

max_new_tokens

max_new_tokens is a hard cap on generated tokens.

If it is too low, the model can be cut off mid-answer:

Prompt: "Explain FlashAttention in two paragraphs."
max_new_tokens = 12
Output: "FlashAttention is an optimized attention algorithm that reduces memory traffic by"

This is not a model-quality failure. It is a decoding configuration failure.

Stop tokens and stop strings

A stop condition ends generation early.

Examples:

  • EOS token: model emits a special end-of-sequence token.
  • Stop string: generated text contains "\n\nUser:".
  • Tool protocol delimiter: generated text contains "</tool_call>".

Token-level stop conditions are easier and safer than string-level ones because generation happens in token IDs. String stops require decoding partial text and checking boundaries.

Repetition penalty

Repetition penalty reduces the chance of repeating tokens that already appeared.

A simple version:

def apply_repetition_penalty(logits, generated_ids, penalty=1.1):
    # logits: (B, V), generated_ids: (B, T)
    logits = logits.clone()
    for b in range(logits.shape[0]):
        seen = set(generated_ids[b].tolist())
        for token_id in seen:
            if logits[b, token_id] > 0:
                logits[b, token_id] /= penalty
            else:
                logits[b, token_id] *= penalty
    return logits

This is heuristic, not a training objective. It is useful when a model gets stuck in loops, but too much penalty can make outputs weird by suppressing legitimate repeated words.

Frequency and presence penalties

These are common in APIs.

Presence penalty says: "If this token has appeared at all, subtract a fixed amount from its logit."

logit[token] -= presence_penalty

Frequency penalty says: "Subtract more if this token has appeared many times."

logit[token] -= frequency_penalty * count[token]

Concrete example:

base logit for token "the" = 3.0
count["the"] = 4
presence_penalty = 0.2
frequency_penalty = 0.1

adjusted logit = 3.0 - 0.2 - 0.1 * 4
               = 3.0 - 0.2 - 0.4
               = 2.4

The token can still be sampled. It is just less likely.

Why Decode Is the Bottleneck

Autoregression makes serving a systems problem.

By the end of this section you should understand why inference systems obsess over batching and caches.

The serial dependency

Token n + 1 depends on token n.

token 1 -> token 2 -> token 3 -> token 4 -> ...

You cannot generate token 4 before you know token 3. This is the autoregressive constraint.

Prefill can parallelize across positions because all prompt tokens are known. Decode cannot parallelize across generated positions because future generated tokens are not known yet.

Batch size changes the economics

At batch size 1, decode often looks like:

read model weights -> do one token of useful work -> repeat

At larger batches, the same weights can be used for many sequences at once:

read model weights -> do B tokens of useful work -> repeat

That raises arithmetic intensity: more math per byte loaded. This is why continuous batching is such a big deal in production inference. Day 24 is the deep dive.

KV cache changes the context cost

Without a cache, the decode step after a 512-token prompt would run attention over all 512 previous tokens and recompute their keys and values. Then the next step would recompute 513 tokens. Then 514. That is wasteful.

With a KV cache:

  • Previous keys and values are stored.
  • The new token computes only its own key and value.
  • Attention reads old cached K/V and writes one new cache row.

This turns decode from "recompute the full prefix every time" into "append one new token's K/V and attend over the cache." Day 20 will turn this into exact memory math.

Exercise

Implement sampling and measure generation.

Use the companion notebook, then swap in your Day 9 GPT for meaningful timings.

Do this exercise in a notebook or a single Python file. Use your Day 9 tiny GPT if you have it; otherwise stub the model with random logits and focus on sampling first.

Part 1: Sampling by hand

Given:

logits = [2.0, -1.5, 0.3, 4.2, -0.8]
  1. Compute softmax by subtracting the max first.
  2. Identify the greedy token.
  3. Apply temperature tau = 2.0 and recompute probabilities.
  4. Apply top-k with k = 2. Which tokens remain?
  5. Apply top-p with p = 0.9. Which tokens remain?

Expected approximate values:

softmax tau=1.0 -> [0.097, 0.003, 0.018, 0.876, 0.006]
softmax tau=2.0 -> [0.206, 0.036, 0.088, 0.619, 0.051]
greedy token      -> 3
top-k k=2         -> tokens 3 and 0
top-p p=0.9       -> tokens 3 and 0

Part 2: Sampling code

Implement:

greedy_sample(logits)
temperature_sample(logits, tau)
top_k_sample(logits, k)
top_p_sample(logits, p)

Test each on the logits above. Use torch.manual_seed(0) before stochastic tests so results are reproducible.

Part 3: Generate loop

Implement the full generate() loop from this lesson. Run it with:

  1. Greedy.
  2. Temperature 0.8.
  3. Temperature 1.2.
  4. Top-p 0.9.

Compare outputs on the same prompt. Do not judge only "quality"; observe diversity, repetition, and failure modes.

Part 4: Timing

Measure:

start = time.perf_counter()
out = generate(...)
elapsed = time.perf_counter() - start
tokens_per_sec = generated_tokens / elapsed

Then repeat with prompt lengths:

T_prompt = 16, 64, 256, 512

Because this no-cache version recomputes the prefix, longer prompts should slow down decode. That pain is the motivation for KV caching.

Self-Check Questions

Ten questions before moving to hardware.

Close the page and answer from memory. If you cannot, re-read the relevant section.

  1. What is the difference between prefill and decode?
  2. Why can prefill parallelize over prompt tokens while decode cannot parallelize over generated tokens?
  3. What does TTFT measure? What does ITL measure?
  4. For a 7B model and a 512-token prompt, estimate prefill FLOPs.
  5. Why can prefill have more FLOPs than decode but still feel faster?
  6. What happens to a probability distribution when temperature is below 1? Above 1?
  7. Why is top-p more adaptive than top-k?
  8. What failure can max_new_tokens cause?
  9. What does the KV cache save?
  10. Why does batch size improve decode throughput?
Answer key
  1. Prefill is the full forward pass over the prompt; decode is the repeated one-token generation loop after that.
  2. Prompt tokens are all known, so their masked computations can run together. Generated token n+1 depends on generated token n, so future decode steps cannot start yet.
  3. TTFT is the wait until the first generated token. ITL is the time between subsequent streamed tokens.
  4. 2 * 7e9 * 512 = 7.168e12 FLOPs.
  5. Prefill uses large dense matmuls and can be compute-bound; batch-1 decode repeatedly reads weights for small matrix-vector work and is often memory-bound.
  6. Below 1 sharpens; above 1 flattens.
  7. Top-p keeps as many tokens as needed to reach a probability mass threshold, so it keeps fewer tokens when the model is confident and more when uncertain.
  8. It can truncate the answer mid-thought.
  9. It stores keys and values for previous tokens so decode does not recompute the whole prefix every step.
  10. More sequences reuse the same weight reads, increasing math per byte loaded.
What to Take Away

Two regimes, one loop.

An LLM inference request is not one operation. It is two regimes glued together:

prefill: big parallel prompt pass -> first token
decode:  sequential one-token loop -> remaining tokens

Prefill affects time to first token. Decode affects streaming speed. Sampling decides which token turns the model's probability distribution into actual text. The rest of Week 3 is about making this loop faster by understanding the hardware and removing unnecessary memory movement.

"Inference is two regimes glued together: a wide prompt pass to get the first token, then a narrow sequential loop that has to be made fast one token at a time."

Day 15 · Week 3 start
Further Reading

Go deeper.

Sampling docs, inference arithmetic, and the serving systems that build on today's split.

Docs · Hugging Face

Hugging Face generation strategies

Practical documentation for greedy, beam, temperature, top-k, top-p, and other decoding strategies.

Open docs
Blog · Hugging Face

How to generate text

The classic walkthrough of decoding strategies and their tradeoffs.

Read post
Blog · Inference Math

Transformer Inference Arithmetic

Concise compute and memory math for inference. Useful before KV cache and roofline lessons.

Read post
Guide · BentoML

The LLM Inference Handbook

Systems-oriented overview of LLM inference metrics, serving constraints, and deployment tradeoffs.

Open guide
Paper · vLLM

PagedAttention

The paper behind vLLM memory management. It turns the KV-cache pressure introduced today into a serving architecture.

Open paper
Notebook · Day 15

Sampling from scratch

Companion notebook for stable softmax, greedy/temperature/top-k/top-p sampling, and a minimal generate loop.

Open notebook