{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "# Day 21 - Online Softmax and FlashAttention\n\nThe core online-softmax verification uses NumPy and should pass everywhere. CUDA SDPA benchmarking is optional.\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "import math, time\nimport numpy as np\nnp.set_printoptions(precision=6, suppress=True)\nprint('numpy', np.__version__)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 1. Stable Softmax\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "scores = np.array([2.0, -1.5, 0.3, 4.2], dtype=np.float64)\n\ndef stable_softmax(x):\n    z = x - np.max(x)\n    e = np.exp(z)\n    return e / e.sum()\n\np = stable_softmax(scores)\nprint(p, 'sum=', p.sum())\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 2. Online Softmax\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "def online_softmax(x, block_size):\n    m = -np.inf\n    d = 0.0\n    blocks = []\n    for start in range(0, len(x), block_size):\n        block = x[start:start+block_size]\n        m_new = max(m, float(np.max(block)))\n        d = d * math.exp(m - m_new) + float(np.exp(block - m_new).sum())\n        m = m_new\n        blocks.append((m, d))\n    probs = np.exp(x - m) / d\n    return probs, blocks\n\np_online, states = online_softmax(scores, 2)\nprint('states:', states)\nprint('online:', p_online)\nprint('max diff:', np.max(np.abs(p_online - p)))\nassert np.allclose(p_online, p, atol=1e-7)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 3. HBM Traffic Estimate\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "def standard_attention_mb(T, d, dtype_bytes=2):\n    # Approx S and P writes/reads as two T^2 matrices; rough pedagogical estimate.\n    return 2 * T * T * dtype_bytes / 1e6\n\ndef flash_attention_mb(T, d, dtype_bytes=2):\n    # Q,K,V,O linear traffic, rough estimate.\n    return 4 * T * d * dtype_bytes / 1e6\n\nfor T in [512, 1024, 2048, 4096]:\n    std = standard_attention_mb(T, 128)\n    fla = flash_attention_mb(T, 128)\n    print(f'T={T:4d}: standard~{std:7.1f} MB/head  flash~{fla:5.1f} MB/head  ratio~{std/fla:5.1f}x')\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 4. Tiny NumPy Attention Sanity Check\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "rng = np.random.default_rng(0)\nT,d = 8,4\nQ = rng.normal(size=(T,d))\nK = rng.normal(size=(T,d))\nV = rng.normal(size=(T,d))\nS = Q @ K.T / math.sqrt(d)\nP = np.apply_along_axis(stable_softmax, -1, S)\nO = P @ V\nprint(O.shape, O[0])\nassert O.shape == (T,d)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 5. Optional PyTorch SDPA Benchmark\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "try:\n    import torch\n    import torch.nn.functional as F\n    if not torch.cuda.is_available():\n        print('CUDA unavailable; SDPA flash benchmark skipped.')\n    else:\n        device = 'cuda'\n        B,H,T,d = 1,8,1024,64\n        q = torch.randn(B,H,T,d, device=device, dtype=torch.float16)\n        k = torch.randn(B,H,T,d, device=device, dtype=torch.float16)\n        v = torch.randn(B,H,T,d, device=device, dtype=torch.float16)\n        torch.cuda.synchronize()\n        t0=time.perf_counter(); out=F.scaled_dot_product_attention(q,k,v,is_causal=True); torch.cuda.synchronize()\n        print('SDPA ms:', (time.perf_counter()-t0)*1000, out.shape)\nexcept Exception as e:\n    print('Skipping SDPA benchmark:', e)\n",
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
