{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "# Day 20 - KV Cache Implementation\n\nThis notebook computes KV memory and verifies a tiny cache update with NumPy.\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "import math, time\nimport numpy as np\nnp.set_printoptions(precision=4, suppress=True)\nprint('numpy', np.__version__)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 1. KV Cache Memory Calculator\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "def kv_cache_gb(B, T, n_layers, n_kv_heads, head_dim, dtype_bytes=2):\n    bytes_ = B * T * n_layers * 2 * n_kv_heads * head_dim * dtype_bytes\n    return bytes_ / 1024**3\n\nprint('LLaMA-2 7B style, B=1,T=4096:', kv_cache_gb(1,4096,32,32,128), 'GiB')\nfor B in [1,16,64]:\n    gqa = kv_cache_gb(B,4096,80,8,128)\n    mha = kv_cache_gb(B,4096,80,64,128)\n    print(f'70B B={B:2d}: GQA={gqa:7.2f} GiB  MHA={mha:7.2f} GiB')\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 2. Tiny Attention With and Without Cache\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "rng = np.random.default_rng(1)\nT,D = 6,8\nx = rng.normal(size=(T,D)).astype(np.float32)\nWq = rng.normal(scale=0.2, size=(D,D)).astype(np.float32)\nWk = rng.normal(scale=0.2, size=(D,D)).astype(np.float32)\nWv = rng.normal(scale=0.2, size=(D,D)).astype(np.float32)\n\ndef softmax(a):\n    z = a - a.max(axis=-1, keepdims=True)\n    e = np.exp(z)\n    return e / e.sum(axis=-1, keepdims=True)\n\ndef full_next_output(prefix):\n    q = prefix[-1:] @ Wq\n    k = prefix @ Wk\n    v = prefix @ Wv\n    p = softmax(q @ k.T / math.sqrt(D))\n    return p @ v\n\n# Prefill cache for first T-1 tokens.\nprefix = x[:-1]\nK_cache = prefix @ Wk\nV_cache = prefix @ Wv\nnew = x[-1:]\nK_cache2 = np.concatenate([K_cache, new @ Wk], axis=0)\nV_cache2 = np.concatenate([V_cache, new @ Wv], axis=0)\nq_new = new @ Wq\ncached_out = softmax(q_new @ K_cache2.T / math.sqrt(D)) @ V_cache2\nfull_out = full_next_output(x)\nprint('max diff:', np.max(np.abs(cached_out - full_out)))\nassert np.allclose(cached_out, full_out, atol=1e-6)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 3. Sliding Window Simulation\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "def append_sliding(cache, row, window):\n    cache = np.concatenate([cache, row], axis=0)\n    return cache[-window:]\n\ncache = np.zeros((0, D), dtype=np.float32)\nfor i in range(10):\n    cache = append_sliding(cache, np.ones((1,D), dtype=np.float32) * i, window=4)\n    print(i, 'cache rows:', cache[:,0].astype(int).tolist())\nassert cache.shape[0] == 4\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 4. Exercise Assertions\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "b1 = kv_cache_gb(1,4096,80,8,128)\nb64 = kv_cache_gb(64,4096,80,8,128)\nprint('70B GQA B=1 GiB:', b1)\nprint('70B GQA B=64 GiB:', b64)\nassert b64 > 64 * 1.0\n",
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
