A 70B model does not fit on one GPU — not the weights, not the gradients, not the optimizer state. Today you master the four ways to split a training job across devices: data parallelism (DDP), sharded data parallelism (ZeRO/FSDP), tensor parallelism (Megatron-style), and pipeline parallelism — plus the collective communication primitives that make them work. These are not just training tools: tensor and pipeline parallelism are how vLLM and TensorRT-LLM serve billion-parameter models at inference time.
Every distributed-training technique exists to smash one of two walls. The memory wall: the model — weights, gradients, optimizer state, and activations — no longer fits on one GPU. The time wall: the model fits but training takes weeks on a single device, which is too slow to be practical. Data parallelism (DDP) attacks the time wall. ZeRO/FSDP, tensor parallelism, and pipeline parallelism attack the memory wall. Gradient checkpointing attacks a specific sub-problem of the memory wall — activation memory.
You will likely never train a 70B model yourself, but as an inference engineer you must read the configs and understand the artifacts these systems produce: sharded checkpoints, tensor-parallel weight layouts, and the communication patterns that also appear in multi-GPU inference. The tensor parallelism you learn today is exactly how vLLM and TensorRT-LLM split a model across GPUs to serve it. Pipeline parallelism shows up in long-pipeline serving to hide prefill latency. Understanding these techniques at a mathematical level means you can reason about throughput, latency, and memory budgets for any model on any cluster — which is precisely the day-to-day job of an inference engineer.
Before choosing a strategy, count what training actually stores. With mixed precision and AdamW, each parameter incurs memory in four distinct categories simultaneously. The standard accounting — the "16 bytes per parameter" rule — goes like this:
Three lessons jump out. First, the weights themselves are only 2/16 = 1/8 of the footprint. The FP32 master copy plus Adam's two moment vectors consume 12 of those 16 bytes — the optimizer state alone is 75% of the budget. Second, even a 7B model overflows a single 80 GB GPU once you account for gradients and optimizer state, before activations. Third, inference needs only the 2-byte BF16 weights plus a KV cache — which is why you can serve a 70B model on a cluster that cannot train it.
Adam accumulates small gradient updates over many steps. If you accumulate in BF16, rounding errors compound: updates smaller than the BF16 resolution (~0.007 relative) simply vanish. The FP32 master copy preserves precision across millions of steps. The BF16 working copy is just what gets loaded into the GPU's tensor cores for the forward/backward computation.
On top of the 16-bytes-per-parameter cost, the forward pass stores activations for the backward pass. Activation memory scales roughly as batch × sequence_length × hidden × num_layers × bytes_per_element. For a 7B model with batch 16, sequence 2048, hidden 4096, 32 layers, in BF16 the residual-stream activations come to 16 × 2048 × 4096 × 32 × 2 ≈ 8.6 GB. That is only the residual term; also storing attention scores and FFN intermediates pushes the true footprint several times higher (the companion notebook uses a ~12× estimate ≈ 100 GB at these settings). At sequence 32768 the residual term alone becomes 136 GB. This is why gradient checkpointing exists, and why sequence length has historically been the binding constraint for long-context training.
Distributed training is I/O between GPUs. Before the parallelisms, you need to understand the four collective operations they compose. Think of N GPUs, each holding a tensor chunk. The collectives specify what happens to those chunks.
| Primitive | Input (each GPU) | Output (each GPU) | Net data moved |
|---|---|---|---|
| Broadcast | Root has full tensor; others empty | All have full tensor (root's copy) | M × (N-1) / N per GPU |
| All-reduce | Each holds a full tensor, values differ | All hold the element-wise sum/mean | 2M × (N-1) / N per GPU |
| Reduce-scatter | Each holds a full tensor | Each holds 1/N-th of the sum (its shard) | M × (N-1) / N per GPU |
| All-gather | Each holds 1/N-th of a tensor | All hold the full concatenated tensor | M × (N-1) / N per GPU |
One key identity: all-reduce = reduce-scatter + all-gather. This decomposition is exactly what the ring algorithm exploits and what ZeRO leverages to shard efficiently. Memorize this: reduce-scatter gives every GPU its portion of the sum; all-gather then broadcasts those portions so everyone has the full result.
For a gradient tensor of size M bytes and N GPUs, a ring all-reduce moves exactly 2M(N-1)/N ≈ 2M bytes per GPU. This is independent of N for large N — scaling from 8 to 64 GPUs barely changes the per-GPU communication volume. The bottleneck is therefore always the link bandwidth, not the GPU count, which is why NVLink (≈600 GB/s bidirectional on A100, ≈900 GB/s on H100) vs InfiniBand (200–400 GB/s) matters enormously.
Data parallelism (DP) is the simplest and most common strategy. Every GPU holds a complete copy of the model. The global batch is split across GPUs — each processes its own shard independently in the forward and backward passes. Then an all-reduce averages every gradient across all GPUs so all replicas apply the identical optimizer step and stay in sync. From the optimizer's perspective it is as if one GPU processed the full batch.
PyTorch implements this as DistributedDataParallel (DDP). The key implementation insight: DDP does not wait until the full backward pass finishes before starting communication. It begins all-reducing each parameter's gradient as soon as that gradient is ready — overlapping computation and communication so the all-reduce is almost free on fast links.
# DDP — launch with: torchrun --nproc_per_node=4 train.py
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
model = GPT(cfg).cuda()
model = DDP(model, device_ids=[local_rank]) # all-reduce hooked into backward automatically
# Training loop is otherwise unchanged — DDP intercepts .backward().
With perfect hardware, doubling GPU count halves training time. In practice, scaling efficiency (actual speedup / ideal speedup) falls below 1 for two reasons. First, the all-reduce volume is proportional to the number of parameters. A 7B model's gradients weigh 14 GB in BF16; even at 600 GB/s NVLink that is ~23 ms per step. Second, small models spend a larger fraction of time communicating relative to computing, so they scale worse. The ratio that matters is compute-to-communication: larger models, bigger batches, and faster interconnects all improve it.
Plain data parallelism wastes memory: N GPUs store N identical copies of the optimizer state, gradients, and weights. The Zero Redundancy Optimizer (ZeRO, DeepSpeed) and PyTorch's Fully Sharded Data Parallel (FSDP) eliminate that redundancy by sharding those tensors across the data-parallel group and gathering them on demand. The name "ZeRO" is literal — there is zero redundancy. Each byte of optimizer state, gradient, or weight is stored by exactly one GPU at steady state.
ZeRO proceeds in three stages of increasing aggressiveness. Each stage shards one more category of tensors while preserving the same mathematical result as DDP — the gradient flowing to the optimizer is identical; only where it lives differs.
| Stage | Shards | 7B on 8 GPUs (GB/GPU) | Extra comms vs DDP | PyTorch API |
|---|---|---|---|---|
| ZeRO-0 (DDP) | Nothing | 112 | 1× all-reduce | DDP |
| ZeRO-1 | Optimizer state | 38.5 | 1× all-reduce (same) | DeepSpeed / manual |
| ZeRO-2 | + Gradients | 26.25 | 1× reduce-scatter | DeepSpeed / manual |
| ZeRO-3 / FSDP | + Parameters | 14 | all-gather per-layer fwd+bwd | FSDP |
The mechanism for ZeRO-3/FSDP: each GPU permanently stores only its 1/N shard of each parameter tensor. When a forward pass reaches a layer, an all-gather reconstitutes the full weights just-in-time, computes the layer output, then frees the gathered weights. In the backward pass the same all-gather happens again, followed by a reduce-scatter that distributes gradient shards. This is why ZeRO-3 communicates more than DDP: instead of one all-reduce per step, it does an all-gather + reduce-scatter per layer, per forward and backward. The price of memory is communication volume — roughly 3× DDP's bandwidth for ZeRO-3. On fast NVLink this is usually fine; on slow interconnects it can dominate.
ZeRO-3 and FSDP are the same mathematical idea discovered independently — Microsoft's DeepSpeed shipped ZeRO (2019), PyTorch built FSDP natively (2021). The difference is packaging: FSDP is baked into PyTorch with no external dependencies. Both achieve near-linear memory scaling with GPU count. Meta used FSDP to train LLaMA-2 70B on 2048 A100 GPUs; without sharding, each GPU would need to hold the full 1.12 TB of parameters + optimizer state.
# FSDP — PyTorch-native ZeRO-3. Wraps the model; shards params/grads/optimizer state.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrap_policy, # shard at transformer block boundaries
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
)
# Each GPU stores model_size / world_size; layers all-gathered on demand.
ZeRO/FSDP still requires gathering the full layer weights during the forward pass — it avoids storing redundant copies at rest, but each GPU processes one layer at a time. Tensor parallelism (TP) goes further: it keeps the weights sharded even during computation by splitting the matrix multiplication across GPUs. The math works out cleanly because matrix multiplies can be partitioned by either rows or columns of the weight matrix.
For a linear Y = XW where X ∈ ℝ^(B×d_in) and W ∈ ℝ^(d_in×d_out):
In Megatron-LM, the MLP is split as column-parallel Linear → GeLU → row-parallel Linear. The column-parallel layer shards by output dimension (each GPU computes part of the hidden state); no communication is needed between them because GeLU is elementwise. The row-parallel layer shards by input dimension and finishes with an all-reduce to sum partial contributions. Net result: two communications per MLP (one all-gather + one all-reduce), and the weight matrices are split N-way across GPUs.
W1 is column-partitioned (each GPU gets half the output neurons); GeLU runs locally; W2 is row-partitioned, finishing with one all-reduce to sum partial results. Net: weights are split N-way, one all-reduce per layer — must run over NVLink.For multi-head attention, TP is even cleaner: each GPU handles a disjoint subset of attention heads. With 32 heads on 4 GPUs, GPU 0 does heads 0–7, GPU 1 does heads 8–15, etc. Each head is fully self-contained (no cross-head communication during attention), so the only collective is one all-reduce after projecting the concatenated head outputs through the output matrix W_O. In practice the QKV projection is also column-split and the output projection is row-split, matching the MLP pattern.
Every transformer layer requires at least one all-reduce in the forward pass and one in the backward. For a model with 32 layers and TP degree 8, that is 64 all-reduces per training step. At a gradient size of ~50 MB per layer and 600 GB/s NVLink, each all-reduce takes ~0.17 ms — acceptable. At 12 GB/s InfiniBand, the same operation takes ~8 ms, making TP across nodes prohibitively slow. Rule: TP always runs within a node, over NVLink.
For inference this rule carries over directly: vLLM's tensor parallelism groups GPUs within a node. If you have two nodes, you use two separate TP groups (one per node) and route requests independently.
Pipeline parallelism (PP) partitions the model's layer stack vertically: GPU 0 holds layers 1–8, GPU 1 holds layers 9–16, and so on. A batch enters GPU 0, flows through its layers, and the intermediate activations at the stage boundary (called the pipeline communication) are sent to GPU 1, which continues. No GPU needs to store all layers simultaneously — each GPU's model memory is 1/P of the full model, where P is the pipeline depth.
The naive pipeline is embarrassingly serial: GPU 1 must wait for GPU 0 to finish before it starts. During that wait, GPU 1 is idle. This idle time is called the pipeline bubble. In a pipeline of depth P, the naive bubble fraction is (P-1)/P — with 8 stages, 7/8 of time is wasted. Useless for production.
The fix is to split each batch into M micro-batches that flow through the pipeline in a staggered schedule. GPipe runs all micro-batches forward, then all backward. 1F1B (one-forward-one-backward) interleaves them: while the first micro-batch is in the backward pass on GPU 0, the second micro-batch's forward pass is running on GPU 1. 1F1B keeps pipeline memory lower than GPipe because it doesn't accumulate all forward-pass activations before starting the backward.
Pipeline parallelism sends activations at stage boundaries — for a batch of 32 tokens at hidden dim 4096, that is 32 × 4096 × 2 bytes = 256 KB per boundary per micro-batch. Tiny. Tensor parallelism sends gradient/activation tensors at every layer, which at layer hidden size 4096 is much larger and more frequent. This is why PP tolerates InfiniBand (25 GB/s) while TP needs NVLink (600 GB/s): volume × frequency tells you which link matters.
Frontier-scale training — GPT-3, PaLM, LLaMA-2 70B, Falcon 180B — uses all three parallelisms simultaneously. The assignment follows hardware topology: TP runs within a node (NVLink), PP runs across a small number of nodes (moderate-bandwidth InfiniBand), and DP runs across the remaining replica groups (low-bandwidth between data-center pods). On top of this, ZeRO shards the optimizer state across the DP group. The product of TP × PP × DP gives the total GPU count.
NVLink 4.0 (H100 SXM) delivers ~900 GB/s bidirectional within a node. InfiniBand HDR/NDR delivers 200–400 GB/s per port across nodes, but this bandwidth is shared among all GPUs communicating across the fabric and latency is much higher. (The current flagship, Blackwell B200/GB200, raises both per-GPU memory — 192GB HBM3e — and NVLink bandwidth; the same partitioning math applies, just with larger per-GPU capacity.) The ratio is roughly 3–10× faster intra-node. This is why TP — with all-reduces every layer — must stay within the NVLink domain, while PP — with a small activation tensor every micro-batch — can span nodes.
The parallelisms you learned in training reappear at inference time with the same logic. Tensor parallelism shards the model's weight matrices across GPUs within a node; serving frameworks replicate this exactly. Pipeline parallelism assigns transformer blocks to pipeline stages, and a request's KV cache must follow the request through the pipeline. With TP, the KV cache is also sharded — each GPU in the TP group holds the KV entries for its subset of attention heads. Understanding these mechanics is why an inference engineer needs to know distributed training: the checkpoint format, the sharding layout, and the communication patterns are the same artifact at training and inference time.
Activations dominate memory for long sequences, but most of them are needed only briefly during the backward pass. Gradient checkpointing (also called activation recomputation) stores activations only at a few checkpoint boundaries — typically one per transformer block — and recomputes the intermediate activations during backpropagation rather than storing them. The tradeoff is: one extra forward pass per backward (roughly +33% compute) in exchange for reducing activation memory from O(layers) to O(1) checkpoint tensors.
from torch.utils.checkpoint import checkpoint
# Without checkpointing:
x = transformer_block(x) # stores all intermediate activations
# With gradient checkpointing:
x = checkpoint(transformer_block, x, use_reentrant=False)
# Stores only x (input to block); recomputes internals during backward.
Modern frameworks (FlashAttention-2, Megatron-LM) implement selective recomputation: they only recompute the expensive activations (attention softmax, dropout) while storing the cheap ones (layer norm outputs, linear activations). This gets most of checkpointing's memory savings at a much smaller compute overhead (5–10% instead of 33%).
You might wonder why an inference curriculum spends a full day on training parallelisms. The answer is that the techniques are not separate: serving a 70B model uses the same TP and PP machinery as training it. The weight tensors arrive from training in their sharded form — a ZeRO-3 or FSDP checkpoint stores separate shards, not a monolithic file. An inference engineer loading this checkpoint must re-shard it correctly onto serving GPUs, which may use a different TP degree than training used.
When you serve a model with TP degree N, each GPU in the TP group holds 1/N of the attention heads. During generation, the KV cache for each request follows the same partition: GPU 0 stores KV entries for its head subset, GPU 1 stores its subset, etc. This means the KV cache memory is naturally distributed, which is one reason TP at inference is not just about weight memory — it also distributes the per-request KV-cache footprint.
For very long prompts (prefill phase), the compute per token is large and latency-insensitive. Pipeline parallelism helps here: the prompt's tokens can be split into chunks and streamed through a pipeline in a micro-batch pattern analogous to training. The autoregressive decode phase is trickier, since each decode step is a single token — the pipeline bubble is severe. This is why most serving systems today use TP (within a node) for decode and reserve PP for large-scale batch prefill or extremely large models.
vLLM's tensor_parallel_size parameter does exactly what Megatron-LM's TP does during training: it splits attention head groups and MLP weight matrices across GPUs within a node. When you serve a Llama-3 70B on a single 8×H100 node, you are running with TP=8 — the same column-parallel/row-parallel split you learned in this lesson. The weight files are gathered and re-partitioned at load time.
Companion notebook: day-11-distributed-training.ipynb.
training_memory(N, bytes_per_param=16) and break it into weights / gradients / optimizer. Confirm a 7B model needs ~112 GB. Find the largest model that fits on 24 GB (single GPU).zero_stage and num_gpus. Print a table of per-GPU memory for stages 0–3 and GPU counts 1, 2, 4, 8, 16 for a 7B model.bubble_fraction(P, M) and plot bubble fraction vs M for P=4 and P=8. Find the minimum M to get bubble below 10% for each P.checkpoint(). On CPU, time both. On GPU (if available), measure torch.cuda.max_memory_allocated() and compare.Close the page and answer from memory. If you can't, re-read the relevant section.
"Data parallel splits the batch. Tensor parallel splits the layer. Pipeline parallel splits the stack. ZeRO refuses to store anything twice. The interconnect decides which combination you can afford."
The distributed-training canon.
The Zero Redundancy Optimizer. Stages 1/2/3 defined here; includes memory analysis and large-scale benchmarks.
Open paperTensor parallelism for transformer attention and MLP: column-parallel and row-parallel linear, and benchmarks up to 8.3B on 512 GPUs.
Open paperThe 3D parallelism paper: combines TP, PP, and DP. Defines 1F1B schedule and achieves 52% MFU on 3072 GPUs.
Open paperPipeline parallelism with micro-batches. First rigorous analysis of the bubble and how M micro-batches shrink it.
Open paperPyTorch-native ZeRO-3. Covers auto_wrap_policy, mixed precision config, and checkpoint saving/loading.
Read tutorialDistributedDataParallel and torchrun, hands-on. Covers bucketing, gradient hooks, and multi-node setup.
Hugging Face's interactive guide to 3D parallelism, memory math, and hardware topology. Best visual reference available.
Open playbookComprehensive survey of parallelism strategies, gradient checkpointing, mixed precision, and memory optimization with clean diagrams.
Read post