{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Day 28 \u2014 Capstone Part 2: KV Cache, Masks, and Sampling Loop (Notebook)\n",
    "\n",
    "Hands-on companion to the Day 28 lesson. Day 27 proved one forward pass is correct;\n",
    "here we turn it into a usable engine: **prefill** the prompt into a per-layer KV\n",
    "cache, **decode** one token at a time from cache, **sample** controllably, and stop\n",
    "on EOS. The non-negotiable invariant: *with the same weights, positions, and masks,\n",
    "cached and uncached logits match to floating-point noise.*\n",
    "\n",
    "As in Day 27 we use a small random LLaMA from `transformers` as the reference, so\n",
    "the whole notebook runs in seconds with no download. The cached engine is also\n",
    "cross-checked against HuggingFace on the prompt.\n",
    "\n",
    "**Cell map**\n",
    "1. Setup + reference model (same tiny LLaMA as Day 27)\n",
    "2. Components + a `KVCache` and a cache-aware forward (prefill & decode in one path)\n",
    "3. Verify: cached logits == uncached logits, and prefill == HuggingFace\n",
    "4. Samplers: greedy, temperature, top-k, top-p\n",
    "5. `generate(...)` with the cache + EOS handling\n",
    "6. Determinism check\n",
    "7. Benchmark: no-cache vs cached decode\n",
    "8. Exercises & self-check"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup + reference model"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import math, time\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from transformers import LlamaConfig, LlamaForCausalLM\n",
    "\n",
    "torch.manual_seed(0)\n",
    "DEVICE, DTYPE = \"cpu\", torch.float32\n",
    "\n",
    "hf_config = LlamaConfig(\n",
    "    vocab_size=320, hidden_size=128, intermediate_size=256,\n",
    "    num_hidden_layers=4, num_attention_heads=8, num_key_value_heads=2,\n",
    "    max_position_embeddings=512, rms_norm_eps=1e-5,\n",
    "    tie_word_embeddings=False, attention_bias=False, mlp_bias=False,\n",
    ")\n",
    "reference = LlamaForCausalLM(hf_config).eval().to(DEVICE, DTYPE)\n",
    "weights = {k: v.to(DEVICE, DTYPE) for k, v in reference.state_dict().items()}\n",
    "\n",
    "@dataclass\n",
    "class LLMConfig:\n",
    "    vocab_size:int; hidden_size:int; intermediate_size:int; n_layers:int\n",
    "    n_heads:int; n_kv_heads:int; head_dim:int; rms_eps:float; rope_theta:float\n",
    "\n",
    "cfg = LLMConfig(\n",
    "    vocab_size=hf_config.vocab_size, hidden_size=hf_config.hidden_size,\n",
    "    intermediate_size=hf_config.intermediate_size, n_layers=hf_config.num_hidden_layers,\n",
    "    n_heads=hf_config.num_attention_heads, n_kv_heads=hf_config.num_key_value_heads,\n",
    "    head_dim=hf_config.head_dim, rms_eps=hf_config.rms_norm_eps,\n",
    "    rope_theta=hf_config.rope_parameters[\"rope_theta\"],\n",
    ")\n",
    "print(\"torch\", torch.__version__, \"| layers\", cfg.n_layers, \"| GQA\", cfg.n_heads, \"/\", cfg.n_kv_heads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Components + `KVCache` + a cache-aware forward\n",
    "\n",
    "The components (RMSNorm, RoPE, GQA, SwiGLU) are identical to Day 27. The new piece is\n",
    "the cache. Two subtleties make caching correct:\n",
    "\n",
    "- **RoPE position offset.** A token decoded after `cache_len` prompt tokens sits at\n",
    "  absolute position `cache_len`, not 0. We rotate Q and the *new* K at their absolute\n",
    "  positions, and store the **already-rotated** K in the cache (so cached keys never get\n",
    "  re-rotated).\n",
    "- **Mask by absolute position.** One general rule covers both phases: a query at\n",
    "  absolute position `p` may attend to a key at position `j` iff `j <= p`. For prefill\n",
    "  (`T` queries from 0) that is the triangular mask; for decode (1 query at `cache_len`)\n",
    "  every cached key is visible, so no masking is needed."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def rms_norm(x, w, eps):\n",
    "    x32 = x.to(torch.float32)\n",
    "    x32 = x32 * torch.rsqrt(x32.pow(2).mean(-1, keepdim=True) + eps)\n",
    "    return (w * x32).to(x.dtype)\n",
    "\n",
    "def rope_tables(positions, head_dim, theta, device, dtype):\n",
    "    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))\n",
    "    freqs = torch.outer(positions.float(), inv_freq)\n",
    "    emb = torch.cat((freqs, freqs), dim=-1)\n",
    "    return emb.cos().to(dtype), emb.sin().to(dtype)\n",
    "\n",
    "def rotate_half(x):\n",
    "    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]\n",
    "    return torch.cat((-x2, x1), dim=-1)\n",
    "\n",
    "def apply_rope(x, cos, sin):\n",
    "    return x * cos + rotate_half(x) * sin\n",
    "\n",
    "\n",
    "class KVCache:\n",
    "    \"\"\"Per-layer key/value cache. Stores K already rotated by RoPE.\"\"\"\n",
    "    def __init__(self, n_layers):\n",
    "        self.k = [None] * n_layers\n",
    "        self.v = [None] * n_layers\n",
    "\n",
    "    def length(self):\n",
    "        return 0 if self.k[0] is None else self.k[0].shape[2]   # (B, KV, T, hd)\n",
    "\n",
    "    def append(self, layer, k_new, v_new):\n",
    "        if self.k[layer] is None:\n",
    "            self.k[layer], self.v[layer] = k_new, v_new\n",
    "        else:\n",
    "            self.k[layer] = torch.cat([self.k[layer], k_new], dim=2)\n",
    "            self.v[layer] = torch.cat([self.v[layer], v_new], dim=2)\n",
    "        return self.k[layer], self.v[layer]\n",
    "\n",
    "\n",
    "class CachedLlamaEngine:\n",
    "    def __init__(self, cfg, weights):\n",
    "        self.cfg, self.w = cfg, weights\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def forward(self, input_ids, cache=None):\n",
    "        \"\"\"One forward over `input_ids`. If `cache` is given, append K/V and attend\n",
    "        over the full cached context (prefill when empty, decode when 1 token).\"\"\"\n",
    "        c, w = self.cfg, self.w\n",
    "        offset = cache.length() if cache is not None else 0       # absolute start position\n",
    "        x = F.embedding(input_ids, w[\"model.embed_tokens.weight\"])\n",
    "        B, Tq, _ = x.shape\n",
    "        positions = torch.arange(offset, offset + Tq, device=x.device)\n",
    "        cos, sin = rope_tables(positions, c.head_dim, c.rope_theta, x.device, x.dtype)\n",
    "\n",
    "        for i in range(c.n_layers):\n",
    "            p = f\"model.layers.{i}.\"\n",
    "            h = rms_norm(x, w[p + \"input_layernorm.weight\"], c.rms_eps)\n",
    "            q = (h @ w[p + \"self_attn.q_proj.weight\"].T).view(B, Tq, c.n_heads,    c.head_dim).transpose(1, 2)\n",
    "            k = (h @ w[p + \"self_attn.k_proj.weight\"].T).view(B, Tq, c.n_kv_heads, c.head_dim).transpose(1, 2)\n",
    "            v = (h @ w[p + \"self_attn.v_proj.weight\"].T).view(B, Tq, c.n_kv_heads, c.head_dim).transpose(1, 2)\n",
    "            q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)   # rotate before caching K\n",
    "            if cache is not None:\n",
    "                k, v = cache.append(i, k, v)                          # full (B, KV, offset+Tq, hd)\n",
    "            rep = c.n_heads // c.n_kv_heads\n",
    "            k_full = k.repeat_interleave(rep, dim=1)\n",
    "            v_full = v.repeat_interleave(rep, dim=1)\n",
    "            scores = (q @ k_full.transpose(-1, -2)) / math.sqrt(c.head_dim)   # (B,H,Tq,Lk)\n",
    "            Lk = offset + Tq\n",
    "            qpos = torch.arange(offset, offset + Tq, device=x.device).unsqueeze(1)   # (Tq,1)\n",
    "            kpos = torch.arange(Lk, device=x.device).unsqueeze(0)                     # (1,Lk)\n",
    "            scores = scores.masked_fill(kpos > qpos, float(\"-inf\"))   # attend iff key_pos <= query_pos\n",
    "            attn = torch.softmax(scores, dim=-1)\n",
    "            o = (attn @ v_full).transpose(1, 2).reshape(B, Tq, c.n_heads * c.head_dim)\n",
    "            x = x + o @ w[p + \"self_attn.o_proj.weight\"].T\n",
    "            h = rms_norm(x, w[p + \"post_attention_layernorm.weight\"], c.rms_eps)\n",
    "            g = F.silu(h @ w[p + \"mlp.gate_proj.weight\"].T) * (h @ w[p + \"mlp.up_proj.weight\"].T)\n",
    "            x = x + g @ w[p + \"mlp.down_proj.weight\"].T\n",
    "        x = rms_norm(x, w[\"model.norm.weight\"], c.rms_eps)\n",
    "        return x @ w[\"lm_head.weight\"].T\n",
    "\n",
    "engine = CachedLlamaEngine(cfg, weights)\n",
    "print(\"CachedLlamaEngine ready.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Verify: cached == uncached, and prefill == HuggingFace\n",
    "\n",
    "Three checks: (a) our prefill logits match HuggingFace on the prompt, (b) a cached\n",
    "prefill+decode run reproduces the uncached full-sequence logits at every new position,\n",
    "and (c) cached and uncached greedy produce the identical token stream."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "prompt = torch.tensor([[1, 5, 9, 12, 3, 7, 42, 100]], device=DEVICE)\n",
    "\n",
    "# (a) prefill vs HuggingFace\n",
    "cache = KVCache(cfg.n_layers)\n",
    "prefill_logits = engine.forward(prompt, cache)\n",
    "with torch.no_grad():\n",
    "    hf_logits = reference(prompt).logits\n",
    "print(f\"(a) prefill vs HF   max_abs_diff: {(prefill_logits - hf_logits).abs().max().item():.3e}\")\n",
    "assert (prefill_logits - hf_logits).abs().max().item() < 1e-4\n",
    "\n",
    "# (b) cached decode logits vs uncached full-forward logits, token by token\n",
    "seq = prompt.clone()\n",
    "for step in range(6):\n",
    "    next_tok = prefill_logits[:, -1, :].argmax(-1, keepdim=True) if step == 0 else cached_logits[:, -1, :].argmax(-1, keepdim=True)\n",
    "    seq = torch.cat([seq, next_tok], dim=1)\n",
    "    cached_logits = engine.forward(next_tok, cache)              # decode one token from cache\n",
    "    uncached_logits = engine.forward(seq)                        # recompute whole sequence, no cache\n",
    "    diff = (cached_logits[:, -1, :] - uncached_logits[:, -1, :]).abs().max().item()\n",
    "    assert diff < 1e-4, f\"cached != uncached at step {step}: {diff}\"\n",
    "print(f\"(b) cached == uncached over 6 decode steps (max diff < 1e-4): OK\")\n",
    "\n",
    "# (c) greedy with cache vs greedy without cache \u2192 identical tokens\n",
    "def greedy_no_cache(ids, n):\n",
    "    ids = ids.clone()\n",
    "    for _ in range(n):\n",
    "        ids = torch.cat([ids, engine.forward(ids)[:, -1, :].argmax(-1, keepdim=True)], dim=1)\n",
    "    return ids\n",
    "def greedy_cached(ids, n):\n",
    "    cache = KVCache(cfg.n_layers)\n",
    "    logits = engine.forward(ids, cache)\n",
    "    out = ids.clone()\n",
    "    for _ in range(n):\n",
    "        nxt = logits[:, -1, :].argmax(-1, keepdim=True)\n",
    "        out = torch.cat([out, nxt], dim=1)\n",
    "        logits = engine.forward(nxt, cache)\n",
    "    return out\n",
    "a = greedy_no_cache(prompt, 15)\n",
    "b = greedy_cached(prompt, 15)\n",
    "print(f\"(c) greedy cached == greedy no-cache: {torch.equal(a, b)}\")\n",
    "assert torch.equal(a, b)\n",
    "print(\"\\n[OK] KV cache is mathematically transparent.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Samplers: greedy, temperature, top-k, top-p\n",
    "\n",
    "Sampling happens **after** logits and before token selection: temperature scales the\n",
    "logits, top-k keeps the k largest, top-p (nucleus) keeps the smallest set whose\n",
    "cumulative probability \u2265 p. `temperature == 0` is greedy."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def sample_token(logits, temperature=1.0, top_k=0, top_p=1.0, generator=None):\n",
    "    \"\"\"logits: (vocab,). Returns an int token id.\"\"\"\n",
    "    logits = logits.clone()\n",
    "    if temperature == 0:                       # greedy\n",
    "        return int(logits.argmax())\n",
    "    logits = logits / temperature\n",
    "    if top_k and top_k < logits.numel():       # keep k largest\n",
    "        kth = torch.topk(logits, top_k).values[-1]\n",
    "        logits[logits < kth] = float(\"-inf\")\n",
    "    if top_p < 1.0:                            # nucleus\n",
    "        sorted_logits, sorted_idx = torch.sort(logits, descending=True)\n",
    "        cum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)\n",
    "        remove = cum > top_p\n",
    "        remove[..., 1:] = remove[..., :-1].clone()   # keep the token that crosses p\n",
    "        remove[..., 0] = False\n",
    "        logits[sorted_idx[remove]] = float(\"-inf\")\n",
    "    probs = torch.softmax(logits, dim=-1)\n",
    "    return int(torch.multinomial(probs, 1, generator=generator))\n",
    "\n",
    "# Show each mode on the prompt's final logits.\n",
    "last = engine.forward(prompt)[0, -1, :]\n",
    "g = torch.Generator(device=DEVICE).manual_seed(0)\n",
    "print(\"greedy           :\", sample_token(last, temperature=0))\n",
    "print(\"temperature=0.8  :\", sample_token(last, temperature=0.8, generator=g))\n",
    "print(\"top_k=5          :\", sample_token(last, temperature=1.0, top_k=5, generator=g))\n",
    "print(\"top_p=0.9        :\", sample_token(last, temperature=1.0, top_p=0.9, generator=g))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. `generate(...)` with the cache\n",
    "\n",
    "Tokenize \u2192 prefill prompt \u2192 sample first token \u2192 repeatedly decode one token, sample,\n",
    "append, and stop on `eos_token_id` or `max_new_tokens`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def generate(engine, prompt_ids, max_new_tokens=20, eos_token_id=None,\n",
    "             temperature=1.0, top_k=0, top_p=1.0, seed=0):\n",
    "    cache = KVCache(engine.cfg.n_layers)\n",
    "    gen = torch.Generator(device=prompt_ids.device).manual_seed(seed)\n",
    "    logits = engine.forward(prompt_ids, cache)            # prefill\n",
    "    out = prompt_ids.clone()\n",
    "    new = []\n",
    "    for _ in range(max_new_tokens):\n",
    "        tok = sample_token(logits[0, -1, :], temperature, top_k, top_p, generator=gen)\n",
    "        new.append(tok)\n",
    "        if tok == eos_token_id:\n",
    "            break\n",
    "        nxt = torch.tensor([[tok]], device=prompt_ids.device)\n",
    "        out = torch.cat([out, nxt], dim=1)\n",
    "        logits = engine.forward(nxt, cache)              # decode from cache\n",
    "    return out, new\n",
    "\n",
    "ids, new = generate(engine, prompt, max_new_tokens=20, temperature=0)   # greedy\n",
    "print(\"greedy generated :\", new)\n",
    "ids, new = generate(engine, prompt, max_new_tokens=20, temperature=0.9, top_k=20, top_p=0.95, seed=1)\n",
    "print(\"sampled generated:\", new)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Determinism check\n",
    "\n",
    "Greedy is deterministic; sampling is deterministic *given a fixed seed*. Both must\n",
    "reproduce exactly on a repeat run."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "_, g1 = generate(engine, prompt, max_new_tokens=100, temperature=0)\n",
    "_, g2 = generate(engine, prompt, max_new_tokens=100, temperature=0)\n",
    "print(\"greedy reproducible (100 tokens):\", g1 == g2)\n",
    "assert g1 == g2\n",
    "\n",
    "_, s1 = generate(engine, prompt, max_new_tokens=50, temperature=0.8, top_p=0.9, seed=7)\n",
    "_, s2 = generate(engine, prompt, max_new_tokens=50, temperature=0.8, top_p=0.9, seed=7)\n",
    "print(\"seeded sampling reproducible    :\", s1 == s2)\n",
    "assert s1 == s2\n",
    "print(\"\\n[OK] Deterministic given (mode, seed).\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Benchmark: no-cache vs cached decode\n",
    "\n",
    "No-cache decode recomputes the entire prefix every step (O(T\u00b2) over the run); cached\n",
    "decode does O(1) prefix work per step. The speedup grows with context length. We use a\n",
    "modest grid so the notebook stays fast \u2014 Exercise 4 scales it to the lesson's\n",
    "`[32,128,512] \u00d7 [32,128,256]`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def time_no_cache(prompt_len, gen_len):\n",
    "    ids = torch.randint(0, cfg.vocab_size, (1, prompt_len), device=DEVICE)\n",
    "    t0 = time.perf_counter()\n",
    "    for _ in range(gen_len):\n",
    "        ids = torch.cat([ids, engine.forward(ids)[:, -1, :].argmax(-1, keepdim=True)], dim=1)\n",
    "    return time.perf_counter() - t0\n",
    "\n",
    "def time_cached(prompt_len, gen_len):\n",
    "    ids = torch.randint(0, cfg.vocab_size, (1, prompt_len), device=DEVICE)\n",
    "    t0 = time.perf_counter()\n",
    "    cache = KVCache(cfg.n_layers)\n",
    "    logits = engine.forward(ids, cache)\n",
    "    for _ in range(gen_len):\n",
    "        nxt = logits[:, -1, :].argmax(-1, keepdim=True)\n",
    "        logits = engine.forward(nxt, cache)\n",
    "    return time.perf_counter() - t0\n",
    "\n",
    "print(f\"{'prompt':>7} {'gen':>5} {'no-cache(s)':>12} {'cached(s)':>10} {'speedup':>8}\")\n",
    "for plen in [16, 64]:\n",
    "    for glen in [32, 64]:\n",
    "        t_nc = time_no_cache(plen, glen)\n",
    "        t_c  = time_cached(plen, glen)\n",
    "        print(f\"{plen:>7} {glen:>5} {t_nc:>12.3f} {t_c:>10.3f} {t_nc / t_c:>7.1f}x\")\n",
    "print(\"\\nSpeedup grows with prompt+gen length \u2014 exactly why production engines cache.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Exercises\n",
    "\n",
    "1. **Cache correctness on your own prompt.** Pick any prompt and assert cached vs\n",
    "   uncached logits match at every decode step (Section 3 is the template).\n",
    "2. **Sampler ablation.** Hold the seed fixed and sweep `temperature \u2208 {0.5, 1.0, 1.5}`\n",
    "   and `top_p \u2208 {0.8, 0.95}`; observe how diversity changes.\n",
    "3. **Determinism at scale.** Generate 200 greedy tokens twice; confirm identical.\n",
    "4. **Full benchmark grid.** Extend Section 7 to prompt `[32,128,512]` \u00d7 output\n",
    "   `[32,128,256]` and plot the speedup vs total context length.\n",
    "5. **EOS handling.** Set `eos_token_id` to a token the model emits and confirm\n",
    "   generation stops early."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Self-Check\n",
    "\n",
    "1. **Why no triangular mask for one-token decode?** The single new query sits at the\n",
    "   newest position and may see *all* valid past keys; there is no future to mask.\n",
    "2. **What does `cache_len` protect against?** Attending to uninitialized/stale cache\n",
    "   slots \u2014 only the first `cache_len` positions are valid.\n",
    "3. **Why compare greedy first?** It removes sampling randomness, isolating numeric bugs.\n",
    "4. **What causes RoPE drift?** Using position 0 for every decode step instead of the\n",
    "   absolute position `cache_len` \u2014 we pass `offset` into the RoPE tables to avoid it.\n",
    "5. **Where does the speedup come from?** Not recomputing K/V projections and prefix\n",
    "   attention every step \u2014 decode does O(1) prefix work via the cache."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}