LLM Inference Engineer · Day 09
Day 09 · Week 2 · Training & Architectures
🐣

Build a Tiny GPT, End to End

Today the Week 2 capstone: you assemble all the pieces from Days 5–8 into a trainable character-level GPT, train it on TinyShakespeare on your own machine, and sample text that improves from random noise to Shakespearean cadence. Along the way you will learn to read the loss curve, tune sampling, estimate parameter counts and FLOPs, and understand exactly why serving this model token-by-token is expensive — and how the KV cache (Week 3) fixes it.

Time~210 min
DifficultyMedium-Hard
PrerequisiteDays 5–8
Why This Lesson

Every piece you built this week snaps together, and the model starts learning.

Day 5 gave you tokenization and embeddings. Day 6 gave you attention. Day 7 gave you the Transformer block. Day 8 gave you the pre-training objective and scaling laws. Today you put them all in a room together, add a training loop, and watch a model learn language for the first time. This is the moment that makes the entire preceding week feel real.

We deliberately stay small: a roughly 10M-parameter GPT on roughly 1MB of Shakespeare. The model trains in minutes on a laptop GPU, in tens of minutes on Apple Silicon, and in a few hours on CPU. Small scale is a feature — you can run the loop twenty times, break things intentionally, and build the intuition that no amount of reading can substitute for. The loop you write today is exactly the loop that trains GPT-3; the only difference is bigger numbers and more machines.

We also spend serious time on inference cost. After training, you will count FLOPs per token, measure memory footprint, and understand why naive autoregressive decoding scales as O(T²). That analysis sets up the KV cache — the central optimization of Week 3 — and connects the training capstone to the inference theme of the whole course.

Learning objectives

  1. Trace the full tensor flow from token IDs through embeddings, position encodings, stacked pre-norm decoder blocks, final norm, and LM head — with exact shapes at every step.
  2. Understand each hyperparameter in GPTConfig and its effect on model size, memory, and quality.
  3. Read the loss-at-init sanity check: know what it is, why it works, and use it as your first debugging tool.
  4. Write a complete training loop with AdamW, gradient clipping, and a cosine LR schedule with linear warmup.
  5. Implement autoregressive generation with greedy, temperature, top-k, and top-p sampling; understand the tradeoffs.
  6. Read a loss curve and diagnose the four main failure modes.
  7. Estimate FLOPs per token and memory footprint for your model; articulate why naive generation is O(T²) and what the KV cache does.
  8. Save and reload a checkpoint for both resuming training and pure inference.
Full Architecture

From token IDs to logits: the end-to-end tensor journey.

Before writing a single line of code, draw the whole graph in your head. A batch of token-ID sequences enters — shape (B, T) where B is batch size and T is context length. It exits as a probability distribution over the vocabulary at every position — shape (B, T, V). Everything in between is a progression of tensor transformations, each with a known shape. Understanding the shapes makes bugs obvious.

Token IDs (B, T) int64 tok_emb (V, D) lookup pos_emb (T, D) lookup sum + dropout (B,T,D) Block × N LN → Attn → res LN → MLP → res (B,T,D) shape unchanged through all blocks ln_f final LayerNorm (B,T,D) LM head Linear(D,V) weight tied to tok_emb (B,T,V) logits (B,T,V) float32 weight tying: head.W = tok_emb.W (saves V×D params) B = batch size (e.g. 64) T = context length (e.g. 256) D = n_embd (e.g. 384) V = vocab size (e.g. 65) N = n_layer (e.g. 6) The block stack is the model's depth; D is its width. Causal mask inside attention ensures position t sees only positions 0..t.
The full GPT tensor flow. Token IDs (B,T) are looked up in two embedding tables and summed; the result travels through N pre-norm decoder blocks — shape unchanged throughout — then a final LayerNorm and a tied linear head to produce logits (B,T,V). Every box is differentiable; the whole graph is trained end-to-end by backprop.

Pre-norm vs post-norm

The original 2017 Transformer paper put LayerNorm after the residual addition (post-norm). GPT-2 and everything since puts it before the sublayer (pre-norm). The difference matters. In post-norm, the signal going into the residual path may be unnormalized, which makes deep networks harder to train without careful learning-rate warmup. Pre-norm normalizes before the sublayer, so the residual path always carries the raw residual stream, which remains well-scaled even at depth 48 or 96. Pre-norm is now the standard; expect to see it everywhere.

Weight tying

The LM head is a linear map from the model dimension D to the vocabulary V. The embedding table is also a matrix of shape (V, D). Weight tying shares these two matrices: the head literally uses the transpose of the embedding lookup. This saves V×D parameters (for our config: 65×384 = 24,960 — a small saving; for a 128k-token BPE vocab with D=4096, it saves 500M parameters). More importantly, it constrains the model to use the same geometry for encoding and decoding tokens, which improves sample quality, especially with small vocabularies.

The full model code

from dataclasses import dataclass
import torch, torch.nn as nn, torch.nn.functional as F

@dataclass
class GPTConfig:
    vocab_size: int   = 65     # set from data
    block_size: int   = 256    # context length T
    n_layer:    int   = 6      # depth
    n_head:     int   = 6      # attention heads
    n_embd:     int   = 384    # d_model = head_dim * n_head
    dropout:    float = 0.1

class CausalSelfAttention(nn.Module):
    def __init__(self, c):
        super().__init__()
        assert c.n_embd % c.n_head == 0
        self.n_head, self.n_embd = c.n_head, c.n_embd
        self.qkv  = nn.Linear(c.n_embd, 3 * c.n_embd, bias=False)
        self.proj = nn.Linear(c.n_embd, c.n_embd, bias=False)
        self.drop = nn.Dropout(c.dropout)
        self.p    = c.dropout

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        dh = C // self.n_head
        def split_heads(t):
            return t.view(B, T, self.n_head, dh).transpose(1, 2)
        q, k, v = split_heads(q), split_heads(k), split_heads(v)
        # PyTorch 2.0+ fused attention (FlashAttention when available)
        y = F.scaled_dot_product_attention(
                q, k, v, is_causal=True,
                dropout_p=self.p if self.training else 0.0)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.drop(self.proj(y))

class MLP(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.fc   = nn.Linear(c.n_embd, 4 * c.n_embd)
        self.proj = nn.Linear(4 * c.n_embd, c.n_embd)
        self.drop = nn.Dropout(c.dropout)
    def forward(self, x):
        return self.drop(self.proj(F.gelu(self.fc(x))))

class Block(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.ln1, self.ln2 = nn.LayerNorm(c.n_embd), nn.LayerNorm(c.n_embd)
        self.attn, self.mlp = CausalSelfAttention(c), MLP(c)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))   # pre-norm residual
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.cfg     = c
        self.tok_emb = nn.Embedding(c.vocab_size, c.n_embd)
        self.pos_emb = nn.Embedding(c.block_size, c.n_embd)
        self.drop    = nn.Dropout(c.dropout)
        self.blocks  = nn.ModuleList([Block(c) for _ in range(c.n_layer)])
        self.ln_f    = nn.LayerNorm(c.n_embd)
        self.head    = nn.Linear(c.n_embd, c.vocab_size, bias=False)
        self.head.weight = self.tok_emb.weight   # weight tying

    def forward(self, idx, targets=None):
        B, T = idx.shape
        pos   = torch.arange(T, device=idx.device)
        x     = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
        for blk in self.blocks:
            x = blk(x)
        logits = self.head(self.ln_f(x))         # (B, T, V)
        loss   = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss
Config & Parameter Count

Every knob in GPTConfig, what it does, and what it costs.

A GPTConfig dataclass is not just convenience — it is a contract. Every architectural decision lives in one place, is passed to every submodule, and is saved in the checkpoint alongside the weights. You cannot load weights without knowing the config. Here is what each field controls.

FieldDefaultEffect on modelEffect on memory / speed
n_layer6Depth (number of blocks). The main lever for quality.Linear in memory; roughly linear in FLOPs/token.
n_embd384Model width D. Quadratic in parameters per block.Dominates parameter count and activation memory.
n_head6Head dimension = D/n_head = 64. More heads = more parallel views of the sequence.No effect on total params; affects fused-attention efficiency.
block_size256Context length T. The maximum sequence length at train time.KV matrices scale as T×D; attention compute as T²×D.
dropout0.1Regularization. Drop 10% of activations during training.Zero at inference (disabled by model.eval()).
vocab_size65Character count for this corpus. BPE models use 50k–128k.Affects embedding table and LM head size.

Parameter count breakdown

You should be able to compute the parameter count for any GPT config on the back of an envelope. The formula: one block has an attention sublayer and an MLP sublayer.

Attention: Q,K,V,O projections = 4 × D² (no bias, all square) MLP: fc + proj = D×4D + 4D×D = 8D² One block: 4D² + 8D² = 12D² N blocks: N × 12D² Embeddings: vocab × D + block_size × D (both; head tied to tok_emb, not counted twice) Total ≈ N × 12D² + (vocab_size + block_size) × D With N=6, D=384, V=65, T=256: blocks: 6 × 12 × 384² = 10,616,832 tok_emb: 65 × 384 = 24,960 pos_emb: 256 × 384 = 98,304 LN params: ~2 × 2 × D × 6 blocks + final ≈ 18,432 (small) Total: ≈ 10.76 M parameters
Parameter breakdown — 10.76M total (N=6, D=384, V=65) Per block Attention (Q,K,V,O): 4 × D² = 589,824 params MLP (fc + proj): 8 × D² = 1,179,648 params ×6 blocks LN (γ,β) ~3,072 / block Embeddings tok_emb 24,960 pos_emb 98,304 ← head tied to tok_emb: no extra params
Parameter breakdown for the default config. The MLP sublayer (8D²) uses 2× the parameters of attention (4D²), making MLP the dominant cost per block. Embeddings are negligible for a 65-character vocab; for a 128k BPE vocabulary they would be ~200M — the largest component of many large models.

The rule of thumb is 12D² per block. Double D, and parameters quadruple. Halve n_layer, and parameters halve. This is why scaling laws (Day 8) say that D and n_layer should be scaled together — neither dimension dominates if you scale jointly.

Initialization & Sanity Check

The single most valuable debugging check in all of deep learning.

Before you run your first training step, there is one check you must always do: verify that the loss at initialization is approximately ln(vocab_size). This is not a nice-to-have — it is the first line of defense against a wide class of bugs. Here is the intuition.

A freshly initialized model with Xavier/Kaiming weights will assign roughly equal logits to each vocabulary token. The softmax of equal logits is a uniform distribution. For a uniform distribution over V tokens, the cross-entropy loss is exactly:

H(uniform) = ln(V) For our char-level vocab: ln(65) ≈ 4.174 For GPT-2 50k BPE vocab: ln(50257) ≈ 10.82 For LLaMA 32k vocab: ln(32000) ≈ 10.37

If your step-0 loss is far from ln(vocab_size), you have a bug before any training has happened. Common culprits: a target-shift bug (using x as targets instead of x[:, 1:]), a data normalization error that produces constant inputs, a weight initialization that is wildly off-scale, or (for BPE tokenizers) a mismatch between the tokenizer's vocabulary size and the model's vocab_size.

This check is called out explicitly by Karpathy in his "A Recipe for Training Neural Networks" and has probably saved thousands of hours of wasted training compute. A wrong init loss means something is broken structurally — no amount of training will fix it. Always check before you train.

import math

xb, yb = get_batch(train_data)
with torch.no_grad():
    _, loss0 = model(xb, yb)

expected = math.log(cfg.vocab_size)
print(f"loss at init: {loss0.item():.4f}   expected: {expected:.4f}")
assert abs(loss0.item() - expected) < 0.5, \
    f"Init loss is off! Check model/data pipeline."
print("Sanity check passed.")

The tolerance of 0.5 nats is generous — a well-initialized model usually lands within 0.1 of the theoretical value. If you are off by more than 0.5, investigate before training.

Weight initialization details

PyTorch's default initialization (Kaiming uniform for Linear, normal for Embedding) works reasonably well. The nanoGPT convention, borrowed from GPT-2, applies an additional scaling to the output projections of each residual block: multiply the initial weights by 1/sqrt(2 * n_layer). This prevents the residual stream from growing uncontrollably with depth at initialization — especially important when n_layer is large. Our small 6-layer model trains fine without it, but it is good practice to know about.

The Training Loop

Forward, loss, backward, step — with clipping, warmup, and cosine decay.

The core loop is four operations. Everything else — learning rate schedules, gradient clipping, mixed precision, gradient accumulation — is engineering layered on top of those four operations. Let us build up from the bare minimum to something production-adjacent.

get_batch (x, y) random crop forward model(x,y) → loss backward loss.backward() → grads clip grads max_norm=1.0 opt.step() update θ zero_grad set_to_none = True repeat for max_steps every eval_interval: estimate train/val loss
The complete training loop. zero_grad(set_to_none=True) frees the gradient memory rather than zeroing it in place — slightly faster. Gradient clipping happens after backward but before step. The eval branch runs every few hundred steps to track generalization without affecting training state.

AdamW: the standard optimizer

AdamW uses per-parameter adaptive learning rates based on running estimates of the first moment (mean of gradients) and second moment (mean of squared gradients). The "W" stands for decoupled weight decay — it applies L2 regularization to the parameters directly rather than through the gradient, which is the correct formulation. The standard GPT settings:

lr = 3e-4 (peak learning rate; cosine schedule decays from here) betas = (0.9, 0.95) # β₁ tracks gradient mean, β₂ tracks squared-gradient mean eps = 1e-8 # numerical stability; rarely tuned weight_decay = 0.1 # only applied to weight matrices, not biases or LN params

The weight_decay of 0.1 is applied selectively: weight matrices get it; bias terms and LayerNorm parameters do not. This is standard practice — biases and LN parameters are already small scalars and do not need shrinkage. In PyTorch:

decay, no_decay = [], []
for pn, p in model.named_parameters():
    if p.ndim < 2:              # bias, LN gamma/beta
        no_decay.append(p)
    else:
        decay.append(p)         # weight matrices

opt = torch.optim.AdamW([
    {"params": decay,    "weight_decay": 0.1},
    {"params": no_decay, "weight_decay": 0.0},
], lr=3e-4, betas=(0.9, 0.95))

Gradient clipping

Gradient clipping bounds the L2 norm of the gradient vector before the optimizer step. If the norm exceeds max_norm, all gradients are scaled down proportionally. This prevents a single bad batch from causing a large destabilizing parameter update — especially important early in training when gradients can be chaotic.

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

A clip value of 1.0 is the standard; higher values (5.0) are used for some RNN-style training. If you see occasional loss spikes but overall good training, clipping (or reducing LR) is often the fix.

Learning rate schedule: warmup + cosine decay

Constant learning rate works for small models, but the production practice is a two-phase schedule: linear warmup from near zero to peak LR over the first few hundred steps, then cosine decay back down to a small fraction of peak LR. Warmup prevents the large early updates that happen when AdamW's running estimates are unreliable (they are initialized to zero). Cosine decay allows aggressive use of the full LR during most of training, then a smooth landing.

step LR warmup_steps linear warmup lr_peak = 3e-4 lr_min = 3e-5 cosine decay lr = lr_min + 0.5*(lr_peak-lr_min)*(1 + cos(π·progress))
The warmup + cosine schedule used by most production LLM training runs. Warmup prevents early instability; cosine decay provides a smooth, principled annealing. The minimum LR is typically 10× smaller than the peak. For a 5,000-step run, 100–200 warmup steps is typical.
def get_lr(step, warmup_steps, max_steps, lr_peak, lr_min):
    if step < warmup_steps:
        return lr_peak * step / warmup_steps
    if step > max_steps:
        return lr_min
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return lr_min + 0.5 * (lr_peak - lr_min) * (1 + math.cos(math.pi * progress))

# Apply inside the loop:
for step in range(max_steps):
    lr = get_lr(step, warmup_steps=100, max_steps=5000,
                lr_peak=3e-4, lr_min=3e-5)
    for param_group in opt.param_groups:
        param_group["lr"] = lr
    ...

Mixed precision (brief mention)

Training in float16 or bfloat16 roughly halves memory and speeds up matrix multiplications on modern hardware. The standard approach is PyTorch's torch.autocast context manager, which casts eligible operations to the lower-precision type automatically, while keeping a master copy of weights in float32 for the optimizer step. Day 10 covers this in detail; for today's small run it is optional.

The complete training loop

model = GPT(cfg).to(device)

# Selective weight decay (weight matrices only)
decay, no_decay = [], []
for pn, p in model.named_parameters():
    (no_decay if p.ndim < 2 else decay).append(p)
opt = torch.optim.AdamW([
    {"params": decay,    "weight_decay": 0.1},
    {"params": no_decay, "weight_decay": 0.0},
], lr=3e-4, betas=(0.9, 0.95))

max_steps     = 5000
warmup_steps  = 100
eval_interval = 500
eval_iters    = 50

@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for name, d in [("train", train_data), ("val", val_data)]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xb, yb = get_batch(d)
            _, loss = model(xb, yb)
            losses[k] = loss.item()
        out[name] = losses.mean().item()
    model.train()
    return out

history = []
for step in range(max_steps + 1):
    # LR schedule
    lr = get_lr(step, warmup_steps, max_steps, 3e-4, 3e-5)
    for g in opt.param_groups:
        g["lr"] = lr

    # Periodic eval
    if step % eval_interval == 0:
        m = estimate_loss()
        history.append((step, m["train"], m["val"]))
        print(f"step {step:>5}  train {m['train']:.3f}  val {m['val']:.3f}  lr {lr:.2e}")

    # Training step
    xb, yb = get_batch(train_data)
    _, loss = model(xb, yb)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
Sampling & Generation

Greedy, temperature, top-k, top-p — and why the loop is O(T²).

The model produces logits (V,) at each step. Converting those logits to a single next token is a decision that dramatically shapes the output. You have four main strategies, each with distinct tradeoffs.

StrategyHow it worksEffect on outputWhen to use
GreedyAlways pick argmax(logits).Deterministic, often repetitive. May loop.Debugging, benchmarks where reproducibility matters.
Temperature (τ)Divide logits by τ before softmax. τ<1 sharpens, τ>1 flattens.τ=0 → greedy; τ=1 → model distribution; τ>1 → more random.Creative text (τ=0.7–1.2); use as your main dial.
Top-kZero out all but the k highest-logit tokens, then sample.Prevents sampling very unlikely tokens regardless of τ.k=40–200 is a good default; combine with τ.
Top-p (nucleus)Sort tokens by probability; keep the smallest set whose cumulative probability ≥ p; sample from that set.Adapts the candidate set size to the model's confidence. Wider when uncertain, narrower when confident.p=0.9 or 0.95 is more principled than fixed top-k.

The autoregressive generation loop

Context at each generation step (grows by 1 token per step) t₀ → model → logits[t₀] → sample t₁ t₀ t₁ → model → logits[t₁] → sample t₂ t₀ t₁ t₂ → model → logits[t₂] → sample t₃ At each step t: attention re-processes ALL previous tokens — cost ∝ t² total Without KV cache: generating T tokens costs O(T³) attention operations overall. KV cache (Week 3): store K and V from previous steps; only compute K,V for the new token. This reduces per-step attention cost from O(T) to O(1) — the central inference optimization.
Naive autoregressive generation re-computes the full attention matrix at every step. For a context of length T, each step costs O(T) attention operations; generating T tokens costs O(T²) total attention, or O(T³) overall when you count all tokens generated. The KV cache (Week 3) reduces the per-step marginal cost to O(1) by storing previously computed keys and values.
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
    for _ in range(max_new_tokens):
        # Crop to context window
        idx_cond = idx[:, -model.cfg.block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature        # (B, V), last position only

        # Top-k filter
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float("inf")

        # Top-p (nucleus) filter
        if top_p is not None:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            # Remove tokens with cumulative prob above top_p (shift by 1 to keep first)
            sorted_logits[cumprobs - F.softmax(sorted_logits, dim=-1) > top_p] = -float("inf")
            logits = torch.zeros_like(logits).scatter(1, sorted_idx, sorted_logits)

        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)
    return idx

Run this at several points in training and you will watch quality climb. At step 0: completely random characters. At step 500: spaces and punctuation are roughly correct. At step 2000: word shapes appear. At step 5000: recognizable character names, line breaks, and Shakespearean rhythm. The model will not be good — it is 10M parameters on 1MB of text — but the trajectory is unmistakable and deeply satisfying.

Reading the Loss Curve

A healthy run falls fast, then settles. Learn to recognize the failure modes.

The loss curve is your most powerful diagnostic tool. Every failure mode has a characteristic shape. Here is how to read it.

Healthy train val steep drop → slow flatten LR too high spikes / NaN → lower LR LR too low barely moves → raise LR Overfitting train ↓ val ↑ val rising → add dropout / data Broken pipeline flat from step 0 → check target shift Expected results (tiny Shakespeare) Init loss: ~4.17 (= ln(65)) Val loss after 5k steps: ~1.4–1.6 Perplexity: ~4–5 (exp(1.5) ≈ 4.5) All curves share x = training step, y = cross-entropy loss (nats). Perplexity = exp(loss); lower = better. A ppl of 4.5 means the model is about as uncertain as choosing from 4–5 equally likely options.
Loss curve archetypes. A healthy run (top-left) is your baseline: steep drop, slow convergence, train and val tracking closely. Every other shape is a signal — read it before tweaking hyperparameters.

Expected final loss for a 10M-parameter GPT on TinyShakespeare after 5,000 steps: roughly 1.4–1.6 nats, corresponding to a perplexity of ~4–5. That means the model is on average as uncertain as choosing uniformly among 4–5 equally likely characters. Not great literature, but definitive proof that the model learned the structure of the corpus.

Overfitting on tiny data

TinyShakespeare is ~1MB — small enough for a 10M-parameter model to partially memorize. You will typically see a small but stable train-val gap. If you reduce the dataset to a few kilobytes and train long enough, you can drive train loss to near zero while val loss rises sharply. This is pedagogically useful: it shows exactly what overfitting looks like and why regularization (dropout, weight decay) matters.

Perplexity as a metric

Perplexity is just exp(loss). A loss of 1.5 nats gives perplexity exp(1.5) ≈ 4.5. The interpretation: on average, the model is as uncertain as if it were choosing uniformly among 4–5 characters. A character-level model on English text trained to its limit reaches roughly perplexity 2–3 (because English has low character-level entropy). Our tiny model with limited data sits at 4–5, which is respectable.

Inference Cost

Why is this model expensive to serve, and what does it cost per token?

Training this model once costs some GPU-hours. Serving it for inference — generating one token at a time, possibly for many concurrent users — is a different and ongoing cost. Understanding the cost structure now sets up everything in Weeks 3–4.

FLOPs per forward pass

For a GPT-style model, the approximate floating-point operations for a single forward pass on a sequence of length T is:

FLOPs per forward ≈ 2 × N × (12 × D²× T + 2 × D × T²) The first term (12D²T) covers the linear layers (Q,K,V,O projections + MLP). The second term (2DT²) covers the attention score matrix (T×T at each head, across N layers). For our tiny model (N=6, D=384, T=256): linear layers: 2×6×12×384²×256 ≈ 5.4 billion FLOPs attention: 2×6×2×384×256² ≈ 0.6 billion FLOPs total per forward: ~6 billion FLOPs For inference (generate 1 new token from a length-T context): Each token step re-runs the full forward on T tokens → ~6G FLOPs per token (for T=256) For GPT-3 (N=96, D=12288, T=2048): ~6× 96 × 12288² × 2048 ≈ 175 billion params × 6 ≈ ~1 trillion FLOPs per forward

The general rule of thumb: approximately 6 × (parameter count) FLOPs per token for the linear layers (factor of 2 for multiply-add, factor of 3 for forward + backward in training). For inference, it is roughly 2 × params FLOPs per token from linear layers, plus the quadratic attention cost.

Memory: weights vs activations

For inference, you need to hold two things in memory: the weights (static) and the activations of the current forward pass (dynamic, grows with batch size and sequence length).

ItemOur tiny modelGPT-3 (175B)Notes
Weights (float32)~43 MB~700 GB10.76M × 4 bytes
Weights (float16/bf16)~22 MB~350 GB2 bytes per param
KV cache (float16, T=2048, B=1)~24 MB per layerHundreds of GB2×T×D×n_layer×2 bytes
Activations (training, float32)~several GBTerabytesNeeded for backward pass; not stored at inference

This is why serving large models requires specialized hardware. A 70B-parameter model in float16 needs ~140 GB — more than a single 80GB A100/H100, so it requires multiple GPUs or quantization (an 80GB GPU holds roughly a 30–40B model in fp16). The KV cache alone for a long context can consume as much memory as the weights. Week 3 and 4 cover the techniques — quantization, KV cache management, paged attention, speculative decoding — that make this tractable.

Why naive generation is O(T²) and the KV cache fixes it

In naive generation, to produce token t+1, you run the full forward pass on tokens 0..t. The attention mechanism computes QK^T/sqrt(d_h) for all pairs, which costs O(t) per layer. Summed over all T tokens you generate, total attention cost is O(T²). The KV cache short-circuits this: once you have computed the K and V matrices for positions 0..t-1, you store them. For token t+1, you only compute Q for the new token, and attend over the cached K,V. The marginal cost per new token drops from O(t) to O(1) in the attention sublayer. The catch: you now need memory proportional to T (to store the cache), and the memory bandwidth to read that cache dominates latency — which motivates quantization, multi-query attention (MQA), and grouped-query attention (GQA), all of which you will see in Week 3.

Checkpointing

Save everything you need to resume — or to serve.

A trained model you cannot reload is a model you must retrain. Checkpointing is trivial but discipline-forming: it makes you think clearly about which state is essential for training versus inference, and it foreshadows the weight-loading pipeline in your Day 27 capstone.

# Save: everything needed to resume training.
torch.save({
    "model":     model.state_dict(),
    "optimizer": opt.state_dict(),
    "config":    cfg,
    "step":      step,
    "stoi":      stoi,     # character-to-index map
    "itos":      itos,     # index-to-character map
}, "tinygpt.pt")

# Load for resuming training (need optimizer state + step).
ckpt  = torch.load("tinygpt.pt", map_location=device, weights_only=False)
model = GPT(ckpt["config"]).to(device)
model.load_state_dict(ckpt["model"])
opt.load_state_dict(ckpt["optimizer"])
step  = ckpt["step"]

# Load for inference only (just weights + config).
ckpt  = torch.load("tinygpt.pt", map_location=device, weights_only=True)
model = GPT(ckpt["config"]).to(device)
model.load_state_dict(ckpt["model"])
model.eval()

Save the optimizer state (including AdamW's running moment estimates) if you intend to resume training — without it, the first optimizer step after reload will be wrong because the moment estimates are reset to zero. For pure inference, only the model weights and config are needed. We will return to weight formats (safetensors, sharded checkpoints for models that do not fit in one file) in Week 4.

Hardware Notes

CUDA, Apple Silicon, or CPU — one device line changes everything.

PathDevice stringExpected wall-clock (5k steps)Notes
NVIDIA GPU (Ampere/Hopper)"cuda"~2–5 min (RTX 4090); longer on older GPUsFlashAttention via F.scaled_dot_product_attention; add autocast for bf16.
Apple Silicon (M-series)"mps"~10–20 min on M2/M3MPS backend stable as of PyTorch 2.1. MLX is faster for Apple-native models.
CPU"cpu"Hours for full config; reduce to n_layer=4, n_embd=128, block_size=128Still produces a working model — the lesson survives the smaller scale.
device = ("cuda" if torch.cuda.is_available()
          else "mps" if torch.backends.mps.is_available()
          else "cpu")
print("using device:", device)

If you are on CPU and time-constrained, shrink the config: n_layer=4, n_embd=128, block_size=128, batch_size=32. Run for 1,000 steps instead of 5,000. The val loss will be worse (~2.0) and the samples rougher, but the loop is identical and the learning is real.

Exercises

Eight exercises, all in the companion notebook.

Companion notebook: day-9-tiny-gpt.ipynb.

  1. Verify loss at init. Before training, confirm step-0 loss is within 0.1 of ln(vocab_size). Then break the target shift on purpose (use x as targets instead of x[:, 1:]) and observe the wrong init loss.
  2. Train the full model. Run the complete loop with the cosine schedule and gradient clipping. Plot train and val loss. Confirm val loss lands in 1.4–1.6 after 5k steps.
  3. Sample across training. Checkpoint the model at steps 0, 500, 2000, and 5000. Generate 200 characters from each. Paste side by side and describe the trajectory.
  4. Sampling strategy comparison. Using the fully trained model, generate with (a) greedy, (b) temperature 0.7 + top-k 40, (c) temperature 1.0 + top-p 0.9, and (d) temperature 1.5. Describe how each feels qualitatively.
  5. Overfit on purpose. Set train_data = data[:2000] and train for 3,000 steps. Plot the train vs val gap. At what step does val loss start rising?
  6. Parameter count audit. Use sum(p.numel() for p in model.parameters()) and compare to the 12D² formula. Then compute the FLOPs-per-token estimate. How does it compare to a single A100's ~310 TFLOPS?
  7. Checkpoint round-trip. Save, restart the kernel, reload, and generate 200 characters. Confirm the output matches a model you never unloaded (use the same seed).
  8. TinyStories + BPE (optional). Install tiktoken and the datasets library, then retrain on TinyStories. With the same 10M-parameter model you get coherent short sentences — proof that the architecture scales gracefully to richer data.
Self-Check

Ten questions before moving on.

Close the page and answer from memory. If you can't, re-read the relevant section.

  1. Trace the tensor shape from token IDs (B,T) through every stage of the model to the logits. What shape exits each major component?
  2. What should the loss be at initialization, and why? What does a wrong init loss tell you?
  3. Explain weight tying in one sentence. How many parameters does it save for a vocab of 65 vs 50,000?
  4. What are the four operations in the core training loop, in order? Why does zero_grad use set_to_none=True?
  5. Why does AdamW apply weight decay only to weight matrices, not to biases or LayerNorm parameters?
  6. Sketch the LR schedule used in a production training run. Why is linear warmup needed?
  7. What do temperature, top-k, and top-p each control during sampling? Which is most principled for variable-confidence situations?
  8. Why is naive autoregressive generation O(T²)? What does the KV cache do to reduce this, and what new cost does it introduce?
  9. Name four loss-curve pathologies and the most likely cause of each.
  10. Estimate the parameter count of a model with n_layer=12, n_embd=768 (GPT-2 small). What about n_layer=96, n_embd=12288 (GPT-3)?

"The loop is four lines: forward, backward, step, repeat. GPT-3 is this loop with bigger numbers and a thousand GPUs. The difference is scale, not kind."

Day 9 · Tiny GPT — Week 2 Capstone
Further Reading

Go deeper.

The canonical "build and train a GPT" references, plus inference-cost essentials.

YouTube · 2 hr

Karpathy — Let's build GPT from scratch

The definitive walkthrough of exactly today's build, on TinyShakespeare. The single best 2-hour investment for this lesson.

Watch on YouTube
Repo · Karpathy

karpathy/nanoGPT

The reference our notebook mirrors. Read train.py and model.py. The most legible production-quality GPT training code.

View repo
Repo · Karpathy

ng-video-lecture (nanogpt.py)

The single-file version from the video — ~300 lines, fully legible. The ideal reference when you want to check your implementation.

View repo
Blog post · Karpathy

A Recipe for Training Neural Networks

"Verify the loss at init" and other hard-won debugging wisdom. Read this once a year.

Read post
Paper · Hoffmann et al. 2022

Chinchilla: Training Compute-Optimal LLMs

The scaling law paper that changed how LLMs are trained. Directly relevant to the FLOPs and parameter-count discussion in this lesson.

Read on arXiv
Blog · Transformer Circuits

In-context Learning and Induction Heads

What actually happens inside a small GPT as it learns. A mechanistic look at the model you just built.

Read post
Dataset · 2023

TinyStories

Tiny models produce coherent text on it. Use it for Exercise 8 — swapping in BPE tokenization is the natural next step after char-level.

View dataset
Repo · MLX

mlx-examples — GPT-2

Apple Silicon-native training. The optional MLX track in the notebook follows this implementation.

View repo