{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "# Day 19 - Apple Silicon, MLX, MPS\n\nCore calculations run anywhere. PyTorch MPS and MLX cells run only when the relevant packages are installed.\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "import math, time\nimport numpy as np\nnp.set_printoptions(precision=4, suppress=True)\nprint('numpy', np.__version__)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 1. Model Footprint Calculator\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "def model_gb(params_b, bits):\n    return params_b * 1e9 * bits / 8 / 1e9\n\nfor params in [7, 13, 70, 405]:\n    print(f'{params:>3}B: fp16={model_gb(params,16):6.1f} GB  int4={model_gb(params,4):6.1f} GB')\n\nassert model_gb(70, 4) == 35.0\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 2. Bandwidth-Limited Decode Estimate\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "def upper_bound_tok_s(bandwidth_gb_s, model_gb_per_token):\n    return bandwidth_gb_s / model_gb_per_token\n\nprint('M2 Ultra-style 800GB/s over 35GB weights:', upper_bound_tok_s(800, 35), 'tok/s upper bound')\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 3. NumPy Transformer Block Reference\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "rng = np.random.default_rng(0)\nB,T,D,H = 1,4,8,2\nhead_dim = D // H\nx = rng.normal(size=(B,T,D)).astype(np.float32)\nWq = rng.normal(scale=0.1, size=(D,D)).astype(np.float32)\nWk = rng.normal(scale=0.1, size=(D,D)).astype(np.float32)\nWv = rng.normal(scale=0.1, size=(D,D)).astype(np.float32)\nWo = rng.normal(scale=0.1, size=(D,D)).astype(np.float32)\n\ndef rms_norm(x, eps=1e-5):\n    return x / np.sqrt(np.mean(x*x, axis=-1, keepdims=True) + eps)\n\ndef softmax(a, axis=-1):\n    z = a - a.max(axis=axis, keepdims=True)\n    e = np.exp(z)\n    return e / e.sum(axis=axis, keepdims=True)\n\ndef attention_block_np(x):\n    q = x @ Wq; k = x @ Wk; v = x @ Wv\n    q = q.reshape(B,T,H,head_dim).transpose(0,2,1,3)\n    k = k.reshape(B,T,H,head_dim).transpose(0,2,1,3)\n    v = v.reshape(B,T,H,head_dim).transpose(0,2,1,3)\n    scores = q @ k.transpose(0,1,3,2) / math.sqrt(head_dim)\n    mask = np.triu(np.ones((T,T), dtype=bool), 1)\n    scores = np.where(mask, -1e9, scores)\n    out = softmax(scores) @ v\n    out = out.transpose(0,2,1,3).reshape(B,T,D)\n    return out @ Wo\n\nout = attention_block_np(rms_norm(x))\nprint(out.shape, out[0, -1, :3])\nassert out.shape == (B,T,D)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 4. Optional PyTorch MPS Check\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "try:\n    import torch\n    device = 'mps' if torch.backends.mps.is_available() else 'cpu'\n    print('torch', torch.__version__, 'device', device)\n    tx = torch.tensor(x, device=device)\n    print('tensor device:', tx.device, 'mean:', float(tx.mean().cpu()))\nexcept Exception as e:\n    print('Skipping torch/MPS:', e)\n",
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": "## 5. Optional MLX Check\n"
    },
    {
      "cell_type": "code",
      "metadata": {},
      "source": "try:\n    import mlx.core as mx\n    a = mx.array([1.0, 2.0, 3.0])\n    y = a * 2 + 1\n    mx.eval(y)\n    print('MLX result:', y)\nexcept Exception as e:\n    print('MLX unavailable; install with: pip install mlx mlx-lm')\n",
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
