LLM Inference Engineer · Day 10
Day 10 · Week 2 · Training & Architectures
🔁

The Training Loop, Production-Grade

Yesterday's loop was minimal and legible. Today we add everything that separates a toy run from a real one: a deep dive into optimizers (SGD through AdamW), learning-rate warmup and cosine decay, mixed precision, gradient accumulation, gradient checkpointing, a full memory budget breakdown, throughput measurement, robust checkpointing, and a practical debugging guide. The same techniques scale from your laptop to a thousand-GPU cluster.

Time~180 min
DifficultyMedium-Hard
PrerequisiteDay 9
Why This Lesson

The gap between "it trains" and "it trains well" is engineering.

Day 9's loop works, but it leaves performance and stability on the table. A production training run is faster, more memory-efficient, and far more robust against the things that go wrong over hours or days of training. None of today's techniques change the model or the objective — they change how efficiently and reliably you optimize it. Every one of them is standard in nanoGPT, Megatron, and the recipe behind every open model you have heard of.

This lesson also matters directly for inference. Mixed precision is the same numerics story as quantization (Week 4). Gradient checkpointing trades compute for memory — the same tradeoff you will see in KV-cache management. Throughput and MFU are the same efficiency lens you will apply to serving. And the checkpoint you write here is the artifact your inference engine loads on Day 27. Training and inference share far more machinery than they appear to.

Learning objectives

  1. Trace a full production training step from zero_grad to opt.step(), with every technique in the right order.
  2. Understand SGD, momentum, Adam, and AdamW from first principles; know their update equations and memory costs.
  3. Implement and explain learning-rate warmup + cosine decay.
  4. Explain FP32 vs FP16 vs BF16; use autocast + GradScaler correctly.
  5. Apply gradient accumulation and gradient clipping; derive the effective-batch-size formula.
  6. Describe gradient checkpointing and when to use it over simply buying more memory.
  7. Build a full memory budget for a training run and explain why training needs so much more memory than inference.
  8. Measure throughput in tokens/sec and estimate Model FLOPs Utilization (MFU).
  9. Write a checkpoint/resume cycle that survives a crash mid-run.
  10. Diagnose common training pathologies: loss spikes, NaNs, and divergence.
The Production Step

Seven operations in a precise order — getting one wrong breaks everything.

The Day 9 loop did four things per step: zero gradients, forward pass, backward pass, optimizer step. A production loop does seven, and the order is not arbitrary. Missing a step or swapping two produces subtle bugs that can silently corrupt training for thousands of steps before you notice anything wrong.

1. zero_grad set_to_none=True saves memory 2. set LR warmup+cosine every step 3. forward autocast(bf16) repeat × accum 4. scale loss loss / accum mean, not sum 5. backward accumulates .grad scaler.scale 6. clip grads unscale then clip norm ≤ 1.0 7. opt.step() update weights scaler.update repeat for each micro-batch Steps 3–5 repeat accum_steps times before proceeding to step 6.
The seven operations of a production training step. Steps 3–5 repeat for each micro-batch in the accumulation loop. Steps 6 and 7 execute once per logical batch. Getting the order wrong — especially unscaling before clipping — produces silent bugs.

Train/eval split and deterministic seeding

Two habits you must wire in from the start. First, always split your data before training and measure held-out validation loss at fixed intervals — it is the only honest measure of generalization. Second, seed everything before creating the model and data loaders: torch.manual_seed(42), random.seed(42), numpy.random.seed(42). Seeding makes runs reproducible, which is essential when debugging or comparing two configs. Save the seed in your checkpoint so a resumed run can restore exact RNG state.

# Deterministic setup — do this before anything else.
import random, numpy as np, torch
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Data split.
n = int(0.9 * len(tokens))          # 90/10 split
train_data, val_data = tokens[:n], tokens[n:]
Optimizers in Depth

SGD to AdamW — update equations, intuition, and memory cost.

An optimizer is the algorithm that turns gradients into weight updates. The right choice affects convergence speed, final loss, and memory consumption — and the memory impact is large enough to break your budget on a big model. Let's build up from first principles.

Vanilla SGD

Gradient descent is conceptually simple: move the weights in the direction that reduces the loss, scaled by a step size (learning rate).

w ← w − lr · ∇L(w)

Plain SGD requires zero extra memory beyond the weights themselves — one scalar per parameter. The problem is that the gradient on a single mini-batch is noisy; the step direction changes erratically, making convergence slow and sensitive to the learning rate.

SGD with momentum

Momentum smooths the gradient signal by maintaining a running average — the "velocity" — of past gradients. Think of it as a ball rolling down a loss surface that accumulates speed in consistent directions and dampens noise.

v ← β · v + (1 − β) · ∇L(w) # β ≈ 0.9 is typical w ← w − lr · v

Momentum adds one extra state tensor (v) per parameter — doubling memory vs. plain SGD. For a 7B-parameter model in FP32 that is an extra 28 GB. Already significant.

Adam — adaptive learning rates per parameter

Adam (Adaptive Moment Estimation) maintains two running averages per parameter: the first moment m (mean of gradients, like momentum) and the second moment v (mean of squared gradients, a per-parameter estimate of curvature). The update divides by the square root of v, giving each parameter its own effective learning rate — parameters with consistently large gradients get smaller steps, and parameters with small gradients get larger steps.

m ← β₁ · m + (1 − β₁) · g # first moment (β₁=0.9) v ← β₂ · v + (1 − β₂) · g² # second moment (β₂=0.95 or 0.999) m̂ = m / (1 − β₁ᵗ) # bias correction: early steps are reliable v̂ = v / (1 − β₂ᵗ) w ← w − lr · m̂ / (√v̂ + ε) # ε=1e-8 prevents division by zero

The bias corrections (dividing by 1 − βᵗ) matter most in the early steps when the exponential averages are initialized at zero and would otherwise underestimate the true moment. By step ~100 they are negligible. Adam converges faster than SGD on almost every language model task. The catch: two extra state tensors per parameter means Adam's optimizer state alone is 2× the weight size. For a 7B model in FP32 that is 56 GB of optimizer state on top of the 28 GB of weights.

AdamW — decoupled weight decay

Regular Adam with weight decay applies the decay inside the gradient: g_eff = g + λ·w. This causes the adaptive scaling to also shrink the weight-decay term — conceptually wrong and empirically worse. AdamW decouples weight decay: it applies the L2 penalty directly to the weights, after the adaptive update:

w ← w − lr · m̂ / (√v̂ + ε) − lr · λ · w # λ = weight_decay

This one line is the only difference between Adam and AdamW, but it matters. Decoupling weight decay makes the regularization strength independent of the per-parameter learning rate, which is exactly what you want: rare parameters (with small ) get large adaptive steps and strong regularization, without the two fighting each other. AdamW is the default optimizer for every modern LLM training recipe. Use weight_decay=0.1 on weight matrices; set it to 0.0 on biases and layer-norm parameters (they should not be decayed).

What the hyperparameters do

β₁ (default 0.9): how quickly the first-moment estimate tracks the current gradient. Lower values make it more reactive; higher values smooth more. Almost nobody changes this. β₂ (0.95 for LLMs, 0.999 for other tasks): controls the second moment. A lower β₂ like 0.95 makes the adaptive scaling react faster to gradient magnitude changes — Karpathy uses 0.95 for GPT-2 reproduction. ε (1e-8): prevents division by zero; rarely matters unless you have consistently zero-variance parameters. weight_decay (0.1): the L2 regularization strength. Too high shrinks weights toward zero and hurts learning; too low and the model overfits. 0.1 is the sweet spot for most LLMs.

Optimizer Extra state / param Key property Typical use WD coupled? SGD 0 (none) Simple, noisy steps CNNs, fine-tuning SGD+momentum 1 (velocity v) Smooth direction via EMA of g Vision models Yes Adam 2 (m and v) Per-param LR via second moment NLP generally Yes (bug) AdamW ✓ 2 (m and v) Decoupled weight decay All LLMs No (correct) Adam and AdamW require 2 extra state tensors per parameter — optimizer state alone is 2× the weight size.
Optimizer comparison. AdamW's only difference from Adam is decoupled weight decay: the L2 penalty is applied directly to weights rather than through the gradient, making regularization independent of the per-parameter adaptive learning rate.

Optimizer state memory — the training tax

This is one of the most important practical points in the whole lesson. At inference time you only need the weights. During training you also maintain optimizer state. For AdamW with FP32 masters:

Memory per parameter: weights (FP32): 4 bytes gradients (FP32): 4 bytes ← same size as weights Adam m (FP32): 4 bytes ┐ Adam v (FP32): 4 bytes ┘ optimizer state = 2× weights Total: 16 bytes / parameter Example — 7B parameters: Weights only (FP16): 7B × 2 = 14 GB ← inference Training (FP32 AO): 7B × 16 = 112 GB ← training, no activations yet

That is why inference memory ≈ model size but training memory ≈ 8× model size (before you add activation memory). This factor-of-8 gap is the single most common source of surprise when people first try to train a model they can easily serve. We will account for the activations in the Memory Budget section below.

# AdamW: apply weight decay only to weights, not biases or norms.
decay_params = [p for n, p in model.named_parameters()
                if p.dim() >= 2 and p.requires_grad]        # matrices
no_decay_params = [p for n, p in model.named_parameters()
                   if p.dim() < 2 and p.requires_grad]      # biases, norms
opt = torch.optim.AdamW([
    {"params": decay_params,    "weight_decay": 0.1},
    {"params": no_decay_params, "weight_decay": 0.0},
], lr=6e-4, betas=(0.9, 0.95), eps=1e-8)

Adam's optimizer state is exactly 2× the parameter count in bytes — because m and v are the same dtype as the weights. For a 70B-parameter model in FP32, that is 560 GB of optimizer state alone. This is why optimizer-state sharding (ZeRO Stage 2) was one of the most important distributed-training innovations of the past five years — it splits that state evenly across GPUs instead of replicating it.

Learning-Rate Schedule

Warm up linearly, then decay on a cosine. The near-universal LLM schedule.

A constant learning rate is rarely optimal. Two ideas, used together, define the standard for LLMs. Warmup: start the LR near zero and ramp it up linearly over the first few hundred to few thousand steps. Early in training the model's weights are random, the gradients are large and noisy, and the optimizer's second-moment estimate v is initialized to zero so Adam's denominator starts very small — producing artificially large steps. Warmup eases past this fragile phase. Cosine decay: after warmup, decay the LR following a cosine curve down to a small final value (often 10% of peak). The cosine shape spends most of training at a high, productive LR and then gently anneals for a clean convergence.

lr(step) = peak_lr × (step+1) / warmup_steps, if step < warmup min_lr + 0.5 × (peak_lr − min_lr) × (1 + cos(π × decay_ratio)), if warmup ≤ step ≤ total min_lr, if step > total where decay_ratio = (step − warmup) / (total − warmup) ∈ [0, 1]

Worked numbers for a GPT-2 124M reproduce: peak_lr = 6e-4, min_lr = 6e-5 (10%), warmup = 715 steps (~0.5% of 143k total steps). Notice the min_lr is 10× smaller than the peak — the final cosine value, not zero. Decaying all the way to zero throws away some of the progress made during training.

steps LR 6e-4 6e-5 warmup ~1% of steps cosine decay peak LR min LR (10% of peak, not zero)
The canonical LLM learning-rate schedule: linear warmup over roughly 1% of total steps, then a smooth cosine decay to 10% of the peak LR. The minimum is intentionally not zero — decaying all the way removes useful signal late in training.
import math

def lr_at(step, peak_lr, warmup, total, min_lr):
    if step < warmup:                          # linear warmup
        return peak_lr * (step + 1) / warmup
    if step > total:                           # past the schedule
        return min_lr
    ratio = (step - warmup) / (total - warmup) # [0, 1]
    coeff = 0.5 * (1.0 + math.cos(math.pi * ratio))
    return min_lr + coeff * (peak_lr - min_lr)

# Apply every step before opt.step():
for g in opt.param_groups:
    g["lr"] = lr_at(step, peak_lr=6e-4, warmup=715, total=143000, min_lr=6e-5)

Why not use a PyTorch scheduler? You can, but writing lr_at explicitly makes the schedule fully transparent — you always know the exact LR at every step, and resuming from a checkpoint is trivial (just pass the step counter). Scheduler state can drift or reset unexpectedly during multi-node restarts.

Mixed Precision

Compute in 16 bits, accumulate in 32. Roughly 2–3× faster, half the memory.

By default PyTorch stores and computes in 32-bit floats (FP32). But modern accelerators have dedicated hardware (NVIDIA Tensor Cores, Apple Matrix Engines) that run 16-bit matrix multiplies several times faster, and 16-bit tensors take half the memory. Mixed precision training runs the heavy operations in 16-bit while keeping a 32-bit master copy of the weights and doing gradient accumulation and optimizer math in 32-bit, capturing most of the speed and memory win with negligible accuracy loss.

FP16 vs BF16 — the exponent/mantissa tradeoff

There are two 16-bit floating-point formats and the difference is consequential. Both use 16 bits total. The split between exponent and mantissa bits determines the tradeoff between range (which values can be represented) and precision (how finely they are subdivided).

FP32 S 8 exponent bits 23 mantissa bits range ±3.4×10³⁸ FP16 S 5 exp bits 10 mantissa bits range ±65504 ← can overflow! BF16 S 8 exponent bits 7 mantissa bits same range as FP32 — no overflow
FP32, FP16, and BF16 bit layouts. BF16 keeps FP32's 8 exponent bits (same dynamic range), trading mantissa precision. FP16's 5-bit exponent means the largest representable value is only 65504 — gradient magnitudes routinely exceed this, causing overflow or underflow without loss scaling.
FormatBits (S/E/M)Max valuePrecisionLoss scaling?Best for
FP321/8/233.4×10³⁸HighNoMaster weights, optimizer state
FP161/5/1065,504GoodYesOlder GPUs (V100), Apple A-series
BF161/8/73.4×10³⁸LowerNoA100/H100, Apple M-series (preferred)

BF16 keeps FP32's 8 exponent bits, so it has the same dynamic range and almost never overflows or underflows — you can use it without a gradient scaler. FP16 has only 5 exponent bits, so small gradients underflow to zero; it needs a GradScaler that multiplies the loss by a large factor before backprop and unscales the gradients afterward. Rule of thumb: use BF16 if your hardware supports it (Ampere/Hopper GPUs, Apple Silicon M-series); otherwise use FP16 with a scaler.

from torch.amp import autocast, GradScaler

# BF16 path (A100/H100/M-series): no scaler needed.
for xb, yb in batches:
    with autocast(device_type="cuda", dtype=torch.bfloat16):
        _, loss = model(xb, yb)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

# FP16 path (older GPUs): GradScaler prevents gradient underflow.
scaler = GradScaler()
for xb, yb in batches:
    with autocast(device_type="cuda", dtype=torch.float16):
        _, loss = model(xb, yb)
    opt.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

BF16 (bfloat16) was designed by Google Brain for TPUs precisely so that deep-learning code could drop FP32 to 16-bit without the loss-scaling gymnastics FP16 requires. It throws away mantissa precision but keeps the full exponent range — and it turns out neural nets care far more about range than precision. That single insight reshaped ML hardware design. Every major accelerator since 2020 (A100, H100, TPUv4, Apple M-series) has dedicated BF16 compute paths.

Gradient Accumulation

Simulate a huge batch by summing gradients over several small ones.

Large batches stabilize training, reduce variance in gradient estimates, and let you use a higher peak LR (the linear scaling rule: double the batch, roughly double the LR). But large batches may not fit in memory. Gradient accumulation decouples the compute batch (what fits in memory) from the optimizer batch (what you want to train on). You run several forward/backward passes, accumulating gradients in .grad, and only call opt.step() after accum_steps micro-batches. The effective batch size becomes:

effective_batch = micro_batch × accum_steps × num_gpus

For GPT-2 124M, Karpathy uses micro_batch=16, accum_steps=32, num_gpus=8 to reach the target of 4096 × 8 = 524,288 tokens per step. Each individual GPU only needs to fit 16 × 1024 = 16,384 tokens at once.

4 micro-batches → 1 optimizer step (accum_steps = 4) micro-batch 1 forward loss / 4 → backward micro-batch 2 forward loss / 4 → backward micro-batch 3 forward loss / 4 → backward micro-batch 4 forward loss / 4 → backward param.grad += g₁/4 + g₂/4 + g₃/4 + g₄/4 (accumulated in-place) clip_grad_norm + opt.step() one parameter update
Gradient accumulation: four micro-batches each contribute loss/4 to the backward pass. PyTorch adds gradients into .grad automatically. After all micro-batches, a single clip + optimizer step updates the weights — equivalent to a batch 4× larger.
accum_steps = 4        # effective batch = micro_batch * accum_steps
opt.zero_grad(set_to_none=True)
for micro in range(accum_steps):
    xb, yb = get_batch(train_data)
    with autocast(device_type=device, dtype=torch.bfloat16):
        _, loss = model(xb, yb)
        loss = loss / accum_steps   # CRITICAL: divide so sum == mean
    loss.backward()                 # gradients accumulate in .grad
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()

The one subtlety to burn into memory: divide the loss by accum_steps so that summing the per-micro-batch gradients yields the mean gradient, matching what a single large batch would produce. Forget this and your effective learning rate is silently accum_steps× too large — a bug that is very easy to miss because the loss still decreases, just noisily and to a worse optimum.

Gradient Clipping & Gradient Checkpointing

Clip to prevent blowup. Checkpoint activations to save memory.

Gradient clipping

Occasionally a batch produces an enormous gradient — a rare token, a numerical edge case, an unlucky initialization region. Without protection, that single step can take a giant stride in parameter space, spike the loss, and sometimes never recover (you see the dreaded loss-goes-to-NaN). Gradient clipping rescales the whole gradient vector so its global norm never exceeds a threshold. The direction is preserved; only the magnitude is capped:

if ‖g‖ > max_norm: g ← g × (max_norm / ‖g‖) In practice: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Clipping is cheap insurance. Almost every LLM training recipe uses max_norm = 1.0. Monitor the pre-clip gradient norm as a diagnostic signal: if it is usually well under 1.0 and occasionally spikes, clipping is doing its job. If it is constantly being clipped (norm consistently above 1.0), your learning rate is probably too high.

# After backward(), before step(). With FP16 scaler, unscale first.
if amp_dtype == torch.float16:
    scaler.unscale_(opt)   # converts fp16 grads to fp32 for correct norm
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt)
scaler.update()
# grad_norm is the pre-clip norm — log it every step as a health metric.

Gradient checkpointing — recompute activations to save memory

During the backward pass PyTorch needs the intermediate activations computed during the forward pass — they are used in the chain rule to compute each layer's gradient. By default all of them are stored simultaneously in GPU memory. For a deep Transformer this can be enormous: a 12-layer model with batch size 64 and sequence length 1024 might store several gigabytes of activations.

Gradient checkpointing (also called activation recomputation) breaks this tradeoff: instead of storing every activation, you store only the outputs of a subset of "checkpoint" layers and recompute the in-between activations during the backward pass when they are needed. This trades compute for memory: you run parts of the forward pass twice, but you only need to hold one block's worth of activations at any time.

Standard (store all activations) Gradient Checkpointing Block 1 Block 2 Block 3 Block 4 all activations stored → high memory activation memory: 4× per block Memory: O(L) per layer Compute: 1× forward pass PyTorch default Block 1 saved Block 2 discarded Block 3 saved Block 4 discarded only checkpoints stored → low memory activation memory: ~2× per block Memory: O(√L) with segment ckpt Compute: ~1.3× forward pass torch.utils.checkpoint.checkpoint Tradeoff: 30% more compute to roughly halve peak activation memory. Worth it when OOM, not worth it when compute-bound.
Gradient checkpointing stores activations only at designated "checkpoint" boundaries and recomputes the discarded intermediate activations during the backward pass. This halves peak activation memory at the cost of roughly 30% extra compute — a common tradeoff when fitting a larger model or batch size matters more than speed.
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        # Wrap with checkpoint to recompute activations during backward.
        return checkpoint(self._forward, x, use_reentrant=False)

    def _forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

Use gradient checkpointing when you are running out of activation memory and cannot reduce the batch size further. Do not use it when your run is already compute-bound — the 30% extra compute will hurt throughput significantly.

Memory Budget

Weights + gradients + optimizer states + activations. All four, simultaneously.

Training is a memory-budgeting problem. Unlike inference, where you only need to hold the weights and the KV cache, training simultaneously holds four components in GPU memory. Understanding each lets you predict whether a config fits and which component to attack when it does not.

The four components

Let N = number of parameters, B = batch size, S = seq length, L = num layers, D = model dim 1. Weights (FP16/BF16): 2 bytes × N (or FP32 master copy: 4 bytes × N) 2. Gradients (same as weights): 2–4 bytes × N 3. Optimizer state (AdamW FP32): 8 bytes × N (m + v, both FP32) 4. Activations (FP32, no ckpt): roughly 12 × B × S × D × L bytes (rough rule: ~60 MB per B=1, S=1024, D=768, L=12) Total (FP32 masters, no ckpt): ≈ (2+2+8) × N + activations = 12 bytes/param + activation memory
Memory breakdown — GPT-2 124M, B=32, S=1024, FP32 masters Weights ~0.5 GB FP32 Gradients ~0.5 GB FP32 Optimizer state (Adam m + v) ~1.0 GB Activations (B=32, S=1024, L=12) ~1.5 GB Total ≈ 3.5 GB for 124M params at B=32 — scales linearly with batch. Inference only (BF16 weights): ~0.25 GB ← just the weights Training (no activations ckpt): ~3.5 GB ← 14× more With gradient checkpointing: ~2.5 GB ← saves ~1 GB of activations For a 7B model: inference ≈ 14 GB (BF16), training ≈ 112+ GB (FP32 masters, no ckpt)
Memory breakdown for a 124M GPT-2-scale model at batch size 32. The optimizer state (Adam m + v) alone is 2× the weights. Activations dominate at large batch sizes. Gradient checkpointing trades ~30% compute to shrink the activation component significantly.

The key takeaway: training needs approximately 8–16× more memory than inference for the same model. Inference only needs BF16 weights (2 bytes/param). Training needs FP32 master weights + FP32 gradients + FP32 optimizer state + activations = 12+ bytes/param plus activation memory that grows with batch size. This is why you can serve LLaMA-7B on a 16 GB GPU but cannot train it on a single machine without distributed optimizations like ZeRO (Day 11).

Model scaleParamsInference (BF16)Training (FP32 AO)Training + ckpt
GPT-2 small124M0.25 GB~3.5 GB~2.5 GB
GPT-2 XL1.5B3 GB~24 GB~18 GB
LLaMA-7B7B14 GB~112 GB~80 GB
LLaMA-70B70B140 GB~1.1 TB~800 GB

These numbers assume a single batch step with B=1 for training estimates (activation memory excluded). Real training at large batch sizes adds substantially to the activation column. The 70B row makes clear why you cannot train frontier-scale models without ZeRO sharding or tensor parallelism — no single GPU has 1 TB of memory.

Throughput & MFU

Tokens per second is your speed. MFU is how much of the GPU you're actually using.

To know whether your training run is efficient you need two numbers. Tokens/sec is the raw throughput — how many tokens the model processes per second. Model FLOPs Utilization (MFU) compares the FLOPs your model actually performs against the hardware's theoretical peak, telling you what fraction of the silicon you are extracting. The two metrics answer different questions: tokens/sec tells you how fast you are going; MFU tells you whether you could go faster.

tokens_per_step = batch_size × seq_len × accum_steps flops_per_step = 6 × N × tokens_per_step # fwd + bwd (≈ 2×+1× = 3× fwd = 6×N×T) achieved_flops = flops_per_step / step_time_seconds MFU = achieved_flops / hardware_peak_flops Example — 124M model, B=32, S=1024, accum=4, hardware=10 TFLOP/s (MPS): tokens/step = 32 × 1024 × 4 = 131,072 flops/step = 6 × 124M × 131,072 ≈ 9.7 × 10¹⁰ step_time = ~2 s (MPS estimate) achieved = 9.7×10¹⁰ / 2 ≈ 4.9×10¹⁰ FLOP/s = 0.049 TFLOP/s MFU = 0.049 / 10 ≈ 0.5% ← tiny model, hardware barely saturated

Well-tuned large training runs on A100 clusters reach 40–55% MFU. Small models on a laptop are often below 1% because a 124M model's matrix multiplications are too small to saturate the hardware. MFU is the single best efficiency metric:

  • Low MFU + slow run: software-bound. Your data loader, small batch size, unfused ops, or Python overhead are bottlenecks — fixable for free.
  • High MFU + slow run: hardware-bound. You are extracting most of what the hardware can do; the only fix is more or faster hardware.

This is the exact same throughput lens you will apply to inference serving in Week 4, where MFU becomes the key metric for comparing KV-cache implementations and batching strategies.

N_params = sum(p.numel() for p in model.parameters())
tokens_per_step = cfg.block_size * batch_size * accum_steps

# Approximate device peak FLOP/s (fill in your hardware's spec).
PEAK_FLOPS = {"cuda": 312e12, "mps": 10e12, "cpu": 0.4e12}.get(device, 1e12)

def mfu(step_time_s):
    flops = 6 * N_params * tokens_per_step
    return (flops / step_time_s) / PEAK_FLOPS

t0 = time.time()
loss, lr, gnorm = train_step(step, max_steps)
dt = time.time() - t0
tps = tokens_per_step / dt
print(f"step {step}: {tps:,.0f} tok/s | MFU {mfu(dt)*100:.2f}%")
Checkpointing & Resumption

Assume the run will crash. Make sure it can resume exactly.

Real runs take hours to months and will be interrupted — preemption, OOM, node failure, or a power outage. A truly resumable checkpoint must capture everything needed to continue as if nothing happened. Anything you forget means the resumed run starts from a subtly different state, producing a discontinuous loss curve and undermining your experimental results.

What to save

ComponentWhy it matters on resume
Model weights (state_dict)The obvious part — where the parameters are.
Optimizer state (opt.state_dict())Adam's m and v moments; without them the first steps after resume use wrong estimates.
GradScaler state (scaler.state_dict())The current loss scale; wrong scale causes instability.
Step counterDetermines the LR schedule position.
Best val lossPreserves the "save best" logic across restarts.
Config / hyperparametersDocuments what settings produced this checkpoint.
RNG state (torch.get_rng_state())Exact reproducibility; resumed batch order matches original.
def save_ckpt(path, step, best_val):
    torch.save({
        "model":     model.state_dict(),
        "optimizer": opt.state_dict(),
        "scaler":    scaler.state_dict(),   # omit if BF16
        "step":      step,
        "best_val":  best_val,
        "config":    cfg,
        "rng":       torch.get_rng_state(),
    }, path)

def resume_ckpt(path):
    ck = torch.load(path, map_location=device, weights_only=False)
    model.load_state_dict(ck["model"])
    opt.load_state_dict(ck["optimizer"])
    scaler.load_state_dict(ck["scaler"])
    torch.set_rng_state(ck["rng"])
    return ck["step"], ck.get("best_val", float("inf"))

Operational habits

Save periodically — every N steps or every M minutes, not just at the end. Save the best checkpoint by validation loss in a separate file from the latest checkpoint. Frontier labs typically keep the last 5–10 checkpoints (a ring buffer) so they can roll back past a loss spike. The checkpoint is also the artifact your inference engine loads: the format and naming convention you establish now carries forward to deployment.

Debugging Training Runs

Loss spikes, NaNs, and divergence — how to diagnose them fast.

You will eventually launch a training run that goes wrong. The loss plateaus, or it spikes and recovers, or it spikes and does not recover, or it drifts to NaN. Knowing where to look first saves hours of GPU time. Here is the diagnostic playbook, in order of frequency.

SymptomMost likely causeFix
Loss immediately diverges to NaN/InfLR too high, or gradient norm spike at step 0Reduce LR 10×; add warmup; check for NaN in the data
Gradients constantly at clip norm (≥ 1.0)LR too high for the current phaseLower peak LR or shorten warmup
Isolated loss spike then recoveryOne bad batch; scaler working correctlyNormal; monitor frequency; if frequent, lower LR
Loss plateaus earlyLR too low, or data shuffle bug (seeing same batches)Check LR schedule; verify data loading seeds
Val loss rises while train loss fallsOverfitting; model too large for dataIncrease weight decay, add dropout, or get more data
Loss diverges after checkpoint restoreOptimizer state or scaler state not saved/restoredSave and restore the full checkpoint dict
NaN only on specific tokensEmbedding lookup overflow, or missing pad-maskCheck token id range; verify causal mask
Throughput drops mid-runData loader falls behind; GPU waits on CPUIncrease num_workers; pre-tokenize; memory-map data

The single most useful debugging instrument is a dashboard plotting: (1) train loss, (2) val loss, (3) gradient norm (pre-clip), (4) current learning rate, and (5) tokens/sec — all as a function of step. Anomalies in any one of these usually point directly to the cause. Gradient norm spikes before a loss spike are particularly telling — they often appear 1–2 steps before the loss reacts, giving you early warning.

# Minimal logging — log these at every eval interval.
metrics = {
    "step":       step,
    "train_loss": train_loss,
    "val_loss":   val_loss,
    "lr":         current_lr,
    "grad_norm":  grad_norm,
    "tokens_sec": tokens_per_step / step_time,
    "mfu_pct":    mfu(step_time) * 100,
}
print(" | ".join(f"{k}: {v:.4g}" for k, v in metrics.items()))
Exercises

Ten exercises, all in the notebook.

Companion notebook: day-10-training-loop.ipynb.

  1. AdamW vs SGD. Train the Day 9 toy GPT for 1000 steps with SGD, SGD+momentum, and AdamW. Overlay the loss curves and report final validation loss for each.
  2. LR schedule. Implement lr_at, plot the warmup+cosine curve, and compare against a constant-LR baseline. How many steps does warmup need to matter?
  3. Mixed precision speedup. Time 200 steps with and without autocast. Report speedup and peak memory with torch.cuda.max_memory_allocated() (or torch.mps.current_allocated_memory() on Apple Silicon).
  4. BF16 vs FP16 stability. Force amp_dtype=torch.float16 with GradScaler(enabled=False) and watch instability; re-enable and confirm it stabilizes.
  5. Gradient clipping. Log the pre-clip gradient norm every step and plot it. Set peak_lr=3e-3 and observe the clipping frequency increase.
  6. Accumulation equivalence. Train batch-64/accum-1 vs batch-16/accum-4. Overlay loss curves to confirm they track each other (up to shuffle noise).
  7. Memory budget calculator. Write a function training_memory_gb(N, B, S, L, D) that estimates total GPU memory needed and print the breakdown for a few model sizes from the table above.
  8. MFU meter. Compute tokens/sec and MFU for your hardware. Look up the peak FLOP/s spec and report your MFU. Why is it low for a 10M parameter model?
  9. Crash & resume round-trip. Save a checkpoint at step 1000 with assertions on val loss. Simulate a crash (delete the model from memory), resume, continue to step 2000, and verify the continued loss curve is smooth.
  10. Gradient checkpointing. Wrap one Transformer block in torch.utils.checkpoint.checkpoint. Measure the change in peak memory and step time. Is the tradeoff worth it at batch size 16? What about batch size 64?
Self-Check

Ten questions before moving on.

Close the page and answer from memory. If you can't, re-read the relevant section.

  1. List the seven operations of a production training step in the correct order, including when to unscale before clipping.
  2. What are the update equations for AdamW? How does it differ from Adam in exactly one line?
  3. Why does Adam need 2× extra memory per parameter compared to plain SGD? For a 7B model in FP32, how many GB is the optimizer state alone?
  4. Why does FP16 need a GradScaler but BF16 doesn't? What is the key bit-layout difference?
  5. Write the effective-batch-size formula with gradient accumulation and multiple GPUs.
  6. Why divide the loss by accum_steps during accumulation, and what goes wrong if you forget?
  7. Why warm up the learning rate instead of starting at the peak? What optimizer-state property makes warmup especially important for Adam?
  8. What is gradient checkpointing and what is the memory vs compute tradeoff? When should you use it?
  9. Define MFU. What does a slow run with low MFU imply versus a slow run with high MFU?
  10. List every component that a fully-resumable checkpoint must contain. Which one is most commonly forgotten?

"None of this changes the model. It changes how fast, how cheaply, and how reliably you can teach it."

Day 10 · Production training
Further Reading

Go deeper.

The production-training canon.

YouTube · 4 hr

Karpathy — Reproduce GPT-2 (124M)

Mixed precision, clipping, schedules, accumulation, MFU — all of today, in code, with live commentary.

Watch on YouTube
Repo · Karpathy

nanoGPT — train.py

The reference loop. Every technique from today appears in ~300 lines of readable Python.

View source
Paper · 2019

Loshchilov & Hutter — Decoupled Weight Decay Regularization (AdamW)

The paper that showed Adam's weight decay coupling is a bug and introduced AdamW.

Open paper
Docs · PyTorch

Automatic Mixed Precision

The autocast + GradScaler API with BF16/FP16 details and worked examples.

Read docs
Paper · 2017

Micikevicius et al. — Mixed Precision Training

The original loss-scaling / FP16 master-weights paper; explains the GradScaler need from first principles.

Open paper
Paper · 2016

Loshchilov & Hutter — SGDR (cosine restarts)

Cosine annealing with warm restarts — the origin of the cosine schedule now standard in LLM training.

Open paper
Paper · 2014

Kingma & Ba — Adam

The original Adam paper. The bias-correction math and intuition for adaptive moment estimation.

Open paper
Paper · 2019

Chen et al. — Training Deep Nets with Sublinear Memory Cost

The foundational paper on activation checkpointing — explains the O(√L) memory tradeoff with depth.

Open paper