Yesterday's loop was minimal and legible. Today we add everything that separates a toy run from a real one: a deep dive into optimizers (SGD through AdamW), learning-rate warmup and cosine decay, mixed precision, gradient accumulation, gradient checkpointing, a full memory budget breakdown, throughput measurement, robust checkpointing, and a practical debugging guide. The same techniques scale from your laptop to a thousand-GPU cluster.
Day 9's loop works, but it leaves performance and stability on the table. A production training run is faster, more memory-efficient, and far more robust against the things that go wrong over hours or days of training. None of today's techniques change the model or the objective — they change how efficiently and reliably you optimize it. Every one of them is standard in nanoGPT, Megatron, and the recipe behind every open model you have heard of.
This lesson also matters directly for inference. Mixed precision is the same numerics story as quantization (Week 4). Gradient checkpointing trades compute for memory — the same tradeoff you will see in KV-cache management. Throughput and MFU are the same efficiency lens you will apply to serving. And the checkpoint you write here is the artifact your inference engine loads on Day 27. Training and inference share far more machinery than they appear to.
zero_grad to opt.step(), with every technique in the right order.autocast + GradScaler correctly.The Day 9 loop did four things per step: zero gradients, forward pass, backward pass, optimizer step. A production loop does seven, and the order is not arbitrary. Missing a step or swapping two produces subtle bugs that can silently corrupt training for thousands of steps before you notice anything wrong.
Two habits you must wire in from the start. First, always split your data before training and measure held-out validation loss at fixed intervals — it is the only honest measure of generalization. Second, seed everything before creating the model and data loaders: torch.manual_seed(42), random.seed(42), numpy.random.seed(42). Seeding makes runs reproducible, which is essential when debugging or comparing two configs. Save the seed in your checkpoint so a resumed run can restore exact RNG state.
# Deterministic setup — do this before anything else.
import random, numpy as np, torch
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
# Data split.
n = int(0.9 * len(tokens)) # 90/10 split
train_data, val_data = tokens[:n], tokens[n:]
An optimizer is the algorithm that turns gradients into weight updates. The right choice affects convergence speed, final loss, and memory consumption — and the memory impact is large enough to break your budget on a big model. Let's build up from first principles.
Gradient descent is conceptually simple: move the weights in the direction that reduces the loss, scaled by a step size (learning rate).
Plain SGD requires zero extra memory beyond the weights themselves — one scalar per parameter. The problem is that the gradient on a single mini-batch is noisy; the step direction changes erratically, making convergence slow and sensitive to the learning rate.
Momentum smooths the gradient signal by maintaining a running average — the "velocity" — of past gradients. Think of it as a ball rolling down a loss surface that accumulates speed in consistent directions and dampens noise.
Momentum adds one extra state tensor (v) per parameter — doubling memory vs. plain SGD. For a 7B-parameter model in FP32 that is an extra 28 GB. Already significant.
Adam (Adaptive Moment Estimation) maintains two running averages per parameter: the first moment m (mean of gradients, like momentum) and the second moment v (mean of squared gradients, a per-parameter estimate of curvature). The update divides by the square root of v, giving each parameter its own effective learning rate — parameters with consistently large gradients get smaller steps, and parameters with small gradients get larger steps.
The bias corrections (dividing by 1 − βᵗ) matter most in the early steps when the exponential averages are initialized at zero and would otherwise underestimate the true moment. By step ~100 they are negligible. Adam converges faster than SGD on almost every language model task. The catch: two extra state tensors per parameter means Adam's optimizer state alone is 2× the weight size. For a 7B model in FP32 that is 56 GB of optimizer state on top of the 28 GB of weights.
Regular Adam with weight decay applies the decay inside the gradient: g_eff = g + λ·w. This causes the adaptive scaling to also shrink the weight-decay term — conceptually wrong and empirically worse. AdamW decouples weight decay: it applies the L2 penalty directly to the weights, after the adaptive update:
This one line is the only difference between Adam and AdamW, but it matters. Decoupling weight decay makes the regularization strength independent of the per-parameter learning rate, which is exactly what you want: rare parameters (with small v̂) get large adaptive steps and strong regularization, without the two fighting each other. AdamW is the default optimizer for every modern LLM training recipe. Use weight_decay=0.1 on weight matrices; set it to 0.0 on biases and layer-norm parameters (they should not be decayed).
β₁ (default 0.9): how quickly the first-moment estimate tracks the current gradient. Lower values make it more reactive; higher values smooth more. Almost nobody changes this. β₂ (0.95 for LLMs, 0.999 for other tasks): controls the second moment. A lower β₂ like 0.95 makes the adaptive scaling react faster to gradient magnitude changes — Karpathy uses 0.95 for GPT-2 reproduction. ε (1e-8): prevents division by zero; rarely matters unless you have consistently zero-variance parameters. weight_decay (0.1): the L2 regularization strength. Too high shrinks weights toward zero and hurts learning; too low and the model overfits. 0.1 is the sweet spot for most LLMs.
This is one of the most important practical points in the whole lesson. At inference time you only need the weights. During training you also maintain optimizer state. For AdamW with FP32 masters:
That is why inference memory ≈ model size but training memory ≈ 8× model size (before you add activation memory). This factor-of-8 gap is the single most common source of surprise when people first try to train a model they can easily serve. We will account for the activations in the Memory Budget section below.
# AdamW: apply weight decay only to weights, not biases or norms.
decay_params = [p for n, p in model.named_parameters()
if p.dim() >= 2 and p.requires_grad] # matrices
no_decay_params = [p for n, p in model.named_parameters()
if p.dim() < 2 and p.requires_grad] # biases, norms
opt = torch.optim.AdamW([
{"params": decay_params, "weight_decay": 0.1},
{"params": no_decay_params, "weight_decay": 0.0},
], lr=6e-4, betas=(0.9, 0.95), eps=1e-8)
Adam's optimizer state is exactly 2× the parameter count in bytes — because m and v are the same dtype as the weights. For a 70B-parameter model in FP32, that is 560 GB of optimizer state alone. This is why optimizer-state sharding (ZeRO Stage 2) was one of the most important distributed-training innovations of the past five years — it splits that state evenly across GPUs instead of replicating it.
A constant learning rate is rarely optimal. Two ideas, used together, define the standard for LLMs. Warmup: start the LR near zero and ramp it up linearly over the first few hundred to few thousand steps. Early in training the model's weights are random, the gradients are large and noisy, and the optimizer's second-moment estimate v is initialized to zero so Adam's denominator starts very small — producing artificially large steps. Warmup eases past this fragile phase. Cosine decay: after warmup, decay the LR following a cosine curve down to a small final value (often 10% of peak). The cosine shape spends most of training at a high, productive LR and then gently anneals for a clean convergence.
Worked numbers for a GPT-2 124M reproduce: peak_lr = 6e-4, min_lr = 6e-5 (10%), warmup = 715 steps (~0.5% of 143k total steps). Notice the min_lr is 10× smaller than the peak — the final cosine value, not zero. Decaying all the way to zero throws away some of the progress made during training.
import math
def lr_at(step, peak_lr, warmup, total, min_lr):
if step < warmup: # linear warmup
return peak_lr * (step + 1) / warmup
if step > total: # past the schedule
return min_lr
ratio = (step - warmup) / (total - warmup) # [0, 1]
coeff = 0.5 * (1.0 + math.cos(math.pi * ratio))
return min_lr + coeff * (peak_lr - min_lr)
# Apply every step before opt.step():
for g in opt.param_groups:
g["lr"] = lr_at(step, peak_lr=6e-4, warmup=715, total=143000, min_lr=6e-5)
Why not use a PyTorch scheduler? You can, but writing lr_at explicitly makes the schedule fully transparent — you always know the exact LR at every step, and resuming from a checkpoint is trivial (just pass the step counter). Scheduler state can drift or reset unexpectedly during multi-node restarts.
By default PyTorch stores and computes in 32-bit floats (FP32). But modern accelerators have dedicated hardware (NVIDIA Tensor Cores, Apple Matrix Engines) that run 16-bit matrix multiplies several times faster, and 16-bit tensors take half the memory. Mixed precision training runs the heavy operations in 16-bit while keeping a 32-bit master copy of the weights and doing gradient accumulation and optimizer math in 32-bit, capturing most of the speed and memory win with negligible accuracy loss.
There are two 16-bit floating-point formats and the difference is consequential. Both use 16 bits total. The split between exponent and mantissa bits determines the tradeoff between range (which values can be represented) and precision (how finely they are subdivided).
| Format | Bits (S/E/M) | Max value | Precision | Loss scaling? | Best for |
|---|---|---|---|---|---|
| FP32 | 1/8/23 | 3.4×10³⁸ | High | No | Master weights, optimizer state |
| FP16 | 1/5/10 | 65,504 | Good | Yes | Older GPUs (V100), Apple A-series |
| BF16 | 1/8/7 | 3.4×10³⁸ | Lower | No | A100/H100, Apple M-series (preferred) |
BF16 keeps FP32's 8 exponent bits, so it has the same dynamic range and almost never overflows or underflows — you can use it without a gradient scaler. FP16 has only 5 exponent bits, so small gradients underflow to zero; it needs a GradScaler that multiplies the loss by a large factor before backprop and unscales the gradients afterward. Rule of thumb: use BF16 if your hardware supports it (Ampere/Hopper GPUs, Apple Silicon M-series); otherwise use FP16 with a scaler.
from torch.amp import autocast, GradScaler
# BF16 path (A100/H100/M-series): no scaler needed.
for xb, yb in batches:
with autocast(device_type="cuda", dtype=torch.bfloat16):
_, loss = model(xb, yb)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
# FP16 path (older GPUs): GradScaler prevents gradient underflow.
scaler = GradScaler()
for xb, yb in batches:
with autocast(device_type="cuda", dtype=torch.float16):
_, loss = model(xb, yb)
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
BF16 (bfloat16) was designed by Google Brain for TPUs precisely so that deep-learning code could drop FP32 to 16-bit without the loss-scaling gymnastics FP16 requires. It throws away mantissa precision but keeps the full exponent range — and it turns out neural nets care far more about range than precision. That single insight reshaped ML hardware design. Every major accelerator since 2020 (A100, H100, TPUv4, Apple M-series) has dedicated BF16 compute paths.
Large batches stabilize training, reduce variance in gradient estimates, and let you use a higher peak LR (the linear scaling rule: double the batch, roughly double the LR). But large batches may not fit in memory. Gradient accumulation decouples the compute batch (what fits in memory) from the optimizer batch (what you want to train on). You run several forward/backward passes, accumulating gradients in .grad, and only call opt.step() after accum_steps micro-batches. The effective batch size becomes:
For GPT-2 124M, Karpathy uses micro_batch=16, accum_steps=32, num_gpus=8 to reach the target of 4096 × 8 = 524,288 tokens per step. Each individual GPU only needs to fit 16 × 1024 = 16,384 tokens at once.
loss/4 to the backward pass. PyTorch adds gradients into .grad automatically. After all micro-batches, a single clip + optimizer step updates the weights — equivalent to a batch 4× larger.accum_steps = 4 # effective batch = micro_batch * accum_steps
opt.zero_grad(set_to_none=True)
for micro in range(accum_steps):
xb, yb = get_batch(train_data)
with autocast(device_type=device, dtype=torch.bfloat16):
_, loss = model(xb, yb)
loss = loss / accum_steps # CRITICAL: divide so sum == mean
loss.backward() # gradients accumulate in .grad
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
The one subtlety to burn into memory: divide the loss by accum_steps so that summing the per-micro-batch gradients yields the mean gradient, matching what a single large batch would produce. Forget this and your effective learning rate is silently accum_steps× too large — a bug that is very easy to miss because the loss still decreases, just noisily and to a worse optimum.
Occasionally a batch produces an enormous gradient — a rare token, a numerical edge case, an unlucky initialization region. Without protection, that single step can take a giant stride in parameter space, spike the loss, and sometimes never recover (you see the dreaded loss-goes-to-NaN). Gradient clipping rescales the whole gradient vector so its global norm never exceeds a threshold. The direction is preserved; only the magnitude is capped:
Clipping is cheap insurance. Almost every LLM training recipe uses max_norm = 1.0. Monitor the pre-clip gradient norm as a diagnostic signal: if it is usually well under 1.0 and occasionally spikes, clipping is doing its job. If it is constantly being clipped (norm consistently above 1.0), your learning rate is probably too high.
# After backward(), before step(). With FP16 scaler, unscale first.
if amp_dtype == torch.float16:
scaler.unscale_(opt) # converts fp16 grads to fp32 for correct norm
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt)
scaler.update()
# grad_norm is the pre-clip norm — log it every step as a health metric.
During the backward pass PyTorch needs the intermediate activations computed during the forward pass — they are used in the chain rule to compute each layer's gradient. By default all of them are stored simultaneously in GPU memory. For a deep Transformer this can be enormous: a 12-layer model with batch size 64 and sequence length 1024 might store several gigabytes of activations.
Gradient checkpointing (also called activation recomputation) breaks this tradeoff: instead of storing every activation, you store only the outputs of a subset of "checkpoint" layers and recompute the in-between activations during the backward pass when they are needed. This trades compute for memory: you run parts of the forward pass twice, but you only need to hold one block's worth of activations at any time.
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
# Wrap with checkpoint to recompute activations during backward.
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
Use gradient checkpointing when you are running out of activation memory and cannot reduce the batch size further. Do not use it when your run is already compute-bound — the 30% extra compute will hurt throughput significantly.
Training is a memory-budgeting problem. Unlike inference, where you only need to hold the weights and the KV cache, training simultaneously holds four components in GPU memory. Understanding each lets you predict whether a config fits and which component to attack when it does not.
The key takeaway: training needs approximately 8–16× more memory than inference for the same model. Inference only needs BF16 weights (2 bytes/param). Training needs FP32 master weights + FP32 gradients + FP32 optimizer state + activations = 12+ bytes/param plus activation memory that grows with batch size. This is why you can serve LLaMA-7B on a 16 GB GPU but cannot train it on a single machine without distributed optimizations like ZeRO (Day 11).
| Model scale | Params | Inference (BF16) | Training (FP32 AO) | Training + ckpt |
|---|---|---|---|---|
| GPT-2 small | 124M | 0.25 GB | ~3.5 GB | ~2.5 GB |
| GPT-2 XL | 1.5B | 3 GB | ~24 GB | ~18 GB |
| LLaMA-7B | 7B | 14 GB | ~112 GB | ~80 GB |
| LLaMA-70B | 70B | 140 GB | ~1.1 TB | ~800 GB |
These numbers assume a single batch step with B=1 for training estimates (activation memory excluded). Real training at large batch sizes adds substantially to the activation column. The 70B row makes clear why you cannot train frontier-scale models without ZeRO sharding or tensor parallelism — no single GPU has 1 TB of memory.
To know whether your training run is efficient you need two numbers. Tokens/sec is the raw throughput — how many tokens the model processes per second. Model FLOPs Utilization (MFU) compares the FLOPs your model actually performs against the hardware's theoretical peak, telling you what fraction of the silicon you are extracting. The two metrics answer different questions: tokens/sec tells you how fast you are going; MFU tells you whether you could go faster.
Well-tuned large training runs on A100 clusters reach 40–55% MFU. Small models on a laptop are often below 1% because a 124M model's matrix multiplications are too small to saturate the hardware. MFU is the single best efficiency metric:
This is the exact same throughput lens you will apply to inference serving in Week 4, where MFU becomes the key metric for comparing KV-cache implementations and batching strategies.
N_params = sum(p.numel() for p in model.parameters())
tokens_per_step = cfg.block_size * batch_size * accum_steps
# Approximate device peak FLOP/s (fill in your hardware's spec).
PEAK_FLOPS = {"cuda": 312e12, "mps": 10e12, "cpu": 0.4e12}.get(device, 1e12)
def mfu(step_time_s):
flops = 6 * N_params * tokens_per_step
return (flops / step_time_s) / PEAK_FLOPS
t0 = time.time()
loss, lr, gnorm = train_step(step, max_steps)
dt = time.time() - t0
tps = tokens_per_step / dt
print(f"step {step}: {tps:,.0f} tok/s | MFU {mfu(dt)*100:.2f}%")
Real runs take hours to months and will be interrupted — preemption, OOM, node failure, or a power outage. A truly resumable checkpoint must capture everything needed to continue as if nothing happened. Anything you forget means the resumed run starts from a subtly different state, producing a discontinuous loss curve and undermining your experimental results.
| Component | Why it matters on resume |
|---|---|
Model weights (state_dict) | The obvious part — where the parameters are. |
Optimizer state (opt.state_dict()) | Adam's m and v moments; without them the first steps after resume use wrong estimates. |
GradScaler state (scaler.state_dict()) | The current loss scale; wrong scale causes instability. |
| Step counter | Determines the LR schedule position. |
| Best val loss | Preserves the "save best" logic across restarts. |
| Config / hyperparameters | Documents what settings produced this checkpoint. |
RNG state (torch.get_rng_state()) | Exact reproducibility; resumed batch order matches original. |
def save_ckpt(path, step, best_val):
torch.save({
"model": model.state_dict(),
"optimizer": opt.state_dict(),
"scaler": scaler.state_dict(), # omit if BF16
"step": step,
"best_val": best_val,
"config": cfg,
"rng": torch.get_rng_state(),
}, path)
def resume_ckpt(path):
ck = torch.load(path, map_location=device, weights_only=False)
model.load_state_dict(ck["model"])
opt.load_state_dict(ck["optimizer"])
scaler.load_state_dict(ck["scaler"])
torch.set_rng_state(ck["rng"])
return ck["step"], ck.get("best_val", float("inf"))
Save periodically — every N steps or every M minutes, not just at the end. Save the best checkpoint by validation loss in a separate file from the latest checkpoint. Frontier labs typically keep the last 5–10 checkpoints (a ring buffer) so they can roll back past a loss spike. The checkpoint is also the artifact your inference engine loads: the format and naming convention you establish now carries forward to deployment.
You will eventually launch a training run that goes wrong. The loss plateaus, or it spikes and recovers, or it spikes and does not recover, or it drifts to NaN. Knowing where to look first saves hours of GPU time. Here is the diagnostic playbook, in order of frequency.
| Symptom | Most likely cause | Fix |
|---|---|---|
| Loss immediately diverges to NaN/Inf | LR too high, or gradient norm spike at step 0 | Reduce LR 10×; add warmup; check for NaN in the data |
| Gradients constantly at clip norm (≥ 1.0) | LR too high for the current phase | Lower peak LR or shorten warmup |
| Isolated loss spike then recovery | One bad batch; scaler working correctly | Normal; monitor frequency; if frequent, lower LR |
| Loss plateaus early | LR too low, or data shuffle bug (seeing same batches) | Check LR schedule; verify data loading seeds |
| Val loss rises while train loss falls | Overfitting; model too large for data | Increase weight decay, add dropout, or get more data |
| Loss diverges after checkpoint restore | Optimizer state or scaler state not saved/restored | Save and restore the full checkpoint dict |
| NaN only on specific tokens | Embedding lookup overflow, or missing pad-mask | Check token id range; verify causal mask |
| Throughput drops mid-run | Data loader falls behind; GPU waits on CPU | Increase num_workers; pre-tokenize; memory-map data |
The single most useful debugging instrument is a dashboard plotting: (1) train loss, (2) val loss, (3) gradient norm (pre-clip), (4) current learning rate, and (5) tokens/sec — all as a function of step. Anomalies in any one of these usually point directly to the cause. Gradient norm spikes before a loss spike are particularly telling — they often appear 1–2 steps before the loss reacts, giving you early warning.
# Minimal logging — log these at every eval interval.
metrics = {
"step": step,
"train_loss": train_loss,
"val_loss": val_loss,
"lr": current_lr,
"grad_norm": grad_norm,
"tokens_sec": tokens_per_step / step_time,
"mfu_pct": mfu(step_time) * 100,
}
print(" | ".join(f"{k}: {v:.4g}" for k, v in metrics.items()))
Companion notebook: day-10-training-loop.ipynb.
lr_at, plot the warmup+cosine curve, and compare against a constant-LR baseline. How many steps does warmup need to matter?autocast. Report speedup and peak memory with torch.cuda.max_memory_allocated() (or torch.mps.current_allocated_memory() on Apple Silicon).amp_dtype=torch.float16 with GradScaler(enabled=False) and watch instability; re-enable and confirm it stabilizes.peak_lr=3e-3 and observe the clipping frequency increase.training_memory_gb(N, B, S, L, D) that estimates total GPU memory needed and print the breakdown for a few model sizes from the table above.torch.utils.checkpoint.checkpoint. Measure the change in peak memory and step time. Is the tradeoff worth it at batch size 16? What about batch size 64?Close the page and answer from memory. If you can't, re-read the relevant section.
accum_steps during accumulation, and what goes wrong if you forget?"None of this changes the model. It changes how fast, how cheaply, and how reliably you can teach it."
The production-training canon.
Mixed precision, clipping, schedules, accumulation, MFU — all of today, in code, with live commentary.
Watch on YouTubeThe reference loop. Every technique from today appears in ~300 lines of readable Python.
View sourceThe paper that showed Adam's weight decay coupling is a bug and introduced AdamW.
Open paperThe autocast + GradScaler API with BF16/FP16 details and worked examples.
The original loss-scaling / FP16 master-weights paper; explains the GradScaler need from first principles.
Open paperCosine annealing with warm restarts — the origin of the cosine schedule now standard in LLM training.
Open paperThe original Adam paper. The bias-correction math and intuition for adaptive moment estimation.
Open paperThe foundational paper on activation checkpointing — explains the O(√L) memory tradeoff with depth.
Open paper