{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Day 23 - Speculative Decoding Simulator"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "import numpy as np\n",
        "rng=np.random.default_rng(0)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Expected tokens"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "def expected_tokens(gamma,K): return (1-gamma**(K+1))/(1-gamma)\n",
        "for K in [3,5,7]:\n",
        "    print('K=',K)\n",
        "    for gamma in [0.6,0.7,0.8,0.9]: print(f'  gamma={gamma:.1f}: {expected_tokens(gamma,K):.2f}')"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Monte Carlo"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "def simulate(gamma,K,trials=50000):\n",
        "    total=0\n",
        "    for _ in range(trials):\n",
        "        accepted=0\n",
        "        for _ in range(K):\n",
        "            if rng.random()<gamma: accepted+=1\n",
        "            else: break\n",
        "        total+=accepted+1\n",
        "    return total/trials\n",
        "for gamma in [0.6,0.8,0.9]: print(gamma, simulate(gamma,5), expected_tokens(gamma,5))"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Ratio test"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "draft=np.array([0.50,0.20,0.20,0.10]); target=np.array([0.05,0.40,0.35,0.20])\n",
        "for t in range(len(draft)): print(t,'accept prob',min(1,target[t]/draft[t]))"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}