{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Day 29 \u2014 Capstone Part 3: Continuous Batching and Paged KV Cache (Notebook)\n",
    "\n",
    "Hands-on companion to the Day 29 lesson. Single-sequence generation (Days 27\u201328) is a\n",
    "demo; a serving engine needs **request state, admission, block allocation, and\n",
    "continuous batching** over many variable-length sequences. We build the smallest\n",
    "honest scheduler that proves the ideas:\n",
    "\n",
    "- a **`BlockPool`** of physical KV blocks,\n",
    "- a per-sequence **`PageTable`** (logical\u2192physical block map),\n",
    "- a **`Sequence`** record and a **`Scheduler`** with `waiting / running / done` queues,\n",
    "- **admission gated on free blocks**, growth-on-boundary, and preemption under pressure.\n",
    "\n",
    "The scheduler drives the *real* cached engine from Day 28 (small random LLaMA, no\n",
    "download), so the headline correctness check is concrete: **batched scheduler output ==\n",
    "solo greedy output**, token for token.\n",
    "\n",
    "**Cell map**\n",
    "1. Setup + the Day 28 cached engine\n",
    "2. `BlockPool` + `PageTable` (the allocator) \u2014 reproduce the lesson's 30-vs-64 example\n",
    "3. `Sequence` + `Scheduler` (admission, decode, growth, preemption)\n",
    "4. Run continuous batching with timed arrivals; verify outputs == solo greedy\n",
    "5. Block-utilization plot vs static allocation\n",
    "6. Admission gating under a constrained pool\n",
    "7. Exercises & self-check"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup + the Day 28 cached engine"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import math, time\n",
    "from dataclasses import dataclass, field\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from transformers import LlamaConfig, LlamaForCausalLM\n",
    "\n",
    "torch.manual_seed(0)\n",
    "DEVICE, DTYPE = \"cpu\", torch.float32\n",
    "\n",
    "hf_config = LlamaConfig(\n",
    "    vocab_size=320, hidden_size=128, intermediate_size=256,\n",
    "    num_hidden_layers=4, num_attention_heads=8, num_key_value_heads=2,\n",
    "    max_position_embeddings=512, rms_norm_eps=1e-5,\n",
    "    tie_word_embeddings=False, attention_bias=False, mlp_bias=False,\n",
    ")\n",
    "reference = LlamaForCausalLM(hf_config).eval().to(DEVICE, DTYPE)\n",
    "weights = {k: v.to(DEVICE, DTYPE) for k, v in reference.state_dict().items()}\n",
    "\n",
    "@dataclass\n",
    "class LLMConfig:\n",
    "    vocab_size:int; hidden_size:int; intermediate_size:int; n_layers:int\n",
    "    n_heads:int; n_kv_heads:int; head_dim:int; rms_eps:float; rope_theta:float\n",
    "cfg = LLMConfig(hf_config.vocab_size, hf_config.hidden_size, hf_config.intermediate_size,\n",
    "                hf_config.num_hidden_layers, hf_config.num_attention_heads,\n",
    "                hf_config.num_key_value_heads, hf_config.head_dim, hf_config.rms_norm_eps,\n",
    "                hf_config.rope_parameters[\"rope_theta\"])\n",
    "\n",
    "# --- Day 28 cached engine (components + KVCache + cache-aware forward) ---\n",
    "def rms_norm(x, w, eps):\n",
    "    x32 = x.to(torch.float32); x32 = x32 * torch.rsqrt(x32.pow(2).mean(-1, keepdim=True) + eps)\n",
    "    return (w * x32).to(x.dtype)\n",
    "def rope_tables(positions, hd, theta, device, dtype):\n",
    "    inv = 1.0 / (theta ** (torch.arange(0, hd, 2, device=device).float() / hd))\n",
    "    emb = torch.cat((torch.outer(positions.float(), inv),) * 2, dim=-1)\n",
    "    return emb.cos().to(dtype), emb.sin().to(dtype)\n",
    "def rotate_half(x):\n",
    "    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]; return torch.cat((-x2, x1), -1)\n",
    "def apply_rope(x, cos, sin): return x * cos + rotate_half(x) * sin\n",
    "\n",
    "class KVCache:\n",
    "    def __init__(self, n): self.k=[None]*n; self.v=[None]*n\n",
    "    def length(self): return 0 if self.k[0] is None else self.k[0].shape[2]\n",
    "    def append(self, i, k, v):\n",
    "        self.k[i] = k if self.k[i] is None else torch.cat([self.k[i], k], 2)\n",
    "        self.v[i] = v if self.v[i] is None else torch.cat([self.v[i], v], 2)\n",
    "        return self.k[i], self.v[i]\n",
    "\n",
    "class CachedLlamaEngine:\n",
    "    def __init__(self, cfg, w): self.cfg, self.w = cfg, w\n",
    "    @torch.no_grad()\n",
    "    def forward(self, input_ids, cache=None):\n",
    "        c, w = self.cfg, self.w\n",
    "        offset = cache.length() if cache is not None else 0\n",
    "        x = F.embedding(input_ids, w[\"model.embed_tokens.weight\"]); B, Tq, _ = x.shape\n",
    "        pos = torch.arange(offset, offset + Tq, device=x.device)\n",
    "        cos, sin = rope_tables(pos, c.head_dim, c.rope_theta, x.device, x.dtype)\n",
    "        for i in range(c.n_layers):\n",
    "            p = f\"model.layers.{i}.\"\n",
    "            h = rms_norm(x, w[p + \"input_layernorm.weight\"], c.rms_eps)\n",
    "            q = (h @ w[p + \"self_attn.q_proj.weight\"].T).view(B, Tq, c.n_heads, c.head_dim).transpose(1, 2)\n",
    "            k = (h @ w[p + \"self_attn.k_proj.weight\"].T).view(B, Tq, c.n_kv_heads, c.head_dim).transpose(1, 2)\n",
    "            v = (h @ w[p + \"self_attn.v_proj.weight\"].T).view(B, Tq, c.n_kv_heads, c.head_dim).transpose(1, 2)\n",
    "            q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)\n",
    "            if cache is not None: k, v = cache.append(i, k, v)\n",
    "            rep = c.n_heads // c.n_kv_heads\n",
    "            kf, vf = k.repeat_interleave(rep, 1), v.repeat_interleave(rep, 1)\n",
    "            s = (q @ kf.transpose(-1, -2)) / math.sqrt(c.head_dim)\n",
    "            qp = torch.arange(offset, offset + Tq, device=x.device).unsqueeze(1)\n",
    "            kp = torch.arange(offset + Tq, device=x.device).unsqueeze(0)\n",
    "            s = s.masked_fill(kp > qp, float(\"-inf\"))\n",
    "            o = (torch.softmax(s, -1) @ vf).transpose(1, 2).reshape(B, Tq, c.n_heads * c.head_dim)\n",
    "            x = x + o @ w[p + \"self_attn.o_proj.weight\"].T\n",
    "            h = rms_norm(x, w[p + \"post_attention_layernorm.weight\"], c.rms_eps)\n",
    "            g = F.silu(h @ w[p + \"mlp.gate_proj.weight\"].T) * (h @ w[p + \"mlp.up_proj.weight\"].T)\n",
    "            x = x + g @ w[p + \"mlp.down_proj.weight\"].T\n",
    "        return rms_norm(x, w[\"model.norm.weight\"], c.rms_eps) @ w[\"lm_head.weight\"].T\n",
    "\n",
    "engine = CachedLlamaEngine(cfg, weights)\n",
    "\n",
    "@torch.no_grad()\n",
    "def solo_greedy(prompt_ids, max_new_tokens):\n",
    "    \"\"\"Single-sequence greedy with its own cache (Day 28). The reference output.\"\"\"\n",
    "    cache = KVCache(cfg.n_layers); logits = engine.forward(prompt_ids, cache); out = []\n",
    "    for _ in range(max_new_tokens):\n",
    "        tok = int(logits[0, -1, :].argmax())\n",
    "        out.append(tok)\n",
    "        logits = engine.forward(torch.tensor([[tok]], device=prompt_ids.device), cache)\n",
    "    return out\n",
    "print(\"engine + solo_greedy ready.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. `BlockPool` + `PageTable` \u2014 the allocator\n",
    "\n",
    "A KV block holds `block_size` token positions. The **pool** owns a fixed set of physical\n",
    "blocks and hands them out. Each sequence's **page table** maps its logical block indices\n",
    "to physical block ids \u2014 `physical = page_table[pos // block_size]`. We first reproduce\n",
    "the lesson's exact accounting: lengths `[32,128,64,256]`, block size 16 \u2192 `2+8+4+16 = 30`\n",
    "blocks paged, versus `4 \u00d7 16 = 64` reserved by static max-length allocation."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "class BlockPool:\n",
    "    def __init__(self, num_blocks):\n",
    "        self.num_blocks = num_blocks\n",
    "        self.free = list(range(num_blocks))      # physical block ids\n",
    "    @property\n",
    "    def free_count(self): return len(self.free)\n",
    "    @property\n",
    "    def used(self): return self.num_blocks - len(self.free)\n",
    "    def allocate(self, n=1):\n",
    "        if n > len(self.free): raise RuntimeError(\"OOM: not enough free blocks\")\n",
    "        return [self.free.pop() for _ in range(n)]\n",
    "    def free_blocks(self, blocks):\n",
    "        self.free.extend(blocks)\n",
    "\n",
    "class PageTable:\n",
    "    \"\"\"Maps logical block index -> physical block id for one sequence.\"\"\"\n",
    "    def __init__(self, block_size): self.block_size = block_size; self.logical_to_physical = []\n",
    "    def num_blocks(self): return len(self.logical_to_physical)\n",
    "    def blocks_needed(self, n_tokens): return math.ceil(n_tokens / self.block_size)\n",
    "    def grow_to(self, n_tokens, pool):\n",
    "        \"\"\"Allocate blocks until this table covers n_tokens. Returns #blocks allocated.\"\"\"\n",
    "        need = self.blocks_needed(n_tokens) - self.num_blocks()\n",
    "        if need > 0:\n",
    "            self.logical_to_physical.extend(pool.allocate(need))\n",
    "        return max(0, need)\n",
    "    def physical_for(self, pos):\n",
    "        return self.logical_to_physical[pos // self.block_size]\n",
    "\n",
    "# Lesson example: paged vs static.\n",
    "BS = 16\n",
    "lengths = [32, 128, 64, 256]\n",
    "paged = sum(math.ceil(L / BS) for L in lengths)\n",
    "static = len(lengths) * math.ceil(max(lengths) / BS)\n",
    "print(f\"lengths {lengths}, block_size {BS}\")\n",
    "print(f\"paged blocks  : {' + '.join(str(math.ceil(L/BS)) for L in lengths)} = {paged}\")\n",
    "print(f\"static blocks : {len(lengths)} seqs x {math.ceil(max(lengths)/BS)} (max len) = {static}\")\n",
    "print(f\"paged uses {paged}/{static} = {paged/static*100:.0f}% of static reservation\")\n",
    "assert paged == 30 and static == 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. `Sequence` + `Scheduler`\n",
    "\n",
    "A `Sequence` tracks its prompt, generated tokens, page table, decode cache, and state.\n",
    "The `Scheduler` owns the queues and runs one iteration = one decode step per running\n",
    "sequence:\n",
    "\n",
    "1. **free** finished sequences (return their blocks to the pool),\n",
    "2. **admit** waiting requests whose arrival time has passed **and** for which the pool\n",
    "   has enough free blocks (and batch budget),\n",
    "3. **decode** one token per running sequence; **grow** its page table when a token\n",
    "   crosses a block boundary, **preempting** the longest sequence if the pool is full,\n",
    "4. **record** block utilization."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class Sequence:\n",
    "    sid: int\n",
    "    prompt: list\n",
    "    max_new_tokens: int\n",
    "    arrival: int = 0\n",
    "    generated: list = field(default_factory=list)\n",
    "    state: str = \"waiting\"\n",
    "    page_table: PageTable = None\n",
    "    cache: KVCache = None\n",
    "    logits = None\n",
    "    def total_tokens(self): return len(self.prompt) + len(self.generated)\n",
    "\n",
    "class Scheduler:\n",
    "    def __init__(self, engine, pool, block_size, max_batch=8):\n",
    "        self.engine, self.pool, self.block_size, self.max_batch = engine, pool, block_size, max_batch\n",
    "        self.waiting, self.running, self.done = [], [], []\n",
    "        self.t = 0\n",
    "        self.trace = []          # (t, blocks_used, n_running, n_done)\n",
    "        self.preemptions = 0\n",
    "\n",
    "    def add(self, seq): self.waiting.append(seq)\n",
    "\n",
    "    def _free_seq(self, seq):\n",
    "        self.pool.free_blocks(seq.page_table.logical_to_physical)\n",
    "        seq.page_table.logical_to_physical = []\n",
    "\n",
    "    def _admit(self):\n",
    "        for seq in sorted([s for s in self.waiting if s.arrival <= self.t], key=lambda s: s.arrival):\n",
    "            if len(self.running) >= self.max_batch:\n",
    "                break\n",
    "            need = math.ceil(len(seq.prompt) / self.block_size)\n",
    "            if self.pool.free_count < need:          # admission gate: enough blocks?\n",
    "                continue\n",
    "            seq.page_table = PageTable(self.block_size)\n",
    "            seq.cache = KVCache(self.engine.cfg.n_layers)\n",
    "            seq.page_table.grow_to(len(seq.prompt), self.pool)     # blocks for the prompt\n",
    "            seq.logits = self.engine.forward(torch.tensor([seq.prompt], device=DEVICE), seq.cache)\n",
    "            seq.state = \"running\"\n",
    "            self.waiting.remove(seq); self.running.append(seq)\n",
    "\n",
    "    def _decode_step(self):\n",
    "        finished = []\n",
    "        for seq in list(self.running):\n",
    "            # Would emitting the next token cross a block boundary that needs new memory?\n",
    "            need_block = seq.page_table.blocks_needed(seq.total_tokens() + 1) > seq.page_table.num_blocks()\n",
    "            if need_block and self.pool.free_count == 0:\n",
    "                # Recomputation preemption: roll this sequence back to `waiting`, free its\n",
    "                # blocks, and drop its progress. Greedy is deterministic, so when it is\n",
    "                # re-admitted it regenerates the identical tokens \u2014 correctness is preserved.\n",
    "                self._free_seq(seq)\n",
    "                seq.generated, seq.cache, seq.logits, seq.state = [], None, None, \"waiting\"\n",
    "                self.running.remove(seq); self.waiting.append(seq)\n",
    "                self.preemptions += 1\n",
    "                continue\n",
    "            tok = int(seq.logits[0, -1, :].argmax())     # greedy (matches solo reference)\n",
    "            seq.generated.append(tok)\n",
    "            if need_block:\n",
    "                seq.page_table.grow_to(seq.total_tokens(), self.pool)\n",
    "            if len(seq.generated) >= seq.max_new_tokens:\n",
    "                seq.state = \"done\"; finished.append(seq)\n",
    "            else:\n",
    "                seq.logits = self.engine.forward(torch.tensor([[tok]], device=DEVICE), seq.cache)\n",
    "        for seq in finished:\n",
    "            self.running.remove(seq); self._free_seq(seq); self.done.append(seq)\n",
    "\n",
    "    def step(self):\n",
    "        self._admit()\n",
    "        self._decode_step()\n",
    "        self.trace.append((self.t, self.pool.used, len(self.running), len(self.done)))\n",
    "        self.t += 1\n",
    "\n",
    "    def run(self, max_iters=10000):\n",
    "        while (self.waiting or self.running) and self.t < max_iters:\n",
    "            self.step()\n",
    "        return self.done\n",
    "\n",
    "print(\"Sequence + Scheduler defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Run continuous batching \u2014 and verify outputs == solo greedy\n",
    "\n",
    "Four requests arrive at times 0, 2, 4, 6 with different output lengths. We run the\n",
    "scheduler to completion, then check every sequence's scheduled output equals the\n",
    "single-sequence greedy reference. Identical tokens prove the paged/batched machinery\n",
    "did not corrupt the math."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def make_requests():\n",
    "    g = torch.Generator().manual_seed(123)\n",
    "    specs = [(0, 3, 10), (1, 6, 25), (2, 4, 8), (3, 5, 18)]   # (sid, prompt_len, max_new)\n",
    "    arrivals = [0, 2, 4, 6]\n",
    "    seqs = []\n",
    "    for (sid, plen, mx), arr in zip(specs, arrivals):\n",
    "        prompt = torch.randint(0, cfg.vocab_size, (plen,), generator=g).tolist()\n",
    "        seqs.append(Sequence(sid=sid, prompt=prompt, max_new_tokens=mx, arrival=arr))\n",
    "    return seqs\n",
    "\n",
    "BLOCK_SIZE = 4\n",
    "seqs = make_requests()\n",
    "pool = BlockPool(num_blocks=64)                 # generous pool (no preemption expected)\n",
    "sched = Scheduler(engine, pool, BLOCK_SIZE, max_batch=8)\n",
    "for s in seqs: sched.add(s)\n",
    "done = sched.run()\n",
    "\n",
    "print(f\"finished {len(done)} sequences in {sched.t} iterations; preemptions={sched.preemptions}\\n\")\n",
    "all_match = True\n",
    "for seq in sorted(done, key=lambda s: s.sid):\n",
    "    ref = solo_greedy(torch.tensor([seq.prompt], device=DEVICE), seq.max_new_tokens)\n",
    "    ok = ref == seq.generated\n",
    "    all_match &= ok\n",
    "    print(f\"  seq {seq.sid}: prompt_len={len(seq.prompt):>2} gen={len(seq.generated):>2} \"\n",
    "          f\"output==solo_greedy: {ok}\")\n",
    "assert all_match\n",
    "print(\"\\n[OK] Continuous-batching scheduler outputs match solo greedy decoding exactly.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Block-utilization evidence\n",
    "\n",
    "The trace records physical blocks in use each iteration. Paged allocation tracks the\n",
    "*actual* token counts, so peak usage sits well below the static worst case (every\n",
    "sequence reserved at the max length for the whole batch)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "peak = max(b for _, b, _, _ in sched.trace)\n",
    "static_max = len(seqs) * math.ceil(max(s.total_tokens() for s in done) / BLOCK_SIZE)\n",
    "print(f\"peak blocks in use (paged) : {peak}\")\n",
    "print(f\"static reservation         : {len(seqs)} seqs x {math.ceil(max(s.total_tokens() for s in done)/BLOCK_SIZE)} = {static_max}\")\n",
    "print(f\"paged peak is {peak}/{static_max} = {peak/static_max*100:.0f}% of static\\n\")\n",
    "\n",
    "try:\n",
    "    import matplotlib\n",
    "    matplotlib.use(\"Agg\")\n",
    "    import matplotlib.pyplot as plt\n",
    "    ts   = [t for t, *_ in sched.trace]\n",
    "    used = [b for _, b, *_ in sched.trace]\n",
    "    run  = [r for _, _, r, _ in sched.trace]\n",
    "    fig, ax = plt.subplots(figsize=(9, 4))\n",
    "    ax.plot(ts, used, color=\"#8B2635\", linewidth=2, label=\"blocks in use (paged)\")\n",
    "    ax.axhline(static_max, color=\"#6B6B6B\", linestyle=\"--\", label=f\"static reservation ({static_max})\")\n",
    "    ax.plot(ts, run, color=\"#B8932E\", linewidth=1.2, label=\"running sequences\")\n",
    "    ax.set_xlabel(\"scheduler iteration\"); ax.set_ylabel(\"count\")\n",
    "    ax.set_title(\"Paged KV utilization under continuous batching\")\n",
    "    ax.legend(); ax.grid(alpha=0.3)\n",
    "    plt.tight_layout(); plt.savefig(\"/tmp/day29_utilization.png\", dpi=110); plt.show()\n",
    "    print(\"saved plot to /tmp/day29_utilization.png\")\n",
    "except Exception as e:\n",
    "    print(\"matplotlib unavailable, table instead:\")\n",
    "    for t, b, r, d in sched.trace:\n",
    "        print(f\"  t={t:>2}  blocks={b:>2}  running={r}  done={d}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Admission gating under a constrained pool\n",
    "\n",
    "Shrink the pool so it cannot hold every prompt at once. The scheduler must **defer**\n",
    "admission until blocks free up \u2014 no OOM, all sequences still finish correctly. This is\n",
    "the whole point of the admission gate."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "seqs2 = make_requests()\n",
    "small_pool = BlockPool(num_blocks=8)            # tight: cannot run all 4 at once\n",
    "sched2 = Scheduler(engine, small_pool, BLOCK_SIZE, max_batch=8)\n",
    "for s in seqs2: sched2.add(s)\n",
    "\n",
    "sched2.run(max_iters=10000)\n",
    "max_concurrent = max(r for _, _, r, _ in sched2.trace)\n",
    "assert small_pool.used == 0, \"all blocks must be returned to the pool at the end\"\n",
    "\n",
    "print(f\"tight pool of {small_pool.num_blocks} blocks: finished {len(sched2.done)} seqs, \"\n",
    "      f\"max concurrent running = {max_concurrent}, preemptions = {sched2.preemptions}\")\n",
    "# Correctness still holds under gating/preemption.\n",
    "ok = all(solo_greedy(torch.tensor([s.prompt], device=DEVICE), s.max_new_tokens) == s.generated\n",
    "         for s in sched2.done)\n",
    "print(f\"outputs still match solo greedy: {ok}\")\n",
    "assert ok and len(sched2.done) == len(seqs2)\n",
    "print(\"\\n[OK] Admission gating prevented OOM; every block was reclaimed; outputs are correct.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Exercises\n",
    "\n",
    "1. **Preemption policy.** Replace \"preempt the longest\" with \"preempt the most recently\n",
    "   admitted\" and compare total iterations and preemption counts.\n",
    "2. **Throughput vs latency.** Lower `max_batch` and measure how per-sequence finish time\n",
    "   changes \u2014 the fairness/throughput tradeoff.\n",
    "3. **Prefix sharing.** Give two sequences an identical prompt prefix and share the prompt\n",
    "   blocks (copy-on-write) \u2014 record the extra blocks saved.\n",
    "4. **Bigger workload.** Generate 16 requests with random arrivals and lengths; plot blocks\n",
    "   in use and report peak vs static.\n",
    "5. **Real EOS.** Add an `eos_token_id` stop condition and confirm sequences end early when\n",
    "   the model emits it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Self-Check\n",
    "\n",
    "1. **How does the scheduler know a sequence is done?** It hits `eos_token_id` or\n",
    "   `max_new_tokens`; the sequence moves to `done` and its blocks return to the pool.\n",
    "2. **What happens when the block pool is exhausted?** Stop admitting new work, or\n",
    "   **preempt** a running sequence (free + requeue) per policy \u2014 never attend to memory\n",
    "   you do not own.\n",
    "3. **Why pad in the simple batched path?** Rectangular matmul shapes; masks hide padding.\n",
    "   (Here each sequence keeps its own cache, so no padding is needed at all.)\n",
    "4. **Page-table growth rule?** Allocate a new logical block when the next token would\n",
    "   cross a block boundary (`ceil(total_tokens / block_size) > num_blocks`).\n",
    "5. **What proves paged cache works?** Peak blocks are well below the static worst case\n",
    "   **and** outputs match solo decoding \u2014 both asserted above."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}