{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Day 27 \u2014 Capstone Part 1: Model Loader, Weights, and Single-Sequence Forward (Notebook)\n",
    "\n",
    "Hands-on companion to the Day 27 lesson. We build the smallest honest engine:\n",
    "read a config, bind weights by name, run **one** forward pass from scratch, and\n",
    "**verify the logits against HuggingFace Transformers** \u2014 correctness before speed.\n",
    "\n",
    "To keep the notebook self-contained and fast (no multi-GB download), we use a\n",
    "**small random LLaMA built with `transformers` as the reference checkpoint**. The\n",
    "exact same code path loads the real TinyLlama-1.1B weights \u2014 see the optional cell\n",
    "in Section 3. Our hand-written forward must match the reference to floating-point\n",
    "precision; if it does, the engine is correct and would be correct on the real model.\n",
    "\n",
    "**Cell map**\n",
    "1. Setup\n",
    "2. Config \u2014 read a model config into a typed `LLMConfig`\n",
    "3. Loader \u2014 bind weights by name and assert layer-0 shapes (+ optional real TinyLlama)\n",
    "4. Model components \u2014 RMSNorm, RoPE, GQA attention, SwiGLU\n",
    "5. Forward pass + verification against Transformers\n",
    "6. Minimal greedy decode (no cache)\n",
    "7. Exercises & self-check"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import math\n",
    "from dataclasses import dataclass, field\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from transformers import LlamaConfig, LlamaForCausalLM\n",
    "\n",
    "torch.manual_seed(0)\n",
    "DEVICE = \"cpu\"   # keep verification on CPU/float32 for exact, deterministic comparison\n",
    "DTYPE  = torch.float32\n",
    "print(\"torch\", torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Config \u2014 read a model config into a typed `LLMConfig`\n",
    "\n",
    "A professional engine is **config-driven**: the same code loads any LLaMA-style\n",
    "variant by reading dimensions from `config.json` instead of hard-coding them. We\n",
    "mirror the handful of fields the forward pass actually needs."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class LLMConfig:\n",
    "    vocab_size: int\n",
    "    hidden_size: int          # model width D\n",
    "    intermediate_size: int    # SwiGLU hidden\n",
    "    n_layers: int\n",
    "    n_heads: int              # query heads\n",
    "    n_kv_heads: int           # key/value heads (GQA)\n",
    "    head_dim: int\n",
    "    rms_eps: float\n",
    "    rope_theta: float\n",
    "    tie_embeddings: bool\n",
    "\n",
    "    @classmethod\n",
    "    def from_hf(cls, c: LlamaConfig) -> \"LLMConfig\":\n",
    "        # transformers 5.x stores the RoPE base under rope_parameters.\n",
    "        theta = c.rope_parameters[\"rope_theta\"] if getattr(c, \"rope_parameters\", None) else 10000.0\n",
    "        return cls(\n",
    "            vocab_size=c.vocab_size, hidden_size=c.hidden_size,\n",
    "            intermediate_size=c.intermediate_size, n_layers=c.num_hidden_layers,\n",
    "            n_heads=c.num_attention_heads, n_kv_heads=c.num_key_value_heads,\n",
    "            head_dim=c.head_dim, rms_eps=c.rms_norm_eps, rope_theta=theta,\n",
    "            tie_embeddings=c.tie_word_embeddings,\n",
    "        )\n",
    "\n",
    "\n",
    "# Build a small but structurally realistic LLaMA (GQA: 8 query / 2 KV heads) as our\n",
    "# stand-in checkpoint. This is exactly the architecture of TinyLlama, just smaller.\n",
    "hf_config = LlamaConfig(\n",
    "    vocab_size=320, hidden_size=128, intermediate_size=256,\n",
    "    num_hidden_layers=4, num_attention_heads=8, num_key_value_heads=2,\n",
    "    max_position_embeddings=128, rms_norm_eps=1e-5,\n",
    "    tie_word_embeddings=False, attention_bias=False, mlp_bias=False,\n",
    ")\n",
    "reference = LlamaForCausalLM(hf_config).eval().to(DEVICE, DTYPE)\n",
    "\n",
    "cfg = LLMConfig.from_hf(hf_config)\n",
    "assert cfg.head_dim == cfg.hidden_size // cfg.n_heads\n",
    "print(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Loader \u2014 bind weights by name and assert layer-0 shapes\n",
    "\n",
    "The loader's job is boring correctness: list the tensors, map names to the\n",
    "submodules that consume them, and **fail loudly** if a shape is wrong. We treat\n",
    "`reference.state_dict()` as the bag of safetensors tensors (the real loader reads\n",
    "these from a `.safetensors` file with `safetensors.torch.load_file`)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "weights = {k: v.to(DEVICE, DTYPE) for k, v in reference.state_dict().items()}\n",
    "\n",
    "print(f\"{len(weights)} tensors. Top-level + layer-0 weight map:\\n\")\n",
    "D, KV, hd = cfg.hidden_size, cfg.n_kv_heads, cfg.head_dim\n",
    "expected_layer0 = {\n",
    "    \"model.embed_tokens.weight\":                 (cfg.vocab_size, D),\n",
    "    \"model.layers.0.input_layernorm.weight\":     (D,),\n",
    "    \"model.layers.0.self_attn.q_proj.weight\":    (cfg.n_heads * hd, D),\n",
    "    \"model.layers.0.self_attn.k_proj.weight\":    (KV * hd, D),\n",
    "    \"model.layers.0.self_attn.v_proj.weight\":    (KV * hd, D),\n",
    "    \"model.layers.0.self_attn.o_proj.weight\":    (D, cfg.n_heads * hd),\n",
    "    \"model.layers.0.post_attention_layernorm.weight\": (D,),\n",
    "    \"model.layers.0.mlp.gate_proj.weight\":       (cfg.intermediate_size, D),\n",
    "    \"model.layers.0.mlp.up_proj.weight\":         (cfg.intermediate_size, D),\n",
    "    \"model.layers.0.mlp.down_proj.weight\":       (D, cfg.intermediate_size),\n",
    "    \"model.norm.weight\":                         (D,),\n",
    "    \"lm_head.weight\":                            (cfg.vocab_size, D),\n",
    "}\n",
    "for name, want in expected_layer0.items():\n",
    "    got = tuple(weights[name].shape)\n",
    "    flag = \"OK\" if got == want else \"MISMATCH\"\n",
    "    print(f\"  {flag:8} {name:48s} {got}\")\n",
    "    assert got == want, f\"{name}: expected {want}, got {got}\"\n",
    "print(\"\\n[OK] All layer-0 / top-level shapes match the config (GQA: K/V are narrower than Q).\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Optional \u2014 load the real TinyLlama-1.1B instead.** The cell below uses the exact\n",
    "same `weights` dict and forward pass; it just sources the tensors from a real\n",
    "checkpoint. It needs `pip install huggingface_hub safetensors` and a ~2.2 GB\n",
    "download, so it is left commented for offline/CI runs.\n",
    "\n",
    "```python\n",
    "# from huggingface_hub import snapshot_download\n",
    "# from safetensors.torch import load_file\n",
    "# import os, json, glob\n",
    "#\n",
    "# path = snapshot_download(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\")\n",
    "# hf_config = LlamaConfig.from_pretrained(path)          # reads the real config.json\n",
    "# cfg = LLMConfig.from_hf(hf_config)\n",
    "# weights = {}\n",
    "# for shard in glob.glob(os.path.join(path, \"*.safetensors\")):\n",
    "#     weights.update({k: v.to(DEVICE, DTYPE) for k, v in load_file(shard).items()})\n",
    "# reference = LlamaForCausalLM.from_pretrained(path, dtype=DTYPE).eval().to(DEVICE)\n",
    "# # everything below runs unchanged.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Model components \u2014 RMSNorm, RoPE, GQA attention, SwiGLU\n",
    "\n",
    "We implement each component to match LLaMA / HuggingFace exactly. The subtle parts:\n",
    "- **RMSNorm** computes the reciprocal-RMS in **float32** then rescales (numerics matter).\n",
    "- **RoPE** uses the `rotate_half` convention (split the head dim in half, not even/odd),\n",
    "  with `inv_freq = theta^(-2i/head_dim)` and `emb = cat(freqs, freqs)`.\n",
    "- **GQA** repeats each KV head `n_heads / n_kv_heads` times before attention.\n",
    "- **SwiGLU** is `down( silu(gate(x)) * up(x) )`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def rms_norm(x, weight, eps):\n",
    "    \"\"\"LLaMA RMSNorm: normalize by RMS in float32, then scale by `weight`.\"\"\"\n",
    "    x32 = x.to(torch.float32)\n",
    "    x32 = x32 * torch.rsqrt(x32.pow(2).mean(-1, keepdim=True) + eps)\n",
    "    return (weight * x32).to(x.dtype)\n",
    "\n",
    "\n",
    "def rope_tables(seq_len, head_dim, theta, device, dtype):\n",
    "    \"\"\"Return (cos, sin) of shape (seq_len, head_dim) for the rotate_half convention.\"\"\"\n",
    "    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))\n",
    "    pos = torch.arange(seq_len, device=device).float()\n",
    "    freqs = torch.outer(pos, inv_freq)            # (T, head_dim/2)\n",
    "    emb = torch.cat((freqs, freqs), dim=-1)        # (T, head_dim)\n",
    "    return emb.cos().to(dtype), emb.sin().to(dtype)\n",
    "\n",
    "\n",
    "def rotate_half(x):\n",
    "    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]\n",
    "    return torch.cat((-x2, x1), dim=-1)\n",
    "\n",
    "\n",
    "def apply_rope(x, cos, sin):\n",
    "    # x: (B, n_heads, T, head_dim); cos/sin: (T, head_dim) broadcast over B and heads.\n",
    "    return x * cos + rotate_half(x) * sin\n",
    "\n",
    "\n",
    "class LlamaEngine:\n",
    "    \"\"\"Minimal config-driven LLaMA forward pass that reads weights by HF name.\"\"\"\n",
    "    def __init__(self, cfg: LLMConfig, weights: dict):\n",
    "        self.cfg = cfg\n",
    "        self.w = weights\n",
    "\n",
    "    def attention(self, x, layer, cos, sin):\n",
    "        c, w = self.cfg, self.w\n",
    "        B, T, _ = x.shape\n",
    "        p = f\"model.layers.{layer}.self_attn.\"\n",
    "        q = (x @ w[p + \"q_proj.weight\"].T).view(B, T, c.n_heads,    c.head_dim).transpose(1, 2)\n",
    "        k = (x @ w[p + \"k_proj.weight\"].T).view(B, T, c.n_kv_heads, c.head_dim).transpose(1, 2)\n",
    "        v = (x @ w[p + \"v_proj.weight\"].T).view(B, T, c.n_kv_heads, c.head_dim).transpose(1, 2)\n",
    "        q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)\n",
    "        rep = c.n_heads // c.n_kv_heads                 # GQA: share each KV head\n",
    "        k = k.repeat_interleave(rep, dim=1)\n",
    "        v = v.repeat_interleave(rep, dim=1)\n",
    "        scores = (q @ k.transpose(-1, -2)) / math.sqrt(c.head_dim)\n",
    "        causal = torch.triu(torch.full((T, T), float(\"-inf\"), device=x.device, dtype=scores.dtype), 1)\n",
    "        attn = torch.softmax(scores + causal, dim=-1)\n",
    "        out = (attn @ v).transpose(1, 2).reshape(B, T, c.n_heads * c.head_dim)\n",
    "        return out @ w[p + \"o_proj.weight\"].T\n",
    "\n",
    "    def swiglu(self, x, layer):\n",
    "        w = self.w\n",
    "        p = f\"model.layers.{layer}.mlp.\"\n",
    "        gate = F.silu(x @ w[p + \"gate_proj.weight\"].T)\n",
    "        up   = x @ w[p + \"up_proj.weight\"].T\n",
    "        return (gate * up) @ w[p + \"down_proj.weight\"].T\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def forward(self, input_ids):\n",
    "        c, w = self.cfg, self.w\n",
    "        x = F.embedding(input_ids, w[\"model.embed_tokens.weight\"])     # (B, T, D)\n",
    "        cos, sin = rope_tables(x.shape[1], c.head_dim, c.rope_theta, x.device, x.dtype)\n",
    "        for i in range(c.n_layers):\n",
    "            p = f\"model.layers.{i}.\"\n",
    "            x = x + self.attention(rms_norm(x, w[p + \"input_layernorm.weight\"], c.rms_eps), i, cos, sin)\n",
    "            x = x + self.swiglu(rms_norm(x, w[p + \"post_attention_layernorm.weight\"], c.rms_eps), i)\n",
    "        x = rms_norm(x, w[\"model.norm.weight\"], c.rms_eps)             # final norm\n",
    "        return x @ w[\"lm_head.weight\"].T                               # (B, T, vocab)\n",
    "\n",
    "\n",
    "engine = LlamaEngine(cfg, weights)\n",
    "print(\"LlamaEngine ready.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Forward pass + verification against Transformers\n",
    "\n",
    "The contract: same token IDs, same dtype, eval mode (no dropout), same positions.\n",
    "We run our engine and HuggingFace on the identical prompt and require the largest\n",
    "absolute logit difference to be tiny (< 1e-4)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "prompt_ids = torch.tensor([[1, 5, 9, 12, 3, 7, 42, 100]], device=DEVICE)\n",
    "\n",
    "ours = engine.forward(prompt_ids)\n",
    "with torch.no_grad():\n",
    "    ref_logits = reference(prompt_ids).logits\n",
    "\n",
    "max_abs_diff = (ours - ref_logits).abs().max().item()\n",
    "print(f\"our logits shape : {tuple(ours.shape)}   [B, T, vocab]\")\n",
    "print(f\"HF  logits shape : {tuple(ref_logits.shape)}\")\n",
    "print(f\"max_abs_diff     : {max_abs_diff:.3e}\")\n",
    "print(f\"argmax match     : {(ours.argmax(-1) == ref_logits.argmax(-1)).all().item()}\")\n",
    "assert max_abs_diff < 1e-4, \"Forward pass does not match the reference \u2014 debug component by component.\"\n",
    "print(\"\\n[OK] Single-sequence forward matches HuggingFace to floating-point precision.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Minimal greedy decode (no cache)\n",
    "\n",
    "Greedy decode: take `argmax` of the **last** position's logits, append it, and run\n",
    "the full forward again. This recomputes the whole prefix every step \u2014 correct but\n",
    "slow (Day 28 adds the KV cache). We verify each generated token equals the\n",
    "reference model's choice at every step."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def greedy_decode(engine, reference, input_ids, max_new_tokens=20):\n",
    "    ids = input_ids.clone()\n",
    "    generated = []\n",
    "    for _ in range(max_new_tokens):\n",
    "        next_id = engine.forward(ids)[:, -1, :].argmax(-1, keepdim=True)   # (B, 1)\n",
    "        # Cross-check against the reference's greedy choice on the same context.\n",
    "        ref_next = reference(ids).logits[:, -1, :].argmax(-1, keepdim=True)\n",
    "        assert (next_id == ref_next).all(), \"Greedy token diverged from reference!\"\n",
    "        ids = torch.cat([ids, next_id], dim=1)\n",
    "        generated.append(int(next_id))\n",
    "    return ids, generated\n",
    "\n",
    "full_ids, new_tokens = greedy_decode(engine, reference, prompt_ids, max_new_tokens=20)\n",
    "print(\"prompt tokens   :\", prompt_ids[0].tolist())\n",
    "print(\"generated tokens:\", new_tokens)\n",
    "print(\"\\n[OK] 20 greedy tokens generated; every token matches the reference model.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Exercises\n",
    "\n",
    "1. **Real checkpoint.** Uncomment the Section-3 cell to load TinyLlama-1.1B and re-run\n",
    "   Sections 4\u20136 unchanged. Confirm `max_abs_diff` stays small (use float32 on CPU/MPS).\n",
    "2. **Break it on purpose.** Replace `rotate_half` with an even/odd interleave and watch\n",
    "   `max_abs_diff` explode \u2014 RoPE convention is the #1 source of mismatch.\n",
    "3. **Component diff ladder.** Compare intermediate tensors (embedding \u2192 first norm \u2192\n",
    "   Q/K/V \u2192 attention out \u2192 FFN) against `reference` with hooks, to localize a bug.\n",
    "4. **GQA accounting.** Print the KV-cache width `n_kv_heads * head_dim` vs the query\n",
    "   width `n_heads * head_dim` and confirm the ratio matches `n_heads / n_kv_heads`.\n",
    "5. **Save artifacts.** Save the prompt, generated tokens, and `max_abs_diff` to a JSON\n",
    "   file \u2014 the seed of a regression test for Days 28\u201330."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Self-Check\n",
    "\n",
    "1. **Why verify logits instead of generated text?** Text can match by accident; logits\n",
    "   expose numeric component bugs (a wrong scale or RoPE offset still often yields\n",
    "   plausible text).\n",
    "2. **What shape does GQA change?** The K/V projection output width is\n",
    "   `n_kv_heads * head_dim`, narrower than Q's `n_heads * head_dim`. Q and O keep full width.\n",
    "3. **Why no KV cache today?** Correct single-forward math must be proven before\n",
    "   optimizing the decode loop \u2014 that is Day 28.\n",
    "4. **Common source of RoPE bugs?** Wrong position offset, wrong `head_dim` pairing, or\n",
    "   the even/odd-vs-half rotation convention.\n",
    "5. **Why config-driven construction?** The same engine loads model variants (TinyLlama,\n",
    "   Llama-3, Qwen) without hard-coded dimensions \u2014 only the weights and config change."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}