LLM Inference Engineer · Day 11
Day 11 · Week 2 · Training & Architectures
🌐

Distributed Training: Scaling Beyond One GPU

A 70B model does not fit on one GPU — not the weights, not the gradients, not the optimizer state. Today you master the four ways to split a training job across devices: data parallelism (DDP), sharded data parallelism (ZeRO/FSDP), tensor parallelism (Megatron-style), and pipeline parallelism — plus the collective communication primitives that make them work. These are not just training tools: tensor and pipeline parallelism are how vLLM and TensorRT-LLM serve billion-parameter models at inference time.

Time~180 min
DifficultyHard
PrerequisiteDay 10
Why This Lesson

Two walls: a memory wall and a time wall. Each parallelism breaks one.

Every distributed-training technique exists to smash one of two walls. The memory wall: the model — weights, gradients, optimizer state, and activations — no longer fits on one GPU. The time wall: the model fits but training takes weeks on a single device, which is too slow to be practical. Data parallelism (DDP) attacks the time wall. ZeRO/FSDP, tensor parallelism, and pipeline parallelism attack the memory wall. Gradient checkpointing attacks a specific sub-problem of the memory wall — activation memory.

You will likely never train a 70B model yourself, but as an inference engineer you must read the configs and understand the artifacts these systems produce: sharded checkpoints, tensor-parallel weight layouts, and the communication patterns that also appear in multi-GPU inference. The tensor parallelism you learn today is exactly how vLLM and TensorRT-LLM split a model across GPUs to serve it. Pipeline parallelism shows up in long-pipeline serving to hide prefill latency. Understanding these techniques at a mathematical level means you can reason about throughput, latency, and memory budgets for any model on any cluster — which is precisely the day-to-day job of an inference engineer.

Learning objectives

  1. Compute the full memory budget of a training run: weights, gradients, optimizer state, and activations — and identify which component dominates.
  2. Name the four collective communication primitives (all-reduce, all-gather, reduce-scatter, broadcast) and state what each does to a distributed tensor.
  3. Explain data parallelism (DDP), the ring all-reduce it uses, and the communication cost formula.
  4. Describe ZeRO stages 1/2/3 and FSDP — what each shards, the memory savings, and the communication overhead.
  5. Implement Megatron-style tensor parallelism for MLP and attention layers: column-parallel and row-parallel linear.
  6. Describe pipeline parallelism, the idle-time bubble, and how GPipe/1F1B micro-batching shrinks it.
  7. Combine DP × TP × PP into 3D parallelism and choose the right mix for a given model, GPU count, and cluster topology.
  8. Connect tensor parallelism and pipeline parallelism to multi-GPU inference: KV-cache sharding and serving throughput.
The Memory Budget

Weights are the small part. Optimizer state is the killer.

Before choosing a strategy, count what training actually stores. With mixed precision and AdamW, each parameter incurs memory in four distinct categories simultaneously. The standard accounting — the "16 bytes per parameter" rule — goes like this:

Per parameter, mixed-precision AdamW: BF16 weights ......... 2 bytes BF16 gradients ....... 2 bytes FP32 master weights .. 4 bytes ┐ FP32 Adam moment m ... 4 bytes │ optimizer state = 12 bytes/param FP32 Adam variance v . 4 bytes ┘ ───────────────────────────────────── TOTAL ≈ 16 bytes / parameter (before activations) Concrete numbers: 7B model: 7 × 10⁹ × 16 ≈ 112 GB → does NOT fit on one 80 GB H100 70B model: 70 × 10⁹ × 16 ≈ 1.12 TB → needs 14+ H100s just for parameters+optimizer

Three lessons jump out. First, the weights themselves are only 2/16 = 1/8 of the footprint. The FP32 master copy plus Adam's two moment vectors consume 12 of those 16 bytes — the optimizer state alone is 75% of the budget. Second, even a 7B model overflows a single 80 GB GPU once you account for gradients and optimizer state, before activations. Third, inference needs only the 2-byte BF16 weights plus a KV cache — which is why you can serve a 70B model on a cluster that cannot train it.

Why FP32 master weights?

Adam accumulates small gradient updates over many steps. If you accumulate in BF16, rounding errors compound: updates smaller than the BF16 resolution (~0.007 relative) simply vanish. The FP32 master copy preserves precision across millions of steps. The BF16 working copy is just what gets loaded into the GPU's tensor cores for the forward/backward computation.

Activations — the fourth cost

On top of the 16-bytes-per-parameter cost, the forward pass stores activations for the backward pass. Activation memory scales roughly as batch × sequence_length × hidden × num_layers × bytes_per_element. For a 7B model with batch 16, sequence 2048, hidden 4096, 32 layers, in BF16 the residual-stream activations come to 16 × 2048 × 4096 × 32 × 2 ≈ 8.6 GB. That is only the residual term; also storing attention scores and FFN intermediates pushes the true footprint several times higher (the companion notebook uses a ~12× estimate ≈ 100 GB at these settings). At sequence 32768 the residual term alone becomes 136 GB. This is why gradient checkpointing exists, and why sequence length has historically been the binding constraint for long-context training.

Per-Parameter Memory in Mixed-Precision AdamW (16 bytes total) BF16 weights 2 B BF16 gradients 2 B FP32 master weights 4 bytes FP32 Adam moment m 4 bytes FP32 Adam variance v 4 bytes optimizer state = 12 bytes (75% of budget) Memory needed (GB) for different model sizes on a single H100 (80 GB): 125M → 2 GB 1.3B → 20.8 GB 7B → 112 GB 70B >1 TB 80 GB H100
The 16-byte-per-parameter breakdown for mixed-precision AdamW. The optimizer state (red, 12 bytes) is 75% of the budget — it is what sharding strategies are actually eliminating. The dashed line marks an 80 GB H100: a 7B model already overflows it before activations.
Collective Communication Primitives

Four operations. Every parallelism is built from them.

Distributed training is I/O between GPUs. Before the parallelisms, you need to understand the four collective operations they compose. Think of N GPUs, each holding a tensor chunk. The collectives specify what happens to those chunks.

PrimitiveInput (each GPU)Output (each GPU)Net data moved
BroadcastRoot has full tensor; others emptyAll have full tensor (root's copy)M × (N-1) / N per GPU
All-reduceEach holds a full tensor, values differAll hold the element-wise sum/mean2M × (N-1) / N per GPU
Reduce-scatterEach holds a full tensorEach holds 1/N-th of the sum (its shard)M × (N-1) / N per GPU
All-gatherEach holds 1/N-th of a tensorAll hold the full concatenated tensorM × (N-1) / N per GPU

One key identity: all-reduce = reduce-scatter + all-gather. This decomposition is exactly what the ring algorithm exploits and what ZeRO leverages to shard efficiently. Memorize this: reduce-scatter gives every GPU its portion of the sum; all-gather then broadcasts those portions so everyone has the full result.

Ring All-Reduce GPU 0 g0=[a,b,c,d] GPU 1 g1=[e,f,g,h] GPU 2 g2=[i,j,k,l] GPU 3 g3=[m,n,o,p] 2 passes (N-1 steps each) Result: every GPU holds mean(g0+g1+g2+g3) All-Reduce = Reduce-Scatter + All-Gather Phase 1: Reduce-Scatter [a+…] [b+…] [c+…] [d+…] Each GPU receives its 1/N shard of the summed gradient M×(N-1)/N data Phase 2: All-Gather [a+…] [b+…] [c+…] [d+…] Each shard broadcast to all GPUs → everyone has full result Every GPU holds the full averaged gradient total data = 2 × M × (N-1)/N (same as all-reduce) ZeRO stops after phase 1: each GPU keeps only its shard
Left: ring all-reduce circulates gradient chunks in two passes so every GPU accumulates the global sum without any single bottleneck. Right: all-reduce decomposes into reduce-scatter (each GPU gets its 1/N shard of the sum) then all-gather (shards are broadcast to all). ZeRO-2/3 exploit this by stopping after reduce-scatter and keeping only the shard each GPU is responsible for.

Bandwidth cost of all-reduce

For a gradient tensor of size M bytes and N GPUs, a ring all-reduce moves exactly 2M(N-1)/N ≈ 2M bytes per GPU. This is independent of N for large N — scaling from 8 to 64 GPUs barely changes the per-GPU communication volume. The bottleneck is therefore always the link bandwidth, not the GPU count, which is why NVLink (≈600 GB/s bidirectional on A100, ≈900 GB/s on H100) vs InfiniBand (200–400 GB/s) matters enormously.

Data Parallelism (DDP)

Replicate the model, split the batch, all-reduce the gradients.

Data parallelism (DP) is the simplest and most common strategy. Every GPU holds a complete copy of the model. The global batch is split across GPUs — each processes its own shard independently in the forward and backward passes. Then an all-reduce averages every gradient across all GPUs so all replicas apply the identical optimizer step and stay in sync. From the optimizer's perspective it is as if one GPU processed the full batch.

PyTorch implements this as DistributedDataParallel (DDP). The key implementation insight: DDP does not wait until the full backward pass finishes before starting communication. It begins all-reducing each parameter's gradient as soon as that gradient is ready — overlapping computation and communication so the all-reduce is almost free on fast links.

Each GPU: identical full model weights · different batch shard · forward+backward independently GPU 0 full model (W) batch[0 : B/4] fwd → bwd → grad g₀ GPU 1 full model (W) batch[B/4 : B/2] fwd → bwd → grad g₁ GPU 2 full model (W) batch[B/2 : 3B/4] fwd → bwd → grad g₂ GPU 3 full model (W) batch[3B/4 : B] fwd → bwd → grad g₃ all-reduce: g_avg = mean(g₀, g₁, g₂, g₃) 2M(N-1)/N bytes per GPU · overlapped with backward in DDP After all-reduce: every GPU holds g_avg → applies identical optimizer step → weights stay in sync. Limitation: each GPU still needs the FULL model in memory. Solves speed, not memory.
Data parallelism (DDP). Each GPU runs the full model on its batch shard, then an all-reduce averages every gradient so all replicas apply the identical update. Throughput scales linearly with GPU count; per-GPU memory is unchanged.
# DDP — launch with: torchrun --nproc_per_node=4 train.py
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

model = GPT(cfg).cuda()
model = DDP(model, device_ids=[local_rank])   # all-reduce hooked into backward automatically
# Training loop is otherwise unchanged — DDP intercepts .backward().

Scaling efficiency and the communication wall

With perfect hardware, doubling GPU count halves training time. In practice, scaling efficiency (actual speedup / ideal speedup) falls below 1 for two reasons. First, the all-reduce volume is proportional to the number of parameters. A 7B model's gradients weigh 14 GB in BF16; even at 600 GB/s NVLink that is ~23 ms per step. Second, small models spend a larger fraction of time communicating relative to computing, so they scale worse. The ratio that matters is compute-to-communication: larger models, bigger batches, and faster interconnects all improve it.

Scaling efficiency ≈ T_compute / (T_compute + T_allreduce) T_allreduce = (2 × 2 × P bytes) / bandwidth = (4 × 7e9) / (600e9 B/s) ≈ 47 ms [7B, NVLink 600 GB/s] If a training step on 8 GPUs takes 200 ms of compute: efficiency ≈ 200 / (200 + 47) ≈ 81% (good, but not perfect)
ZeRO & FSDP

Don't replicate what you can shard. Eliminate redundancy one layer at a time.

Plain data parallelism wastes memory: N GPUs store N identical copies of the optimizer state, gradients, and weights. The Zero Redundancy Optimizer (ZeRO, DeepSpeed) and PyTorch's Fully Sharded Data Parallel (FSDP) eliminate that redundancy by sharding those tensors across the data-parallel group and gathering them on demand. The name "ZeRO" is literal — there is zero redundancy. Each byte of optimizer state, gradient, or weight is stored by exactly one GPU at steady state.

ZeRO proceeds in three stages of increasing aggressiveness. Each stage shards one more category of tensors while preserving the same mathematical result as DDP — the gradient flowing to the optimizer is identical; only where it lives differs.

Stage Weights (2B) Gradients (2B) Opt State (12B) Per-GPU (7B, 8×GPUs) ZeRO-0 (DDP) replicated ×8 replicated ×8 replicated ×8 112 GB full budget per GPU ZeRO-1 replicated ×8 replicated ×8 SHARDED ÷ 8 1.5 B/param 38.5 GB ~2.9× reduction ZeRO-2 replicated ×8 SHARDED ÷ 8 0.25 B/param SHARDED ÷ 8 1.5 B/param 26.25 GB ~4.3× reduction ZeRO-3 (= FSDP) SHARDED ÷ 8 0.25 B/param SHARDED ÷ 8 0.25 B/param SHARDED ÷ 8 1.5 B/param 14 GB ~8× reduction (= 112/8)
ZeRO stages for a 7B model on 8 GPUs (112 GB baseline). Sharded tensors are shown in gold/red; replicated in white. ZeRO-3/FSDP shards everything, achieving near-linear memory scaling: 112 GB / 8 = 14 GB per GPU — fitting on a single consumer-grade GPU.
StageShards7B on 8 GPUs (GB/GPU)Extra comms vs DDPPyTorch API
ZeRO-0 (DDP)Nothing1121× all-reduceDDP
ZeRO-1Optimizer state38.51× all-reduce (same)DeepSpeed / manual
ZeRO-2+ Gradients26.251× reduce-scatterDeepSpeed / manual
ZeRO-3 / FSDP+ Parameters14all-gather per-layer fwd+bwdFSDP

The mechanism for ZeRO-3/FSDP: each GPU permanently stores only its 1/N shard of each parameter tensor. When a forward pass reaches a layer, an all-gather reconstitutes the full weights just-in-time, computes the layer output, then frees the gathered weights. In the backward pass the same all-gather happens again, followed by a reduce-scatter that distributes gradient shards. This is why ZeRO-3 communicates more than DDP: instead of one all-reduce per step, it does an all-gather + reduce-scatter per layer, per forward and backward. The price of memory is communication volume — roughly 3× DDP's bandwidth for ZeRO-3. On fast NVLink this is usually fine; on slow interconnects it can dominate.

ZeRO-3 and FSDP are the same mathematical idea discovered independently — Microsoft's DeepSpeed shipped ZeRO (2019), PyTorch built FSDP natively (2021). The difference is packaging: FSDP is baked into PyTorch with no external dependencies. Both achieve near-linear memory scaling with GPU count. Meta used FSDP to train LLaMA-2 70B on 2048 A100 GPUs; without sharding, each GPU would need to hold the full 1.12 TB of parameters + optimizer state.

# FSDP — PyTorch-native ZeRO-3. Wraps the model; shards params/grads/optimizer state.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy,  # shard at transformer block boundaries
    mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
)
# Each GPU stores model_size / world_size; layers all-gathered on demand.
Tensor Parallelism

Split the matrix multiply inside a layer across GPUs.

ZeRO/FSDP still requires gathering the full layer weights during the forward pass — it avoids storing redundant copies at rest, but each GPU processes one layer at a time. Tensor parallelism (TP) goes further: it keeps the weights sharded even during computation by splitting the matrix multiplication across GPUs. The math works out cleanly because matrix multiplies can be partitioned by either rows or columns of the weight matrix.

Column-parallel and row-parallel linear

For a linear Y = XW where X ∈ ℝ^(B×d_in) and W ∈ ℝ^(d_in×d_out):

Column-parallel (partition W by output dim): GPU_k gets columns W[:,k·d_out/N : (k+1)·d_out/N] Each GPU computes Y_k = X W_k (partial output slice) Result: each GPU holds Y_k (a slice of the full output) — need all-gather to assemble full Y Row-parallel (partition W by input dim — dual): GPU_k gets rows W[k·d_in/N : (k+1)·d_in/N, :] Each GPU gets the corresponding slice of X Each GPU computes partial Y_k = X_k W_k (partial sum) Result: each GPU holds a partial output sum — need all-reduce to sum the partial outputs

In Megatron-LM, the MLP is split as column-parallel Linear → GeLU → row-parallel Linear. The column-parallel layer shards by output dimension (each GPU computes part of the hidden state); no communication is needed between them because GeLU is elementwise. The row-parallel layer shards by input dimension and finishes with an all-reduce to sum partial contributions. Net result: two communications per MLP (one all-gather + one all-reduce), and the weight matrices are split N-way across GPUs.

Megatron-LM MLP Tensor Parallelism (N=2 GPUs) Input X [B, d_in] broadcast to all GPUs W1 (column-parallel) W1[:,0:h/2] GPU 0 [d_in, h/2] W1[:,h/2:h] GPU 1 [d_in, h/2] GeLU (local) GeLU GeLU no comms! W2 (row-parallel) W2[0:h/2,:] GPU 0 [h/2, d_out] W2[h/2:h,:] GPU 1 [h/2, d_out] all-reduce sum partials Output Y [B, d_out] 1 all-reduce per MLP layer Needs NVLink — intra-node only Weights halved on each GPU
Megatron-style tensor parallelism for a two-layer MLP. W1 is column-partitioned (each GPU gets half the output neurons); GeLU runs locally; W2 is row-partitioned, finishing with one all-reduce to sum partial results. Net: weights are split N-way, one all-reduce per layer — must run over NVLink.

Attention is tensor-parallel too

For multi-head attention, TP is even cleaner: each GPU handles a disjoint subset of attention heads. With 32 heads on 4 GPUs, GPU 0 does heads 0–7, GPU 1 does heads 8–15, etc. Each head is fully self-contained (no cross-head communication during attention), so the only collective is one all-reduce after projecting the concatenated head outputs through the output matrix W_O. In practice the QKV projection is also column-split and the output projection is row-split, matching the MLP pattern.

Why TP must stay intra-node

Every transformer layer requires at least one all-reduce in the forward pass and one in the backward. For a model with 32 layers and TP degree 8, that is 64 all-reduces per training step. At a gradient size of ~50 MB per layer and 600 GB/s NVLink, each all-reduce takes ~0.17 ms — acceptable. At 12 GB/s InfiniBand, the same operation takes ~8 ms, making TP across nodes prohibitively slow. Rule: TP always runs within a node, over NVLink.

For inference this rule carries over directly: vLLM's tensor parallelism groups GPUs within a node. If you have two nodes, you use two separate TP groups (one per node) and route requests independently.

Pipeline Parallelism

Assign contiguous layer ranges to GPUs. Stream micro-batches to hide the bubble.

Pipeline parallelism (PP) partitions the model's layer stack vertically: GPU 0 holds layers 1–8, GPU 1 holds layers 9–16, and so on. A batch enters GPU 0, flows through its layers, and the intermediate activations at the stage boundary (called the pipeline communication) are sent to GPU 1, which continues. No GPU needs to store all layers simultaneously — each GPU's model memory is 1/P of the full model, where P is the pipeline depth.

The bubble problem

The naive pipeline is embarrassingly serial: GPU 1 must wait for GPU 0 to finish before it starts. During that wait, GPU 1 is idle. This idle time is called the pipeline bubble. In a pipeline of depth P, the naive bubble fraction is (P-1)/P — with 8 stages, 7/8 of time is wasted. Useless for production.

Naive bubble fraction = (P - 1) / P P=4: 75% idle P=8: 87.5% idle — catastrophic With M micro-batches (GPipe/1F1B): bubble fraction = (P - 1) / (M + P - 1) P=4, M=8: 3/11 ≈ 27% P=4, M=32: 3/35 ≈ 9% ← practical range

GPipe and 1F1B schedules

The fix is to split each batch into M micro-batches that flow through the pipeline in a staggered schedule. GPipe runs all micro-batches forward, then all backward. 1F1B (one-forward-one-backward) interleaves them: while the first micro-batch is in the backward pass on GPU 0, the second micro-batch's forward pass is running on GPU 1. 1F1B keeps pipeline memory lower than GPipe because it doesn't accumulate all forward-pass activations before starting the backward.

Stage time → Naive (M=1) GPU 0 GPU 1 GPU 2 GPU 3 F₀ idle (bubble) B₀ idle F₀ idle B₀ idle F₀ idle B₀ idle (bubble) F₀ B₀ Bubble = (P-1)/P = 75% 1F1B (M=4 micro-batches) GPU 0 GPU 1 GPU 2 GPU 3 F1 F2 F3 F4 B1 B2 B3 B4 · F1 F2 F3 F4 B1 B2 B3 B4 · · F1 F2 F3 F4 B1 B2 B3 B4 bubble F1 F2 F3 F4 B1 B2 B3 B4 1F1B bubble = (P-1)/(M+P-1) = 3/7 ≈ 43% → with M=16: 3/19 ≈ 16% Communicates only boundary activations — tolerates slow inter-node InfiniBand. Legend: green=forward, red=backward, gray=idle (bubble).
Top: naive pipeline wastes 75% of time with a single micro-batch. Bottom: 1F1B schedule with M=4 micro-batches interleaves forward (green) and backward (red) passes, shrinking the bubble from 75% to 43%; with M=16 it drops to 16%. The remaining gray cells are the unavoidable startup/drain bubble.

PP vs TP: communication pattern

Pipeline parallelism sends activations at stage boundaries — for a batch of 32 tokens at hidden dim 4096, that is 32 × 4096 × 2 bytes = 256 KB per boundary per micro-batch. Tiny. Tensor parallelism sends gradient/activation tensors at every layer, which at layer hidden size 4096 is much larger and more frequent. This is why PP tolerates InfiniBand (25 GB/s) while TP needs NVLink (600 GB/s): volume × frequency tells you which link matters.

3D Parallelism & Decision Guide

Combine DP × TP × PP. The cluster topology decides which axis you use.

Frontier-scale training — GPT-3, PaLM, LLaMA-2 70B, Falcon 180B — uses all three parallelisms simultaneously. The assignment follows hardware topology: TP runs within a node (NVLink), PP runs across a small number of nodes (moderate-bandwidth InfiniBand), and DP runs across the remaining replica groups (low-bandwidth between data-center pods). On top of this, ZeRO shards the optimizer state across the DP group. The product of TP × PP × DP gives the total GPU count.

Total GPUs = TP_degree × PP_degree × DP_degree Example: LLaMA-2 70B training on 2048 A100s TP = 8 (within each 8-GPU node, over NVLink) PP = 4 (across 4 nodes, pipeline depth) DP = 2048 / (8 × 4) = 64 (64 data-parallel replicas)
3D Parallelism — What to Use When Strategy What it splits Interconnect needed Memory savings Use when… DDP / ZeRO-1/2 batch optimizer state/grads Any (only grads communicated) 1× (weights same), optimizer ÷N Model fits on 1 GPU; want throughput FSDP / ZeRO-3 batch + weights grads + opt state Medium (all-gather per layer) ÷ world_size (linear) Model too large for 1 GPU Tensor Parallel matrix dimensions within each layer NVLink mandatory (intra-node) ÷ TP_degree per layer Layer too big; fast intra-node link Pipeline Parallel layer stack across nodes InfiniBand OK (boundary acts only) ÷ PP_degree per node Many nodes; slow inter-node link 3D = TP×PP×DP all of the above simultaneously NVLink (TP) + IB (PP) + any (DP) product of all three Frontier: 1000+ GPU runs Rule of thumb: TP ≤ GPUs/node; PP = nodes if needed; DP fills the rest. ZeRO on top of DP always.
Decision guide for 3D parallelism. The cluster topology — specifically the ratio of intra-node to inter-node bandwidth — determines which strategies are viable at each axis. TP is constrained by NVLink availability; PP tolerates InfiniBand; DP is always cheapest.

Intra-node vs inter-node bandwidth

NVLink 4.0 (H100 SXM) delivers ~900 GB/s bidirectional within a node. InfiniBand HDR/NDR delivers 200–400 GB/s per port across nodes, but this bandwidth is shared among all GPUs communicating across the fabric and latency is much higher. (The current flagship, Blackwell B200/GB200, raises both per-GPU memory — 192GB HBM3e — and NVLink bandwidth; the same partitioning math applies, just with larger per-GPU capacity.) The ratio is roughly 3–10× faster intra-node. This is why TP — with all-reduces every layer — must stay within the NVLink domain, while PP — with a small activation tensor every micro-batch — can span nodes.

How this connects to inference

The parallelisms you learned in training reappear at inference time with the same logic. Tensor parallelism shards the model's weight matrices across GPUs within a node; serving frameworks replicate this exactly. Pipeline parallelism assigns transformer blocks to pipeline stages, and a request's KV cache must follow the request through the pipeline. With TP, the KV cache is also sharded — each GPU in the TP group holds the KV entries for its subset of attention heads. Understanding these mechanics is why an inference engineer needs to know distributed training: the checkpoint format, the sharding layout, and the communication patterns are the same artifact at training and inference time.

Gradient Checkpointing

Trade one extra forward pass for a massive activation memory reduction.

Activations dominate memory for long sequences, but most of them are needed only briefly during the backward pass. Gradient checkpointing (also called activation recomputation) stores activations only at a few checkpoint boundaries — typically one per transformer block — and recomputes the intermediate activations during backpropagation rather than storing them. The tradeoff is: one extra forward pass per backward (roughly +33% compute) in exchange for reducing activation memory from O(layers) to O(1) checkpoint tensors.

Without checkpointing: Activation memory ≈ B × T × d × L × 2 bytes Example: B=16, T=2048, d=4096, L=32, BF16 = 16 × 2048 × 4096 × 32 × 2 ≈ 8.6 GB With checkpointing (store 1 tensor per block boundary): Activation memory ≈ B × T × d × 1 × 2 bytes (just the block inputs) = 16 × 2048 × 4096 × 1 × 2 ≈ 0.27 GB → 32× reduction Cost: recompute all intermediate activations once more in backward
from torch.utils.checkpoint import checkpoint

# Without checkpointing:
x = transformer_block(x)          # stores all intermediate activations

# With gradient checkpointing:
x = checkpoint(transformer_block, x, use_reentrant=False)
# Stores only x (input to block); recomputes internals during backward.

Selective recomputation

Modern frameworks (FlashAttention-2, Megatron-LM) implement selective recomputation: they only recompute the expensive activations (attention softmax, dropout) while storing the cheap ones (layer norm outputs, linear activations). This gets most of checkpointing's memory savings at a much smaller compute overhead (5–10% instead of 33%).

Choosing a strategy — a decision guide

  • Model fits on one GPU, want it faster? Data parallelism (DDP). Add ZeRO-1 for free optimizer-state savings.
  • Model almost fits; optimizer state is the problem? ZeRO-1/2 or FSDP.
  • Model far too big for one GPU? ZeRO-3 / FSDP. Add TP within each node if layers are still too large.
  • Many nodes, slow inter-node interconnect? Pipeline parallelism across nodes, TP within each node.
  • Activation-bound (long sequences)? Gradient checkpointing on top of any of the above.
  • Frontier scale (1000+ GPUs)? 3D parallelism (TP × PP × DP) + ZeRO + gradient checkpointing.
Connection to Inference

The same techniques that train large models also serve them.

You might wonder why an inference curriculum spends a full day on training parallelisms. The answer is that the techniques are not separate: serving a 70B model uses the same TP and PP machinery as training it. The weight tensors arrive from training in their sharded form — a ZeRO-3 or FSDP checkpoint stores separate shards, not a monolithic file. An inference engineer loading this checkpoint must re-shard it correctly onto serving GPUs, which may use a different TP degree than training used.

KV-cache sharding with tensor parallelism

When you serve a model with TP degree N, each GPU in the TP group holds 1/N of the attention heads. During generation, the KV cache for each request follows the same partition: GPU 0 stores KV entries for its head subset, GPU 1 stores its subset, etc. This means the KV cache memory is naturally distributed, which is one reason TP at inference is not just about weight memory — it also distributes the per-request KV-cache footprint.

Pipeline parallelism at inference

For very long prompts (prefill phase), the compute per token is large and latency-insensitive. Pipeline parallelism helps here: the prompt's tokens can be split into chunks and streamed through a pipeline in a micro-batch pattern analogous to training. The autoregressive decode phase is trickier, since each decode step is a single token — the pipeline bubble is severe. This is why most serving systems today use TP (within a node) for decode and reserve PP for large-scale batch prefill or extremely large models.

vLLM's tensor_parallel_size parameter does exactly what Megatron-LM's TP does during training: it splits attention head groups and MLP weight matrices across GPUs within a node. When you serve a Llama-3 70B on a single 8×H100 node, you are running with TP=8 — the same column-parallel/row-parallel split you learned in this lesson. The weight files are gathered and re-partitioned at load time.

Exercise

Eight exercises, in the notebook.

Companion notebook: day-11-distributed-training.ipynb.

  1. Memory calculator. Write training_memory(N, bytes_per_param=16) and break it into weights / gradients / optimizer. Confirm a 7B model needs ~112 GB. Find the largest model that fits on 24 GB (single GPU).
  2. ZeRO savings table. Extend the calculator with zero_stage and num_gpus. Print a table of per-GPU memory for stages 0–3 and GPU counts 1, 2, 4, 8, 16 for a 7B model.
  3. All-reduce by hand. Implement the averaging all-reduce over a list of fake per-GPU gradient tensors (simulating N replicas). Verify every "GPU" ends with the mean. Then implement reduce-scatter and all-gather separately and show all-reduce = reduce-scatter + all-gather.
  4. Manual DDP gradient sync. Create N model replicas with different random gradients; simulate the DDP all-reduce by averaging their gradients; verify all replicas now have the same gradients and would take the same optimizer step.
  5. Pipeline bubble calculator. Write bubble_fraction(P, M) and plot bubble fraction vs M for P=4 and P=8. Find the minimum M to get bubble below 10% for each P.
  6. Memory accounting for TP. For a single transformer MLP layer with d_model=4096, d_ff=16384, show the per-GPU weight memory at TP=1, 2, 4, 8. Verify it is exactly 1/TP of the full weight.
  7. Gradient checkpointing. Run a stack of transformer blocks with and without checkpoint(). On CPU, time both. On GPU (if available), measure torch.cuda.max_memory_allocated() and compare.
  8. Strategy chooser. Given (model_params, gpu_mem_gb, num_gpus, has_nvlink), write a function that returns a recommended strategy string using the decision guide. Test on at least five realistic scenarios.
Self-Check

Ten questions before moving on.

Close the page and answer from memory. If you can't, re-read the relevant section.

  1. Why is the per-parameter training footprint ~16 bytes? Name each component and say which one dominates.
  2. What are the four collective communication primitives? State the identity relating all-reduce, reduce-scatter, and all-gather.
  3. What does DDP (data parallelism) solve and what does it not solve? What is the communication cost formula for its all-reduce?
  4. What does each ZeRO stage (1/2/3) shard, and how does ZeRO-3 relate to FSDP?
  5. For a 7B model on 8 GPUs, compute the approximate per-GPU memory for ZeRO-0, ZeRO-1, ZeRO-2, and ZeRO-3.
  6. Explain column-parallel and row-parallel linear in Megatron-style TP. Which collective does each require?
  7. Why must tensor parallelism stay intra-node, and why can pipeline parallelism span nodes?
  8. What is the pipeline bubble fraction for P stages and M micro-batches? How many micro-batches do you need to get below 10% bubble with P=8?
  9. What does gradient checkpointing trade? State the approximate memory reduction and compute overhead for a 32-layer model.
  10. How does tensor parallelism at training time connect to multi-GPU inference? What happens to the KV cache in a TP-parallel serving setup?

"Data parallel splits the batch. Tensor parallel splits the layer. Pipeline parallel splits the stack. ZeRO refuses to store anything twice. The interconnect decides which combination you can afford."

Day 11 · Distributed training
Further Reading

Go deeper.

The distributed-training canon.

Paper · 2019

Rajbhandari et al. — ZeRO

The Zero Redundancy Optimizer. Stages 1/2/3 defined here; includes memory analysis and large-scale benchmarks.

Open paper
Paper · 2019

Shoeybi et al. — Megatron-LM

Tensor parallelism for transformer attention and MLP: column-parallel and row-parallel linear, and benchmarks up to 8.3B on 512 GPUs.

Open paper
Paper · 2021

Narayanan et al. — Efficient Large-Scale LM Training (Megatron v2)

The 3D parallelism paper: combines TP, PP, and DP. Defines 1F1B schedule and achieves 52% MFU on 3072 GPUs.

Open paper
Paper · 2018

Huang et al. — GPipe

Pipeline parallelism with micro-batches. First rigorous analysis of the bubble and how M micro-batches shrink it.

Open paper
Docs · PyTorch

FSDP Tutorial

PyTorch-native ZeRO-3. Covers auto_wrap_policy, mixed precision config, and checkpoint saving/loading.

Read tutorial
Docs · PyTorch

DDP Tutorial

DistributedDataParallel and torchrun, hands-on. Covers bucketing, gradient hooks, and multi-node setup.

Read tutorial
Blog · HF

The Ultra-Scale Playbook

Hugging Face's interactive guide to 3D parallelism, memory math, and hardware topology. Best visual reference available.

Open playbook
Blog · Lilian Weng

Large Model Training Techniques

Comprehensive survey of parallelism strategies, gradient checkpointing, mixed precision, and memory optimization with clean diagrams.

Read post