LLM Inference Engineer · Day 02
Day 02 · Week 1 · Foundations
🔢

Tensors, Autograd, PyTorch & MLX

Tensors as the universal data structure — shape, dtype, device, strides, memory layout. How autograd actually works: forward DAG, backward walk, chain rule on graphs. PyTorch and MLX internals compared side by side. Every symbol decoded, every gradient verified by hand.

Time~150 min
DifficultyMedium
PrerequisiteDay 1 math
Why This Lesson

Two abstractions, every framework.

Yesterday you saw the math: matmuls, softmax, chain rule, cross-entropy. Today we wire all of that into the two abstractions every modern ML framework is built on:

  1. Tensors — n-dimensional arrays with dtype, shape, device. Where the numbers live.
  2. Autograd — automatic computation of derivatives. Where the chain rule lives.

Get fluent with these and every framework — PyTorch, MLX, JAX, TensorFlow, NumPy + a hand-rolled autograd — feels familiar. We focus on PyTorch (cross-platform standard, what NVIDIA users mostly write) and MLX (Apple Silicon native, what we'll use whenever Mac performance matters). They look 90% alike; the 10% difference matters when you read source code.

By the end of the lesson you should be able to read a forward pass written in either framework, predict its shapes at every line, explain in plain English what loss.backward() does under the hood, and connect dtype choices to real inference memory costs.

Learning objectives

  1. Define what a tensor is (shape, dtype, device, stride) and create one in PyTorch and MLX.
  2. Predict the output shape of broadcasting, reductions, matmul, and einsum operations, and verify by running code.
  3. Distinguish a view (no data copy, different strides) from a copy, explain when .view() fails, and explain why a transposed tensor is non-contiguous.
  4. Calculate the exact VRAM footprint of a model given its parameter count and dtype, using the formula numel × bytes_per_element.
  5. Move tensors between devices on NVIDIA (CPU ↔ CUDA via PCIe) and explain why this is unnecessary on Apple Silicon (unified memory).
  6. Walk a small computational graph forward and then backward, computing every intermediate gradient by hand and matching what autograd produces.
  7. Explain what grad_fn is, why gradients accumulate at branch points, and why inference should use torch.inference_mode().
  8. Read a PyTorch training loop and an MLX training loop and identify the equivalent steps in each.
  9. State three concrete differences between PyTorch's eager autograd and MLX's lazy graph-based autograd.

By the end of this section you'll be able to…

…look at any line of transformer forward-pass code, say exactly what shape comes out, how many bytes that tensor occupies in memory, whether it holds a gradient, and what device it lives on. That's the entire mental model for inference engineering.

The One-Picture Mental Model

Forward builds a graph. Backward walks it.

Pin this picture before diving in:

tensors shape, dtype, device ops on tensors +, @, softmax, ... computational graph recorded by autograd loss scalar .backward() walks the graph in reverse, applying the chain rule grads accumulated into x.grad
Forward (left → right): build a graph by running ops on tensors. Backward (right → left): autograd applies the chain rule along the recorded graph.

Every concept today either lives on the tensor side of this picture (shape, dtype, device, strides, broadcasting) or on the autograd side (graph, grad_fn, accumulation, no-grad).

The Universal Object

A tensor is an n-dimensional array with metadata.

A tensor is an n-dimensional array of numbers with three pieces of metadata: dtype, shape, device.

  • 0-D (scalar): 5.0 — shape ()
  • 1-D (vector): [1, 2, 3] — shape (3,)
  • 2-D (matrix): [[1, 2], [3, 4]] — shape (2, 2)
  • 3-D: sequence batches (B, T, D), image (C, H, W)
  • 4-D: image batch (B, C, H, W), attention scores (B, h, T, T)
  • 5-D+: rare, but legal (e.g., video (B, T, C, H, W))

A scalar is a special case of a vector, which is a special case of a matrix, which is a special case of a tensor. Every number in deep learning lives in a tensor.

import torch
x = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0]])
print(x.shape, x.dtype, x.device)
# torch.Size([2, 3]) torch.float32 cpu

Tensor anatomy: storage, shape, strides

A tensor is a view over a flat block of memory. Two pieces of metadata interpret that flat block as multi-dimensional:

  • shape — the dimensions, e.g., (2, 3).
  • strides — how many elements you step in flat memory when you advance by one along each axis.

For a contiguous (2, 3) matrix in row-major order, the flat layout is:

Logical (2, 3) matrix 1 2 3 4 5 6 Flat memory (row-major) 1 2 3 4 5 6 [0] [1] [2] [3] [4] [5] strides = (3, 1) → step 3 to next row, step 1 to next column addr(i, j) = i · stride[0] + j · stride[1] → addr(1, 2) = 1·3 + 2·1 = 5 ✓ .T (transpose) = same storage, swapped strides → (1, 3). No copy.
A tensor is metadata over flat memory. Strides are why transpose, slice, and reshape are usually free.

Strides are why transpose, slicing, and reshape are usually free — they just change the metadata, not the bytes. And they're why operations on non-contiguous tensors sometimes need a .contiguous() call: a kernel that expects row-major data won't work on transposed strides.

x = torch.arange(6).reshape(2, 3)
print(x.shape, x.stride())          # torch.Size([2, 3]) (3, 1)
print(x.T.shape, x.T.stride())      # torch.Size([3, 2]) (1, 3)  same storage
print(x.T.is_contiguous())          # False
print(x.T.contiguous().stride())    # (2, 1)  copied to a new buffer

Reshape vs view vs transpose — and why contiguity matters

Four operations all change a tensor's shape. Only one of them (.contiguous()) is ever guaranteed to copy. The rest manipulate metadata:

x = (2, 3) strides (3, 1) contiguous ✓ x.view(6) shape (6,) strides (1,) same storage ✓ contiguous ✓ no copy x.reshape(6) same storage if contiguous copies if not (friendly) maybe copy x.T (transpose) shape (3, 2) strides (1, 3) same storage non-contiguous ✗ .view() will FAIL here no copy x.T.contiguous() shape (3, 2) strides (2, 1) NEW storage contiguous ✓ .view() works again COPIES The rule view → strict: must be contiguous reshape → friendly: copies if needed transpose → free but non-contiguous contiguous() → always copies
.view() and .T are free (no copy, just different strides). .T produces a non-contiguous tensor — calling .view() on it raises a RuntimeError. Use .reshape() or .contiguous().view() to recover.

dtype — bytes per number, and why it matters for LLMs

Every tensor has a dtype (data type) that determines how many bytes each element occupies and what range of values it can represent. Before any formulas, here's the concrete picture: an fp32 number uses 32 bits (1 sign + 8 exponent + 23 mantissa); bf16 truncates the mantissa to 7 bits but keeps the same 8-bit exponent — preserving fp32's wide range while halving the bytes.

Bit layouts (each cell = 1 bit) fp32 (32 bits) S exponent (8 bits) mantissa (23 bits) 4 bytes bf16 (16 bits) S exponent (8 bits) mantissa (7) 2 bytes ← same range as fp32 ✓ fp16 (16 bits) S exp (5) mantissa (10 bits) 2 bytes ← narrower range; overflow risk int8 (8 bits) integer −128 … 127 1 byte ← quantized inference only
bf16 keeps fp32's 8-bit exponent (wide range) but truncates the mantissa to 7 bits. fp16 has only 5 exponent bits — narrower range, overflow risk in activations. Both cost 2 bytes. This choice drives the KV-cache and weight-matrix memory budget for every LLM you deploy.
dtypebitsbytesexponent bitsmantissa bitstypical use
float646481152scientific computing; almost never in DL
float32324823training default, full precision
bfloat1616287training & inference; same range as fp32
float16162510inference; narrower range (overflow risk)
int881quantized inference (Day 22)
int440.5aggressively quantized inference

Memory formula. Total bytes for a tensor = numel() × bytes_per_element. For a model's weights: num_parameters × bytes_per_dtype. You can read this off in one line:

x = torch.randn(1024, 4096, dtype=torch.bfloat16)
print(x.numel() * x.element_size())   # 1024 * 4096 * 2 = 8,388,608 bytes = 8 MB

Concrete VRAM math. A 7B-parameter model in fp32 needs 7 × 10⁹ × 4 = 28 GB of VRAM just for the weights. In bf16: 14 GB. In int8: 7 GB. In int4: 3.5 GB. dtype is the single biggest lever for "does this model fit on this GPU." We'll see this lever pulled hard on Days 22 and 24.

The word "bfloat16" stands for Brain Float 16 — named after Google Brain, which developed it to enable training on TPUs with fp32 dynamic range. Its key advantage: you can simply truncate the lower 16 bits of an fp32 to get bf16. Conversely, you can widen bf16 to fp32 with a bit-shift and zero-fill — no rounding needed. fp16 requires a full format conversion.

device — where the tensor physically lives

Backenddevice stringWhere it lives
CPU"cpu"system RAM, accessible to any framework
NVIDIA GPU"cuda", "cuda:0"GPU's HBM (separate memory bus)
Apple GPU via PyTorch MPS"mps"unified memory, accessed through Metal
Apple GPU via MLX(implicit)unified memory, native

Two tensors on different devices can't directly interact. PyTorch raises RuntimeError: Expected all tensors to be on the same device. Move first, op second.

The word "tensor" comes from physics — specifically the work of Gregorio Ricci-Curbastro and Tullio Levi-Civita in the 1890s. Einstein used tensors in general relativity (1915) decades before they appeared in machine learning. PyTorch tensors are exactly the same mathematical objects — multilinear arrays — just with autograd bolted on.

Operations

Element-wise, reductions, matmul, broadcasting, einsum.

Three classes of op cover ~90% of what you'll do:

Element-wise

Same shape in, same shape out. a + b, a * b, torch.exp(a), torch.relu(a). Concrete: [1,2,3] + [10,20,30] = [11,22,33].

Reductions

Collapse one or more dimensions. a.sum(dim=-1), a.mean(dim=0), a.max(dim=1). Concrete: a (2, 3) tensor reduced along dim=1 (last) becomes shape (2,) — one value per row.

a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
print(a.sum(dim=-1))   # tensor([ 6., 15.])  shape (2,)
print(a.sum(dim=0))    # tensor([5., 7., 9.])  shape (3,)
print(a.sum())         # tensor(21.)  shape ()  full reduction

Matmul

The workhorse from Day 1. (m, k) @ (k, n) → (m, n). Batched: (B, m, k) @ (B, k, n) → (B, m, n).

A = torch.randn(8, 16)
W = torch.randn(16, 4)
(A @ W).shape           # torch.Size([8, 4])

Ab = torch.randn(2, 8, 16)
(Ab @ W).shape          # torch.Size([2, 8, 4])  batch dim flows through

Broadcasting deep dive

Broadcast — read: "stretch" — lets you add a (3,) bias vector to every row of a (5, 3) matrix without writing a loop or manually copying data. The framework stretches the smaller array virtually — no extra memory, no copy.

Concrete example first. We want to add a bias b = [10, 20, 30] to a matrix A with shape (4, 3). A has shape (4, 3); b has shape (3,). To apply the rule, align from the right:

A: (4, 3) b: (3,) ← virtually prepend a 1: (1, 3) rule: dims match (4 vs 1 → stretch) and (3 == 3 → keep) out: (4, 3) ← b added to every row

This is the workhorse pattern in transformers: x + bias where x is (B, T, D) and bias is (D,) — the bias broadcasts over both batch and time dimensions simultaneously.

When shapes don't match, frameworks try to align them from the right. Each pair of aligned dims must either be equal or have one of them be 1 (which is then "stretched" virtually — no copy, no extra memory). Otherwise: error.

Align from the right; dims must match or be 1. a: 8 1 4 b: · 3 4 a+b: 8 3 4 a has shape (8, 1, 4) b has shape (3, 4); virtually pad with 1 on left result: (8, 3, 4) — broadcast dims of size 1 stretched, no copy gold = the dim that gets stretched
Broadcasting aligns shapes from the right; size-1 dims stretch virtually.
a = torch.randn(8, 1, 4)
b = torch.randn(3, 4)
(a + b).shape          # torch.Size([8, 3, 4])

# Bias added to every (b, t) — workhorse pattern in transformers
x = torch.randn(2, 5, 8)        # (B, T, D)
bias = torch.randn(8)           # (D,)
(x + bias).shape                # (2, 5, 8)

Failure case. Shapes (3, 4) and (2, 4): align from right, 4 == 4 ok, but 3 vs 2 — neither is 1 — error.

einsum — the most expressive op

einsum lets you spell out indices and the framework figures out the loop. It is the clearest way to express tensor contractions; once it clicks you'll prefer it.

# Plain matmul: sum over j
C = torch.einsum('ij,jk->ik', A, B)        # same as A @ B

# Batched matmul: keep batch b, sum over j
C = torch.einsum('bij,bjk->bik', A, B)

# Attention scores: Q (B, T, D), K (B, S, D) → (B, T, S)
scores = torch.einsum('btd,bsd->bts', Q, K)
# Equivalent to: Q @ K.transpose(-1, -2)

# Outer product:
outer = torch.einsum('i,j->ij', a, b)

# Trace (sum of diagonal):
tr = torch.einsum('ii->', M)

Rule of thumb: letters that appear on both sides are kept; letters that disappear are summed over.

Reshape, view, transpose, permute

Four close cousins. All change shape; only some change strides; none (usually) copy data.

x = torch.arange(24).reshape(2, 3, 4)   # (2, 3, 4), strides (12, 4, 1)
x.view(6, 4)                            # (6, 4) — works because contiguous
x.transpose(0, 1)                       # (3, 2, 4), strides (4, 12, 1) non-contig
x.permute(2, 0, 1)                      # (4, 2, 3) — arbitrary axis reorder

.view() is strict — requires the tensor to already be contiguous, never copies. .reshape() is friendly — copies if necessary. After .transpose(), the tensor is non-contiguous; .view() will fail; use .reshape() or .contiguous().view(...).

Devices and Data Movement

PCIe vs unified memory.

The single biggest engineering difference between NVIDIA and Apple Silicon is the memory model.

NVIDIA + Linux CPU system RAM ~50 GB/s DDR5 NVIDIA GPU HBM (VRAM) ~3 TB/s on H100 PCIe ~32 GB/s copy = expensive Apple Silicon CPU GPU unified memory ~400 GB/s on M3 Max no copy — same address space
NVIDIA's CPU and GPU each have their own memory; Apple Silicon shares one address space.

On NVIDIA, CPU and GPU each have their own memory. Moving a tensor across PCIe is slow (~32 GB/s) compared to HBM bandwidth (~3.3 TB/s on H100, ~8 TB/s on a Blackwell B200). The rule is: move data once, keep it on the GPU.

x_cpu = torch.randn(1000, 1000)
x_gpu = x_cpu.to("cuda")          # PCIe copy
y     = x_gpu @ x_gpu.T           # runs on GPU
back  = y.cpu().numpy()           # PCIe copy back

On Apple Silicon, CPU and GPU share the same physical RAM. There is no .to("device") that moves data — the GPU just reads the same bytes. This is why MLX has no device field on tensors at all.

import mlx.core as mx
x = mx.random.normal((1000, 1000))   # lives in unified memory
y = x @ x.T                          # GPU touches the same RAM

PyTorch's MPS backend works on Mac too (device="mps") but is generally slower than MLX for LLM workloads — the MPS backend bridges Metal through PyTorch's dispatcher rather than running natively on Apple Silicon's compute graph.

Device moves cost time — keep data on the fast bus

The practical rule: move data once, then keep it resident. Every .to("cuda") crosses PCIe at ~32 GB/s. That sounds fast, but HBM runs at ~3.3 TB/s (≈8 TB/s on Blackwell) — ~100× faster. Moving a 1 GB activation tensor from CPU to GPU costs ~31 ms; the matmul itself might cost 2 ms. The transfer is the bottleneck. Common mistake: loading a batch on CPU, running preprocessing, then moving to GPU inside the training loop — pay the transfer on every step.

# Good: move once, keep on device
x_gpu = dataset_tensor.to("cuda")          # one transfer at the start
for step in range(1000):
    loss = model(x_gpu[batch_idx])         # GPU → GPU, no transfer

# Bad: transfer every step
for step in range(1000):
    x_batch = load_cpu_batch()             # CPU
    loss = model(x_batch.to("cuda"))       # PCIe every step

Why this matters for inference. The KV cache in an LLM inference server stores key and value tensors for every layer and every token in the context. For a LLaMA-2 7B-style model with a 4096-token context in fp16, the KV cache alone is about 2 GB per active sequence. It lives on the GPU throughout the request. If it spills to CPU, inference throughput collapses. This is why vLLM's PagedAttention (Day 24) manages KV cache memory like a virtual memory system — the dtype × context × layers math is unforgiving.

Apple's MLX framework launched in December 2023 — making it one of the youngest major ML frameworks. Designed from scratch for unified memory architecture. Led by ex-DeepMind/Google Brain engineers Awni Hannun and team at Apple's MLR group.

How Autograd Actually Works

A computational graph, walked in reverse.

We've been calling derivatives "the chain rule applied to a chain of functions." A real neural network isn't a chain — it's a DAG (directed acyclic graph). Autograd is the engineering that makes the chain rule work on DAGs, automatically, for any program you write.

The mental model in three sentences:

  1. Forward pass: as you compute, autograd records every op into a graph. Each output tensor has a grad_fn pointer that knows the local derivative of that op.
  2. Backward pass: call .backward() on a scalar (the loss). Autograd walks the graph in reverse topological order, multiplying local derivatives via the chain rule.
  3. Leaves: the only tensors that keep a gradient are the leaves you marked with requires_grad=True. They get the accumulated dloss/dx deposited on x.grad.

That's it. Everything below is detail.

A worked DAG with a branch

Let's do a real walk on a graph with two inputs and a branch, so you see how gradients accumulate.

Forward. Define a = 2, b = 3. Compute:

c = a · b d = a + b e = c · d L = e²

That's a DAG (not a chain) because a is used twice (in c and d), and so is b. Here it is, with values plugged in:

Forward pass — values flow left → right a = 2 b = 3 c = a·b = 6 d = a+b = 5 e = c·d = 30 L = e² = 900 leaf with requires_grad=True scalar loss
Forward values: c = 6, d = 5, e = 30, L = 900.

Backward. Goal: dL/da and dL/db. Walk the graph in reverse, multiplying local derivatives. Start at the output and seed dL/dL = 1. For every node, compute its local derivative w.r.t. each input, multiply by the gradient already arrived at the output side, and pass the result back along the edge. At any node where multiple paths meet, sum the contributions.

StepNodeLocal derivative(s)Incoming gradOutgoing grad
0LseeddL/dL = 1
1L = e²dL/de = 2e1dL/de = 2·30 = 60
2e = c·dde/dc = d, de/dd = c60dL/dc = 60·5 = 300, dL/dd = 60·6 = 360
3ac = a·bdc/da = b, dc/db = a300contribution to dL/da = 300·3 = 900; to dL/db = 300·2 = 600
3bd = a+bdd/da = 1, dd/db = 1360contribution to dL/da = 360·1 = 360; to dL/db = 360·1 = 360
4accumulate at leavessum the two pathsdL/da = 900 + 360 = 1260, dL/db = 600 + 360 = 960

Same picture, drawn with the gradients flowing backward:

Backward pass — gradients flow right → left, accumulate at leaves a .grad=1260 b .grad=960 c grad=300 d grad=360 e grad=60 L grad=1 ·d=5 ·c=6 ·2e=60 accumulated from 2 paths
Backward walks reverse-topological order; gradients accumulate at branch points (the multivariate chain rule).

Verify in PyTorch — same numbers fall out:

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a * b
d = a + b
e = c * d
L = e ** 2
L.backward()
print(a.grad)   # tensor(1260.)
print(b.grad)   # tensor(960.)

The two takeaways:

  1. grad_fn is per-tensor. Each non-leaf tensor stores a pointer to the function that created it. c.grad_fn is MulBackward, d.grad_fn is AddBackward, and so on. backward() follows these pointers.
  2. Gradients accumulate at branch points. When a is consumed by both c and d, each consumer pushes its own gradient back, and a.grad is the sum. This is just the multivariate chain rule: dL/da = (dL/dc)(dc/da) + (dL/dd)(dd/da).

Why grads accumulate (and why you call zero_grad)

Autograd accumulates into .grad across calls, too — not just at branch points within one backward pass. This is on purpose: it lets you do gradient accumulation across mini-batches when you can't fit a big batch on one GPU. But if you forget to clear .grad before each step, your gradients are stale + summed forever.

opt.zero_grad()        # or: for p in params: p.grad = None
loss.backward()
opt.step()

Jacobian-vector products — what backward really computes

For a single scalar loss, what we want is the gradient vector. But internally autograd doesn't materialize huge Jacobians; for each op it only knows how to compute a vector-Jacobian product (VJP, sometimes called "backward op"): given the upstream gradient as a vector, push it back through the local op without ever forming the full Jacobian matrix.

This is why backward is roughly the same cost as forward, instead of O(parameters × outputs) — Jacobians can be enormous (billions × billions). VJPs sidestep that. We don't need the math here; just remember: every op has a forward kernel and a paired VJP kernel. When you write a custom op (Days 17+ for CUDA kernels), you write both.

Inference: turn autograd off

During inference, there are no gradients to compute. Every op in the forward pass would normally build a graph node, save input tensors for the backward pass, and bump version counters. All of that is wasted work when you only want a prediction. Skip it with two context managers:

# PyTorch
with torch.no_grad():
    out = model(x)              # no graph recorded; less memory; faster

# Even faster — disables version counters too:
with torch.inference_mode():
    out = model(x)

What's the difference? no_grad stops graph recording but still tracks tensor versions (so in-place ops stay safe inside the context). inference_mode disables everything — version counters, view tracking, the whole autograd machinery — and is strictly cheaper. Use inference_mode for production inference; use no_grad only when you need to do in-place ops or share tensors with code that checks versions.

full autograd torch.no_grad() torch.inference_mode() grad_fn recorded ✗ off ✗ off version counters ✗ off view tracking ✗ off extra activation memory saved for backward not saved not saved use for training eval / in-place ops production inference
inference_mode disables all autograd overhead — the cheapest non-model change you can make to speed up inference. Always pair it with model.eval().

This is one of the cheapest performance wins in the whole stack — and the existence of inference_mode is part of why an "inference engine" can be 2-3× faster than a naïve forward call.

Autograd as a concept dates to the 1960s under the name "reverse-mode automatic differentiation." It was popularized in deep learning by HIPS autograd (2014) — built by students at Harvard's Intelligent Probabilistic Systems group. They proved you could write Python that looked like NumPy and get gradients for free.

PyTorch in More Detail

Modules, parameters, buffers, state_dict.

import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Tensors ---
x = torch.tensor([1.0, 2.0, 3.0])             # from list
x = torch.zeros(3, 4)                          # zeros
x = torch.ones(3, 4)                           # ones
x = torch.randn(3, 4)                          # standard normal
x = torch.arange(10).reshape(2, 5)             # 0..9 reshaped
x = torch.empty(3, 4)                          # uninitialized (faster, garbage values)

# --- Devices ---
device = ("cuda" if torch.cuda.is_available()
          else "mps" if torch.backends.mps.is_available()
          else "cpu")
x = x.to(device)

Modules: the LEGO blocks

Every layer, sub-network, or model in PyTorch is an nn.Module. Modules own three things: parameters (learnable tensors), buffers (non-learnable state that should still be saved/loaded — e.g., running batch-norm stats), and submodules.

class TwoLayer(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.fc1 = nn.Linear(d_in, d_hidden)         # submodule
        self.fc2 = nn.Linear(d_hidden, d_out)        # submodule
        self.register_buffer('step', torch.zeros(1)) # buffer (saved, not trained)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

m = TwoLayer(128, 256, 10).to(device)

for name, p in m.named_parameters():
    print(name, tuple(p.shape), p.requires_grad)
# fc1.weight (256, 128) True
# fc1.bias   (256,)     True
# fc2.weight (10, 256)  True
# fc2.bias   (10,)      True

# Save/load
torch.save(m.state_dict(), 'model.pt')
m2 = TwoLayer(128, 256, 10)
m2.load_state_dict(torch.load('model.pt'))

state_dict() is just an OrderedDict of parameter and buffer tensors keyed by name. It's the Hugging Face checkpoint format underneath all the JSON. When we load LLaMA weights on Day 27, we'll be doing exactly this lookup.

.train() vs .eval()

Some layers (Dropout, BatchNorm, LayerNorm with running stats) behave differently in train vs eval. Toggle with:

m.train()   # enable dropout, update batchnorm stats
m.eval()    # disable dropout, freeze batchnorm stats

This is independent of no_grad. Real inference: m.eval() and torch.inference_mode().

A complete training step

opt = torch.optim.AdamW(m.parameters(), lr=1e-3)

for step in range(100):
    x = torch.randn(8, 128, device=device)
    target = torch.randint(0, 10, (8,), device=device)
    logits = m(x)
    loss = F.cross_entropy(logits, target)

    opt.zero_grad()
    loss.backward()
    opt.step()

The four lines zero_grad → forward → backward → step are the heartbeat of every PyTorch training loop ever written. We'll see them again on Days 4, 9, 10.

What PyTorch is doing under the hood

PyTorch is eager and dynamic: every Python line runs immediately, and the graph is rebuilt every forward pass. The dispatcher routes each op (e.g., add) to a backend-specific kernel: CPU, CUDA, MPS, XLA. Most kernels live in ATen (a C++ tensor library); higher-level orchestration is in Python. Autograd is implemented as a "tape": every op pushes a grad_fn onto a thread-local list; backward() pops them in reverse.

Two consequences worth knowing:

  • Dynamic graphs make debugging easy. Standard Python control flow (if, for) just works.
  • They prevent some optimizations. torch.compile (Day 11+) introduces a tracer that captures a static graph, fuses kernels, and can be 1.5-3× faster. Worth knowing it exists; we'll use it later.

PyTorch was open-sourced by Facebook AI Research in October 2016. Before that, the dominant framework was Theano (2007–2017), then briefly TensorFlow 1.x (which made you build static graphs ahead of time). PyTorch's big innovation: dynamic graphs that are built as you run code.

MLX in More Detail

Lazy graphs and function transformations.

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# --- Arrays (no .device, no .to()) ---
x = mx.array([1.0, 2.0, 3.0])
x = mx.zeros((3, 4))
x = mx.random.normal((3, 4))

# --- Lazy evaluation: ops build a graph, don't compute yet ---
y = x * 2 + 1                  # nothing has been computed
mx.eval(y)                     # forces evaluation
print(y)                       # also forces evaluation

Lazy eval is the single biggest mental shift coming from PyTorch. In MLX, y = x * 2 + 1 does not run x * 2 and then + 1. It records "compute this expression," and only on mx.eval(y) (or anything that must read the value, like print or numpy()) does the runtime walk the graph, fuse what it can into a single GPU kernel, and execute. This is how MLX gets a lot of its speed on Apple Silicon.

Function transformations: grad, value_and_grad, vmap, compile

MLX's autograd is functional: grad(f) returns a new function that computes derivatives. There is no .backward() on the array; there is no .grad attribute either. Gradients are returned as data structures (pytrees) that mirror your model.

def loss_fn(params, x, target):
    pred = params['W'] @ x + params['b']
    return mx.mean((pred - target) ** 2)

grad_fn = mx.grad(loss_fn)              # new function: returns dloss/dparams
grads = grad_fn(params, x, target)

For models, nn.value_and_grad returns both the loss and the grads in one shot:

def loss_fn(model, x, target):
    return mx.mean((model(x) - target) ** 2)

loss_and_grad_fn = nn.value_and_grad(linear, loss_fn)
opt = optim.AdamW(learning_rate=1e-3)

for _ in range(100):
    x = mx.random.normal((8, 128))
    target = mx.random.normal((8, 64))
    loss, grads = loss_and_grad_fn(linear, x, target)
    opt.update(linear, grads)
    mx.eval(linear.parameters(), opt.state)   # force the update

Other transformations (we'll touch these on Days 19 and 21):

  • mx.vmap(f) — auto-batch a function written for a single example.
  • mx.compile(f) — capture and fuse the graph; equivalent in spirit to torch.compile.
  • mx.jvp(...) / mx.vjp(...) — explicit forward/reverse-mode building blocks.

Lazy eval tax: when to call mx.eval

Lazy eval is great for fusion, but it can quietly grow the deferred graph if you never force evaluation. Two rules:

  1. At the end of every training step, force eval on parameters and optimizer state: mx.eval(model.parameters(), opt.state). This caps the graph size.
  2. When timing, force eval before stopping the timer. Otherwise you measure graph construction, not execution.
import time
x = mx.random.normal((4096, 4096))
t0 = time.time()
y = x @ x.T          # graph built, not run
mx.eval(y)           # actually run on GPU
print(time.time() - t0)

Unified memory implications

Because there's no PCIe to cross, MLX skips the entire .to(device) / .cuda() / .cpu() dance. The price is that MLX is Apple-Silicon-only — there is no CUDA backend.

PyTorch ↔ MLX

Translation table and design differences.

The two frameworks share ~90% of their surface area. The 10% difference is architectural. Understanding it saves hours of debugging when you port code or read source.

Core design comparison

DimensionPyTorchMLX
Evaluation modelEager — every op runs immediately as Python executesLazy — ops build a graph; execution deferred until mx.eval() or a value is needed
GraphRebuilt every forward pass (dynamic). Easy to debug.Captured lazily. Runtime can fuse ops before executing.
Autograd styleTape-based: each tensor stores a grad_fn; call loss.backward()Functional: mx.grad(f) returns a new function; grads returned as pytree
Memory modelDiscrete (NVIDIA): CPU RAM ↔ GPU HBM via PCIeUnified (Apple Silicon): CPU and GPU share one address space
Device fieldtensor.device"cpu", "cuda", "mps"No device field. All arrays are in unified memory.
Zero-grad neededYes — opt.zero_grad() before each backwardNo — functional; each mx.grad call is fresh
Compile/fusetorch.compile(model) — traces and fusesmx.compile(f) — captures the lazy graph explicitly
Save weightstorch.save(m.state_dict(), 'w.pt')mx.save_safetensors('w.safetensors', m.parameters())
Target hardwareNVIDIA, AMD, Intel, CPUApple Silicon only

API translation cheat sheet

When porting code, this is the cheat sheet:

ConceptPyTorchMLX
Tensor typetorch.Tensormx.array
Random normaltorch.randn(3, 4)mx.random.normal((3, 4))
Move to GPUx.to('cuda')(no-op — unified memory)
MatmulA @ BA @ B
Reshapex.reshape(...) / .view(...)x.reshape(...)
Reductionx.sum(dim=-1)x.sum(axis=-1)
einsumtorch.einsum(...)mx.einsum(...)
Module basenn.Modulenn.Module
Linear layernn.Linear(d_in, d_out)nn.Linear(d_in, d_out)
Trainable params iterm.parameters()m.parameters() (returns dict)
Loss scalarloss.item()loss.item()
Gradientsloss.backward() then p.gradgrad_fn(...) returns grads; no .grad
Optimizer stepopt.step() after loss.backward()opt.update(model, grads)
Zero gradsopt.zero_grad()not needed (functional)
No-grad inferencewith torch.no_grad(): ...mx.stop_gradient(x) or just don't call grad
Force computealways eagermx.eval(...)
Compile / fusetorch.compile(m)mx.compile(f)
Save statetorch.save(m.state_dict(), 'p.pt')mx.save_safetensors('p.safetensors', m.parameters())

Rule of thumb. If a PyTorch line uses device, drop it for MLX. If a PyTorch loop uses loss.backward() + opt.step(), replace with loss, grads = value_and_grad_fn(...); opt.update(model, grads); mx.eval(...).

Profiling and Timing

Sync, or you measure nothing.

GPU work is asynchronous. The CPU enqueues kernels; the GPU executes them later. If you time naively, you measure how long it took to launch the work, not to do it.

PyTorch / CUDA

import time, torch
x = torch.randn(4096, 4096, device='cuda')
torch.cuda.synchronize()
t0 = time.time()
y = x @ x.T
torch.cuda.synchronize()        # wait for GPU
print(time.time() - t0)

MLX

Force eval the result before stopping the timer:

import time
import mlx.core as mx
x = mx.random.normal((4096, 4096))
t0 = time.time()
y = x @ x.T
mx.eval(y)
print(time.time() - t0)

Profilers: PyTorch Profiler, Nsight Systems (NVIDIA), mx.profiler (MLX). We'll use these heavily in Week 3.

"80% of ML bugs are shape mismatches. Add print(x.shape) liberally."

A practical truth · Day 2
Exercise

Eight tasks to build muscle memory.

  1. Tensor basics. Create a (2, 3, 4) tensor of standard-normal values. Reshape to (6, 4) and (2, 12). Take the mean along the last axis. Try x.sum(dim=(0, 2)) — predict the shape, then verify.
  2. Strides and views. x = torch.arange(12).reshape(3, 4). Print x.stride(). y = x.T; print y.stride() and y.is_contiguous(). Try y.view(12) — read the error. Then y.reshape(12) — why does that work?
  3. Broadcasting. a = torch.arange(12).reshape(3, 4). Add a row vector [10, 20, 30, 40] to every row. Add a column vector [[100], [200], [300]] to every column. Try (torch.zeros(3, 4) + torch.zeros(2, 4)) — read the error.
  4. einsum. Q, K = torch.randn(2, 5, 8), torch.randn(2, 5, 8). Compute attention scores (2, 5, 5) with einsum. Verify via Q @ K.transpose(-1, -2).
  5. Autograd by hand vs by machine. Use the worked DAG above (a=2, b=3). Run the PyTorch snippet, confirm 1260, 960. Now change L = e² to L = e³; predict dL/da and dL/db on paper, then verify.
  6. Branching gradient accumulation. x = torch.tensor(2.0, requires_grad=True); y = x * x + x; y.backward(); print(x.grad). Predict before running.
  7. inference_mode timing. Run a forward pass through nn.Linear(4096, 4096) on a 4096×4096 input with and without torch.inference_mode(). Use cuda.synchronize if on CUDA. Compare wall time and peak memory.
  8. PyTorch ↔ MLX port (Mac only). Translate a tiny nn.Linear → ReLU → nn.Linear → MSE training step from PyTorch to MLX. Use the translation table above. Run for 100 steps; watch loss go down in both.
Further Reading

Go deeper.

Hand-picked references for this lesson. Free where possible. Books and papers where the depth is irreplaceable.

YouTube · Free · 2.5 hrs

Karpathy — building micrograd

Builds autograd from scratch in pure Python. Watch this if anything in this lesson felt magical.

Watch on YouTube
GitHub · ~150 lines

karpathy/micrograd

The whole concept of automatic differentiation in 150 lines you can read in one sitting.

Open repo
Blog · Reference

Edward Yang — PyTorch Internals

Best one-hour read on how PyTorch is built. Tensors, autograd engine, dispatcher, ATen.

Read post
Book · Manning

Deep Learning with PyTorch

Stevens, Antiga, Viehmann. Best PyTorch-focused book — covers tensors, autograd, training loop in depth.

Open page
Apple · Documentation

MLX Official Docs

API reference, examples, and conceptual guides for the MLX framework.

Open docs
GitHub · Apple

mlx-examples

Production-ready MLX implementations: LLM inference, Whisper, Stable Diffusion, and more.

Open repo
Paper · Survey

Baydin et al. — Automatic Differentiation

The canonical survey paper on AD in machine learning. Forward mode, reverse mode, dual numbers, the lot.

Read on arXiv
Library · Tutorial

einops — Pretty Tensor Ops

Readable tensor rearrangement notation. Works with PyTorch, NumPy, MLX, JAX, TensorFlow.

Open docs
Tutorial · Reference

Tim Rocktäschel — Einsum Is All You Need

The canonical einsum tutorial. Worth committing to memory.

Read post
Apple · Blog

Apple — MLX intro post

December 2023 launch announcement explaining the design philosophy behind MLX.

Read post
Docs · PyTorch

PyTorch Autograd Mechanics

The official deep-dive on how PyTorch's autograd engine works: grad_fn, leaf tensors, accumulation, no_grad, inference_mode, and custom backward functions.

Read docs
Docs · MLX

MLX — Function Transforms

How mx.grad, mx.vmap, mx.vjp, mx.jvp, and mx.compile work. The functional autograd design that makes zero_grad unnecessary.

Read docs
Self-Check

Ten questions before moving on.

Close the page and answer from memory. If you can't, re-read the relevant section.

  1. A (3, 4) tensor x is transposed to get y = x.T. What are y.shape, y.stride(), and y.is_contiguous()? Why does y.view(12) raise a RuntimeError? How do you flatten it anyway?
  2. A 13B-parameter model is loaded in bfloat16. How many bytes of VRAM does it occupy for weights alone? Show the arithmetic. What is the formula in general?
  3. Why does bf16 preserve fp32's dynamic range while fp16 does not? Give a one-sentence answer in terms of exponent bits.
  4. Shapes (8, 1, 4) and (3, 4) are added. What is the output shape? Walk the broadcasting rules step by step. Now try (8, 1, 4) + (2, 4) — does it work? Why or why not?
  5. In the worked DAG (a=2, b=3), walk both paths that contribute to a.grad. What are the two numerical contributions, and why do they add rather than one overwriting the other?
  6. Why does tensor.grad accumulate across multiple .backward() calls, not just within one? Name a situation where this is a feature and one where it's a bug.
  7. What is a grad_fn? Which tensors have one? Which do not? What does .backward() do with it?
  8. What two things does torch.inference_mode() disable that torch.no_grad() does not? When would you still prefer no_grad?
  9. If you have a (B, T, D) tensor and want a (B, D) tensor of mean activations across the time dimension, write the call in PyTorch and in MLX (note the different keyword for axis).
  10. Name three concrete design differences between PyTorch and MLX. For each one, explain the downstream consequence for how you write a training loop.