{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Day 22 - INT8/INT4 Quantization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "import numpy as np\n",
        "np.set_printoptions(precision=6, suppress=True)\n",
        "print(\"numpy\", np.__version__)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Per-tensor INT8"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "w=np.array([0.10,-0.30,0.70,-1.00,0.50],dtype=np.float32)\n",
        "def quantize_affine(x,bits=8):\n",
        "    qmin,qmax=0,2**bits-1; mn,mx=float(x.min()),float(x.max()); scale=(mx-mn)/(qmax-qmin) if mx!=mn else 1.0\n",
        "    q=np.round((x-mn)/scale).clip(qmin,qmax).astype(np.uint8); xhat=mn+q.astype(np.float32)*scale\n",
        "    return q,xhat,scale,mn\n",
        "q,xhat,scale,mn=quantize_affine(w); print(scale,mn); print(q); print(xhat); print('mse',np.mean((w-xhat)**2))"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Per-channel versus per-tensor"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "rng=np.random.default_rng(0)\n",
        "W=np.vstack([rng.normal(0,0.02,16),rng.normal(0,0.20,16),rng.normal(0,1.00,16)]).astype(np.float32)\n",
        "def mse(a,b): return float(np.mean((a-b)**2))\n",
        "_,Wt,*_=quantize_affine(W)\n",
        "Wc=np.vstack([quantize_affine(row)[1] for row in W])\n",
        "print('per tensor',mse(W,Wt)); print('per channel',mse(W,Wc)); assert mse(W,Wc)<mse(W,Wt)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Memory table"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "for params in [7e9,13e9,70e9]:\n",
        "    print(f\"\\n{params/1e9:.0f}B params\")\n",
        "    for name,b in [('FP32',4),('FP16/BF16',2),('INT8/FP8',1),('INT4/NF4',0.5)]: print(f\"  {name:10s}: {params*b/1e9:6.1f} GB\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Group INT4 sketch"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "def quantize_groups(x,group_size=8,bits=4):\n",
        "    flat=x.reshape(-1); out=np.empty_like(flat); scales=[]\n",
        "    for start in range(0,len(flat),group_size):\n",
        "        q,xhat,s,mn=quantize_affine(flat[start:start+group_size],bits); out[start:start+group_size]=xhat; scales.append(s)\n",
        "    return out.reshape(x.shape),np.array(scales)\n",
        "W4,scales=quantize_groups(W); print('INT4 group mse',mse(W,W4),'num scales',len(scales))"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}