{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Day 15 — Sampling From Scratch (Notebook)\n",
        "\n",
        "This notebook is the hands-on companion to Day 15. It focuses on the parts of inference you can test without a production serving stack:\n",
        "\n",
        "1. Stable softmax from raw logits.\n",
        "2. Greedy, temperature, top-k, and top-p sampling.\n",
        "3. A minimal `generate()` loop.\n",
        "4. Timing the loop at different prompt lengths.\n",
        "\n",
        "The final model is a tiny random bigram model so the notebook is self-contained. Swap it for your Day 9 GPT by keeping the same interface: `logits = model(idx)` with `idx: (B, T)` and `logits: (B, T, V)`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 1. Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import math, time\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "torch.manual_seed(0)\n",
        "device = (\n",
        "    \"cuda\" if torch.cuda.is_available()\n",
        "    else \"mps\" if torch.backends.mps.is_available()\n",
        "    else \"cpu\"\n",
        ")\n",
        "print(\"torch :\", torch.__version__)\n",
        "print(\"device:\", device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 2. Stable Softmax on the Lesson Logits\n",
        "\n",
        "We use the same logits as the markdown lesson:\n",
        "\n",
        "```text\n",
        "[2.0, -1.5, 0.3, 4.2, -0.8]\n",
        "```\n",
        "\n",
        "The expected probabilities at temperature 1.0 are approximately `[0.097, 0.003, 0.018, 0.876, 0.006]`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "logits = torch.tensor([[2.0, -1.5, 0.3, 4.2, -0.8]], device=device)\n",
        "\n",
        "def stable_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:\n",
        "    shifted = x - x.max(dim=dim, keepdim=True).values\n",
        "    exps = shifted.exp()\n",
        "    return exps / exps.sum(dim=dim, keepdim=True)\n",
        "\n",
        "for tau in [1.0, 0.7, 2.0]:\n",
        "    probs = stable_softmax(logits / tau)\n",
        "    print(f\"tau={tau}:\", [round(x, 6) for x in probs[0].tolist()], \"sum=\", probs.sum().item())\n",
        "\n",
        "greedy = logits.argmax(dim=-1).item()\n",
        "print(\"greedy token:\", greedy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 3. Sampling Helpers\n",
        "\n",
        "`temperature = 0` is handled as greedy decoding. For top-p, the implementation keeps the first token that crosses the cumulative threshold; otherwise the kept mass could remain below `p`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def sample_next_token(\n",
        "    logits: torch.Tensor,\n",
        "    *,\n",
        "    temperature: float = 1.0,\n",
        "    top_k: int | None = None,\n",
        "    top_p: float | None = None,\n",
        ") -> torch.Tensor:\n",
        "    \"\"\"logits: (B, V) -> sampled ids: (B, 1).\"\"\"\n",
        "    if temperature == 0:\n",
        "        return logits.argmax(dim=-1, keepdim=True)\n",
        "    if temperature < 0:\n",
        "        raise ValueError(\"temperature must be non-negative\")\n",
        "\n",
        "    logits = logits / temperature\n",
        "\n",
        "    if top_k is not None:\n",
        "        values, _ = torch.topk(logits, k=top_k, dim=-1)\n",
        "        kth_value = values[:, [-1]]\n",
        "        logits = torch.where(logits < kth_value, torch.full_like(logits, float(\"-inf\")), logits)\n",
        "\n",
        "    if top_p is not None:\n",
        "        sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)\n",
        "        sorted_probs = F.softmax(sorted_logits, dim=-1)\n",
        "        cumulative = sorted_probs.cumsum(dim=-1)\n",
        "\n",
        "        remove_sorted = cumulative > top_p\n",
        "        remove_sorted[:, 1:] = remove_sorted[:, :-1].clone()\n",
        "        remove_sorted[:, 0] = False\n",
        "\n",
        "        remove = torch.zeros_like(remove_sorted).scatter(1, sorted_idx, remove_sorted)\n",
        "        logits = logits.masked_fill(remove, float(\"-inf\"))\n",
        "\n",
        "    probs = F.softmax(logits, dim=-1)\n",
        "    return torch.multinomial(probs, num_samples=1)\n",
        "\n",
        "\n",
        "def kept_tokens_after_top_p(logits: torch.Tensor, top_p: float):\n",
        "    sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)\n",
        "    sorted_probs = F.softmax(sorted_logits, dim=-1)\n",
        "    cumulative = sorted_probs.cumsum(dim=-1)\n",
        "    remove_sorted = cumulative > top_p\n",
        "    remove_sorted[:, 1:] = remove_sorted[:, :-1].clone()\n",
        "    remove_sorted[:, 0] = False\n",
        "    keep_sorted = ~remove_sorted\n",
        "    return sorted_idx[keep_sorted].tolist()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 4. Verify Top-k and Top-p on the Lesson Logits"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "probs = stable_softmax(logits)\n",
        "top2_vals, top2_idx = torch.topk(logits, k=2, dim=-1)\n",
        "print(\"softmax probs:\", [round(x, 6) for x in probs[0].tolist()])\n",
        "print(\"top-k k=2 tokens:\", top2_idx[0].tolist())\n",
        "print(\"top-k renorm probs:\", [round(x, 6) for x in F.softmax(top2_vals, dim=-1)[0].tolist()])\n",
        "print(\"top-p p=0.9 kept tokens:\", kept_tokens_after_top_p(logits, 0.9))\n",
        "\n",
        "torch.manual_seed(123)\n",
        "print(\"greedy sample:\", sample_next_token(logits, temperature=0).item())\n",
        "print(\"temperature sample:\", sample_next_token(logits, temperature=2.0).item())\n",
        "print(\"top-k sample:\", sample_next_token(logits, temperature=1.0, top_k=2).item())\n",
        "print(\"top-p sample:\", sample_next_token(logits, temperature=1.0, top_p=0.9).item())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 5. A Tiny Self-Contained Model\n",
        "\n",
        "This is not a Transformer. It is a bigram language model: the logits for the next token depend only on the current token. That is enough to test the mechanics of `generate()` without requiring a trained Day 9 checkpoint."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "vocab = list(\"abcdefghijklmnopqrstuvwxyz .,\\n\")\n",
        "stoi = {ch: i for i, ch in enumerate(vocab)}\n",
        "itos = {i: ch for ch, i in stoi.items()}\n",
        "\n",
        "def encode(text: str) -> torch.Tensor:\n",
        "    ids = [stoi.get(ch, stoi[\" \"]) for ch in text.lower()]\n",
        "    return torch.tensor([ids], dtype=torch.long, device=device)\n",
        "\n",
        "def decode(ids: torch.Tensor) -> str:\n",
        "    return \"\".join(itos[int(i)] for i in ids.flatten().tolist())\n",
        "\n",
        "class TinyBigramLM(nn.Module):\n",
        "    def __init__(self, vocab_size: int):\n",
        "        super().__init__()\n",
        "        self.table = nn.Embedding(vocab_size, vocab_size)\n",
        "        with torch.no_grad():\n",
        "            self.table.weight.normal_(mean=0.0, std=0.3)\n",
        "            # Give spaces, vowels, and common letters a slight global bias.\n",
        "            for ch, bonus in {\" \": 1.0, \"e\": 0.8, \"t\": 0.6, \"a\": 0.5, \"o\": 0.5, \"n\": 0.4}.items():\n",
        "                self.table.weight[:, stoi[ch]] += bonus\n",
        "\n",
        "    def forward(self, idx: torch.Tensor) -> torch.Tensor:\n",
        "        return self.table(idx)  # (B, T, V)\n",
        "\n",
        "model = TinyBigramLM(len(vocab)).to(device)\n",
        "prompt = encode(\"the \")\n",
        "print(\"prompt ids:\", prompt.tolist())\n",
        "print(\"prompt text:\", repr(decode(prompt)))\n",
        "print(\"logits shape:\", tuple(model(prompt).shape))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 6. Minimal Generate Loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def generate(\n",
        "    model: nn.Module,\n",
        "    idx: torch.Tensor,\n",
        "    *,\n",
        "    max_new_tokens: int,\n",
        "    block_size: int,\n",
        "    temperature: float = 1.0,\n",
        "    top_k: int | None = None,\n",
        "    top_p: float | None = None,\n",
        ") -> torch.Tensor:\n",
        "    model.eval()\n",
        "    for _ in range(max_new_tokens):\n",
        "        idx_cond = idx[:, -block_size:]\n",
        "        logits = model(idx_cond)          # (B, T_context, V)\n",
        "        next_logits = logits[:, -1, :]    # (B, V)\n",
        "        next_id = sample_next_token(next_logits, temperature=temperature, top_k=top_k, top_p=top_p)\n",
        "        idx = torch.cat([idx, next_id], dim=1)\n",
        "    return idx\n",
        "\n",
        "settings = [\n",
        "    (\"greedy\", dict(temperature=0)),\n",
        "    (\"temperature 0.8\", dict(temperature=0.8)),\n",
        "    (\"temperature 1.4\", dict(temperature=1.4)),\n",
        "    (\"top-p 0.9\", dict(temperature=1.0, top_p=0.9)),\n",
        "]\n",
        "\n",
        "for name, kwargs in settings:\n",
        "    torch.manual_seed(7)\n",
        "    out = generate(model, prompt.clone(), max_new_tokens=80, block_size=32, **kwargs)\n",
        "    print(\"=\" * 80)\n",
        "    print(name)\n",
        "    print(repr(decode(out)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 7. Timing Prompt Lengths\n",
        "\n",
        "This toy model is tiny, so the absolute speeds are not meaningful. The measurement exists to practice the interface and to show how you would time a generation loop. With a real Transformer and no KV cache, longer prompts get more expensive because each step runs over the current context."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def time_generation(prompt_len: int, new_tokens: int = 128):\n",
        "    idx = torch.full((1, prompt_len), stoi[\" \"], dtype=torch.long, device=device)\n",
        "    if device == \"cuda\":\n",
        "        torch.cuda.synchronize()\n",
        "    start = time.perf_counter()\n",
        "    _ = generate(model, idx, max_new_tokens=new_tokens, block_size=256, temperature=1.0, top_p=0.9)\n",
        "    if device == \"cuda\":\n",
        "        torch.cuda.synchronize()\n",
        "    elapsed = time.perf_counter() - start\n",
        "    return elapsed, new_tokens / elapsed\n",
        "\n",
        "for prompt_len in [16, 64, 256, 512]:\n",
        "    elapsed, tps = time_generation(prompt_len)\n",
        "    print(f\"prompt_len={prompt_len:>3} elapsed={elapsed:.4f}s tokens/sec={tps:.1f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Exercise Extension\n",
        "\n",
        "1. Replace `TinyBigramLM` with your Day 9 GPT.\n",
        "2. Run the same prompt with greedy, temperature 0.8, temperature 1.2, and top-p 0.9.\n",
        "3. Add `eos_token_id` support to `generate()`.\n",
        "4. Add a repetition penalty and observe how outputs change.\n",
        "5. Record TTFT separately from decode speed by timing the first forward pass apart from the loop."
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.14.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
