diff --git a/README.md b/README.md index 785f6fbab7..9404315528 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni - **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link. - **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices. - **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication. +- **Experimental tinygrad support:** exo uses [tinygrad](https://github.com/tinygrad/tinygrad) as an inference backend for non-Apple systems, with support for AMD (ROCm/HIP) and NVIDIA (CUDA) GPU backends. The current implementation works on single instance, **model sharding will be tested in the near future.** ## Dashboard @@ -136,6 +137,8 @@ This starts the exo dashboard and API at http://localhost:52415/ **Installation methods:** **Option 1: Using system package manager (Ubuntu/Debian example):** + +- For Ubuntu/Debian based distros ```bash # Install Node.js and npm sudo apt update @@ -149,6 +152,15 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh rustup toolchain install nightly ``` +- For Arch based distros +```bash +sudo pacman -S uv rustup + +# Install the nightly toolchain +rustup toolchain install nightly +``` + + **Option 2: Using Homebrew on Linux (if preferred):** ```bash # Install Homebrew on Linux @@ -164,7 +176,149 @@ rustup toolchain install nightly **Note:** The `macmon` package is macOS-only and not required for Linux. -Clone the repo, build the dashboard, and run exo: +**Install GPU drivers for tinygrad** + +The tinygrad inference backend supports multiple GPU backends. Install the appropriate +drivers for your hardware. The backend is selected at runtime via the `DEV` environment +variable (e.g. `DEV=HIP`, `DEV=CUDA`). If not set, tinygrad auto-detects the best +available backend. + +
+AMD GPUs (ROCm / HIP) + +ROCm provides GPU compute support for AMD GPUs. The tinygrad backend uses different +device modes depending on your GPU architecture: + +| GPU Architecture | Device | Notes | +|---|---|---| +| **RDNA 3+** (RX 7000 series, MI300, etc.) | `DEV=AMD` | Optimized HCQ backend | +| **RDNA 2** (RX 6000 series) | `DEV=HIP` | HIP runtime backend | + +> **Important:** The `DEV=AMD` backend (HCQ) only supports RDNA 3 and newer. RDNA 2 users +> **must** use `DEV=HIP`. Both modes require a ROCm installation. + +**Ubuntu / Debian:** + +```bash +# Install prerequisites +sudo apt update && sudo apt install -y wget gpg + +# Add the AMD ROCm repository +# Check https://rocm.docs.amd.com/en/latest/deploy/linux/install.html for the latest version +wget https://repo.radeon.com/amdgpu-install/6.4/ubuntu/$(lsb_release -cs)/amdgpu-install_6.4.60400-1_all.deb +sudo apt install -y ./amdgpu-install_6.4.60400-1_all.deb + +# Install ROCm HIP libraries +sudo amdgpu-install --usecase=hip --no-dkms + +# Add your user to the render and video groups +sudo usermod -aG render,video $USER +``` + +> For Debian, replace `$(lsb_release -cs)` with the equivalent Ubuntu codename +> (e.g. `noble` for Debian Trixie). Check [AMD's documentation](https://rocm.docs.amd.com/en/latest/deploy/linux/install.html) +> for supported versions. + +**Fedora:** + +```bash +# Install prerequisites +sudo dnf install -y kernel-headers kernel-devel + +# Add the AMD ROCm repository +# Check https://rocm.docs.amd.com/en/latest/deploy/linux/install.html for the latest version +sudo tee /etc/yum.repos.d/amdgpu.repo <<'EOF' +[amdgpu] +name=amdgpu +baseurl=https://repo.radeon.com/amdgpu/6.4/el/9/main/x86_64/ +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +EOF + +sudo tee /etc/yum.repos.d/rocm.repo <<'EOF' +[rocm] +name=rocm +baseurl=https://repo.radeon.com/rocm/el9/6.4/main +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +EOF + +sudo dnf install -y rocm-hip-runtime hip-devel + +# Add your user to the render and video groups +sudo usermod -aG render,video $USER +``` + +**Arch:** + +```bash +# Install ROCm HIP from official repositories +sudo pacman -S rocm-hip-runtime rocm-hip-sdk + +# Add your user to the render and video groups +sudo usermod -aG render,video $USER +``` + +**Running exo with AMD GPU:** + +```bash +# RDNA 3+ GPUs +DEV=AMD uv run exo + +# RDNA 2 GPUs (RX 6000 series) +DEV=HIP uv run exo +``` + +
+ +
+NVIDIA GPUs (CUDA) + +NVIDIA GPUs require proprietary drivers and the CUDA toolkit. + +**Ubuntu / Debian:** + +```bash +# Option 1: Install via ubuntu-drivers (Ubuntu only, simplest method) +sudo add-apt-repository ppa:graphics-drivers/ppa && sudo apt update +sudo ubuntu-drivers autoinstall + +# Option 2: Install CUDA toolkit directly (Ubuntu / Debian) +# Check https://developer.nvidia.com/cuda-downloads for the latest version +sudo apt install -y nvidia-driver-560 nvidia-cuda-toolkit +``` + +**Fedora:** + +```bash +# Enable RPM Fusion repositories +sudo dnf install -y \ + https://download1.rpmfusion.org/free/fedora/rpmfusion-free-release-$(rpm -E %fedora).noarch.rpm \ + https://download1.rpmfusion.org/nonfree/fedora/rpmfusion-nonfree-release-$(rpm -E %fedora).noarch.rpm + +# Install NVIDIA drivers and CUDA +sudo dnf install -y akmod-nvidia xorg-x11-drv-nvidia-cuda +``` + +**Arch:** + +```bash +# Install NVIDIA drivers and CUDA toolkit +sudo pacman -S nvidia cuda +``` + +**Running exo with NVIDIA GPU:** + +```bash +DEV=CUDA uv run exo +``` + +
+ +> **Tip:** You can verify your GPU is detected by running with debug logging: +> `DEBUG=1 DEV=HIP uv run exo` (substitute your device backend). ```bash # Clone exo @@ -179,8 +333,6 @@ uv run exo This starts the exo dashboard and API at http://localhost:52415/ -**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one. - **Configuration Options:** - `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity. @@ -426,10 +578,10 @@ The tool outputs performance metrics including prompt tokens per second (prompt_ ## Hardware Accelerator Support -On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community. +On macOS, exo uses the GPU via [MLX](https://github.com/ml-explore/mlx). On Linux, exo uses the GPU via [tinygrad](https://github.com/tinygrad/tinygrad) with support for AMD (ROCm/HIP) and NVIDIA (CUDA) backends. See the [Linux installation guide](#run-from-source-linux) for GPU driver setup. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community. --- ## Contributing -See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo. \ No newline at end of file +See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo. diff --git a/dashboard/package-lock.json b/dashboard/package-lock.json index 345c73d257..e28a6b6003 100644 --- a/dashboard/package-lock.json +++ b/dashboard/package-lock.json @@ -865,6 +865,7 @@ "integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@standard-schema/spec": "^1.0.0", "@sveltejs/acorn-typescript": "^1.0.5", @@ -904,6 +905,7 @@ "integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@sveltejs/vite-plugin-svelte-inspector": "^4.0.1", "debug": "^4.4.1", @@ -1520,6 +1522,7 @@ "integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -1529,6 +1532,7 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -1941,6 +1945,7 @@ "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", "dev": true, "license": "ISC", + "peer": true, "engines": { "node": ">=12" } @@ -2648,6 +2653,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -2690,6 +2696,7 @@ "integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==", "dev": true, "license": "MIT", + "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -2862,6 +2869,7 @@ "resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz", "integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==", "license": "MIT", + "peer": true, "dependencies": { "@jridgewell/remapping": "^2.3.4", "@jridgewell/sourcemap-codec": "^1.5.0", @@ -3006,6 +3014,7 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -3027,6 +3036,7 @@ "integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.4.4", diff --git a/dashboard/src/lib/components/ModelCard.svelte b/dashboard/src/lib/components/ModelCard.svelte index b432b7a87b..5b1edc2f10 100644 --- a/dashboard/src/lib/components/ModelCard.svelte +++ b/dashboard/src/lib/components/ModelCard.svelte @@ -21,7 +21,7 @@ } | null; nodes?: Record; sharding?: "Pipeline" | "Tensor"; - runtime?: "MlxRing" | "MlxJaccl"; + runtime?: "MlxRing" | "MlxJaccl" | "Tinygrad"; onLaunch?: () => void; tags?: string[]; apiPreview?: PlacementPreview | null; diff --git a/dashboard/src/lib/stores/app.svelte.ts b/dashboard/src/lib/stores/app.svelte.ts index 00c6669f23..3c382540d9 100644 --- a/dashboard/src/lib/stores/app.svelte.ts +++ b/dashboard/src/lib/stores/app.svelte.ts @@ -168,7 +168,7 @@ export interface ModelDownloadStatus { export interface PlacementPreview { model_id: string; sharding: "Pipeline" | "Tensor"; - instance_meta: "MlxRing" | "MlxJaccl"; + instance_meta: "MlxRing" | "MlxJaccl" | "Tinygrad"; instance: unknown | null; memory_delta_by_node: Record | null; error: string | null; @@ -580,6 +580,7 @@ class AppStore { debugMode = $state(false); topologyOnlyMode = $state(false); chatSidebarVisible = $state(true); // Shown by default + enableLogprobs = $state(false); // Image generation params imageGenerationParams = $state({ @@ -608,6 +609,7 @@ class AppStore { this.loadTopologyOnlyModeFromStorage(); this.loadChatSidebarVisibleFromStorage(); this.loadImageGenerationParamsFromStorage(); + this.loadEnableLogprobsFromStorage(); } } @@ -677,6 +679,25 @@ class AppStore { } } + private loadEnableLogprobsFromStorage() { + try { + const stored = localStorage.getItem("exo-enable-logprobs"); + if (stored !== null) { + this.enableLogprobs = stored === "true"; + } + } catch (error) { + console.error("Failed to load enable logprobs:", error); + } + } + + private saveEnableLogprobsToStorage() { + try { + localStorage.setItem("exo-enable-logprobs", this.enableLogprobs ? "true" : "false"); + } catch (error) { + console.error("Failed to save enable logprobs:", error); + } + } + private loadTopologyOnlyModeFromStorage() { try { const stored = localStorage.getItem("exo-topology-only-mode"); @@ -1213,6 +1234,20 @@ class AppStore { this.saveDebugModeToStorage(); } + getEnableLogprobs(): boolean { + return this.enableLogprobs; + } + + setEnableLogprobs(enabled: boolean) { + this.enableLogprobs = enabled; + this.saveEnableLogprobsToStorage(); + } + + toggleEnableLogprobs() { + this.enableLogprobs = !this.enableLogprobs; + this.saveEnableLogprobsToStorage(); + } + getTopologyOnlyMode(): boolean { return this.topologyOnlyMode; } @@ -1658,8 +1693,7 @@ class AppStore { model: modelToUse, messages: apiMessages, stream: true, - logprobs: true, - top_logprobs: 5, + ...(this.enableLogprobs && { logprobs: true, top_logprobs: 5 }), }), }); @@ -1865,8 +1899,7 @@ class AppStore { model: modelToUse, messages: apiMessages, stream: true, - logprobs: true, - top_logprobs: 5, + ...(this.enableLogprobs && { logprobs: true, top_logprobs: 5 }), }), }); @@ -2345,8 +2378,7 @@ class AppStore { messages: apiMessages, temperature: 0.7, stream: true, - logprobs: true, - top_logprobs: 5, + ...(this.enableLogprobs && { logprobs: true, top_logprobs: 5 }), ...(enableThinking != null && { enable_thinking: enableThinking, }), @@ -3254,6 +3286,10 @@ export const toggleSidebar = () => appStore.toggleSidebar(); export const toggleDebugMode = () => appStore.toggleDebugMode(); export const setDebugMode = (enabled: boolean) => appStore.setDebugMode(enabled); +export const enableLogprobs = () => appStore.getEnableLogprobs(); +export const toggleEnableLogprobs = () => appStore.toggleEnableLogprobs(); +export const setEnableLogprobs = (enabled: boolean) => + appStore.setEnableLogprobs(enabled); export const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode(); export const setTopologyOnlyMode = (enabled: boolean) => appStore.setTopologyOnlyMode(enabled); diff --git a/dashboard/src/routes/+page.svelte b/dashboard/src/routes/+page.svelte index 7f374dd748..9d2023faf2 100644 --- a/dashboard/src/routes/+page.svelte +++ b/dashboard/src/routes/+page.svelte @@ -815,7 +815,7 @@ return model.tasks.includes("ImageToImage"); } let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline"); - type InstanceMeta = "MlxRing" | "MlxJaccl"; + type InstanceMeta = "MlxRing" | "MlxJaccl" | "Tinygrad"; // Launch defaults persistence const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults-v2"; @@ -882,6 +882,7 @@ } let selectedInstanceType = $state("MlxRing"); + let instanceTypeInitialized = $state(false); let selectedMinNodes = $state(1); let minNodesInitialized = $state(false); let launchingModelId = $state(null); @@ -1074,9 +1075,11 @@ } const matchesSelectedRuntime = (runtime: InstanceMeta): boolean => - selectedInstanceType === "MlxRing" - ? runtime === "MlxRing" - : runtime === "MlxJaccl"; + selectedInstanceType === "Tinygrad" + ? runtime === "Tinygrad" + : selectedInstanceType === "MlxRing" + ? runtime === "MlxRing" + : runtime === "MlxJaccl"; // Helper to check if a model can be launched (has valid placement with >= minNodes) function canModelFit(modelId: string): boolean { @@ -2492,6 +2495,22 @@ } }); + // Auto-detect instance type from available placement previews (e.g., Tinygrad on Linux) + $effect(() => { + if (instanceTypeInitialized || previewsData.length === 0) return; + const hasCurrentType = previewsData.some( + (p: PlacementPreview) => matchesSelectedRuntime(p.instance_meta), + ); + if (!hasCurrentType) { + // Switch to whatever the server provides + const firstMeta = previewsData[0]?.instance_meta; + if (firstMeta) { + selectedInstanceType = firstMeta; + } + } + instanceTypeInitialized = true; + }); + // Calculate total memory usage across all nodes const clusterMemory = $derived(() => { if (!data) return { used: 0, total: 0 }; @@ -4985,6 +5004,29 @@ RDMA (Fast) + diff --git a/pyproject.toml b/pyproject.toml index c4bfa55052..cf11937df0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "python-multipart>=0.0.21", "msgspec>=0.19.0", "zstandard>=0.23.0", + "tinygrad>=0.12.0; sys_platform == 'linux'", ] [project.scripts] diff --git a/src/exo/main.py b/src/exo/main.py index 60729cc692..528dd401fb 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -43,7 +43,7 @@ class Node: @classmethod async def create(cls, args: "Args") -> Self: keypair = get_node_id_keypair() - node_id = NodeId(keypair.to_node_id()) + node_id = NodeId(keypair.to_peer_id()) session_id = SessionId(master_node_id=node_id, election_clock=0) router = Router.create(keypair) await router.register_topic(topics.GLOBAL_EVENTS) diff --git a/src/exo/master/api.py b/src/exo/master/api.py index 5ecc1575d2..5e4f807480 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -2,6 +2,7 @@ import contextlib import json import random +import sys import time from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator from datetime import datetime, timezone @@ -47,7 +48,9 @@ ) from exo.master.event_log import DiskEventLog from exo.master.image_store import ImageStore -from exo.master.placement import place_instance as get_instance_placements +from exo.master.placement import ( + place_instance as get_instance_placements, +) from exo.shared.apply import apply from exo.shared.constants import ( DASHBOARD_DIR, @@ -169,7 +172,12 @@ ) from exo.shared.types.state import State from exo.shared.types.worker.downloads import DownloadCompleted -from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta +from exo.shared.types.worker.instances import ( + Instance, + InstanceId, + InstanceMeta, + default_instance_meta, +) from exo.shared.types.worker.shards import Sharding from exo.utils.banner import print_startup_banner from exo.utils.channels import Receiver, Sender, channel @@ -394,9 +402,11 @@ async def get_placement( self, model_id: ModelId, sharding: Sharding = Sharding.Pipeline, - instance_meta: InstanceMeta = InstanceMeta.MlxRing, + instance_meta: InstanceMeta | None = None, min_nodes: int = 1, ) -> Instance: + if instance_meta is None: + instance_meta = default_instance_meta() model_card = await ModelCard.load(model_id) try: @@ -446,8 +456,13 @@ async def get_placement_previews( status_code=400, detail=f"Failed to load model card: {exc}" ) from exc instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = [] + instance_metas = ( + (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl) + if sys.platform == "darwin" + else (InstanceMeta.Tinygrad,) + ) for sharding in (Sharding.Pipeline, Sharding.Tensor): - for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl): + for instance_meta in instance_metas: instance_combinations.extend( [ (sharding, instance_meta, i) diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index 5f191b846f..57b39f61d3 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -41,6 +41,7 @@ InstanceMeta, MlxJacclInstance, MlxRingInstance, + TinygradInstance, ) from exo.shared.types.worker.shards import Sharding @@ -142,7 +143,7 @@ def place_instance( instance_id = InstanceId() target_instances = dict(deepcopy(current_instances)) - if len(selected_cycle) == 1: + if len(selected_cycle) == 1 and command.instance_meta != InstanceMeta.Tinygrad: command.instance_meta = InstanceMeta.MlxRing match command.instance_meta: @@ -192,6 +193,11 @@ def get_device_rank(node_id: NodeId) -> int: hosts_by_node=hosts_by_node, ephemeral_port=ephemeral_port, ) + case InstanceMeta.Tinygrad: + target_instances[instance_id] = TinygradInstance( + instance_id=instance_id, + shard_assignments=shard_assignments, + ) return target_instances diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index cad495ead2..06042f677b 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -25,6 +25,7 @@ InstanceMeta, MlxJacclInstance, MlxRingInstance, + TinygradInstance, ) from exo.shared.types.worker.runners import ShardAssignments from exo.shared.types.worker.shards import Sharding @@ -456,3 +457,25 @@ def test_tensor_rdma_backend_connectivity_matrix( else: ip_part = coordinator.split(":")[0] assert len(ip_part.split(".")) == 4 + +def test_place_tinygrad_single_node(model_card: ModelCard): + """Tinygrad placement should create a TinygradInstance for a single node.""" + topology = Topology() + node_id = NodeId() + topology.add_node(node_id) + node_memory = {node_id: create_node_memory(model_card.storage_size.in_bytes * 2)} + node_network = {node_id: create_node_network()} + + command = PlaceInstance( + command_id=CommandId(), + model_card=model_card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.Tinygrad, + min_nodes=1, + ) + + placements = place_instance(command, topology, {}, node_memory, node_network) + + assert len(placements) == 1 + instance = list(placements.values())[0] + assert isinstance(instance, TinygradInstance) diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index d71275b763..083baba074 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -221,7 +221,7 @@ def get_node_id_keypair( Obtain the :class:`PeerId` by from it. """ # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates - return Keypair.generate() + return Keypair.generate_ed25519() def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: return Path(str(path) + ".lock") @@ -241,6 +241,6 @@ def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: # if no valid credentials, create new ones and persist with open(path, "w+b") as f: - keypair = Keypair.generate_ed25519() + keypair = Keypair.generate() f.write(keypair.to_bytes()) return keypair diff --git a/src/exo/shared/architecture/__init__.py b/src/exo/shared/architecture/__init__.py new file mode 100644 index 0000000000..670bdeaef8 --- /dev/null +++ b/src/exo/shared/architecture/__init__.py @@ -0,0 +1,72 @@ +from typing import Any, Literal + +from pydantic import BaseModel + +AttentionType = Literal["multi_head", "grouped_query", "multi_latent"] +MLPType = Literal["swiglu", "moe_top_k"] +NormType = Literal["rms_norm", "layer_norm"] +RoPEType = Literal["standard", "ntk_aware", "yarn"] + +class ArchitectureSpec(BaseModel, frozen=True, strict=True): + name: str + attention_type: AttentionType + mlp_type: MLPType + norm_type: NormType + rope_type: RoPEType + + # Weight key templates ({layer_idx} per layer) + layer_prefix: str + q_proj_key: str + k_proj_key: str + v_proj_key: str + o_proj_key: str + gate_proj_key: str + up_proj_key: str + down_proj_key: str + input_norm_key: str + post_attn_norm_key: str + final_norm_key: str + lm_head_key: str + + # Optional normalised keys + + q_norm_key: str | None = None + k_norm_key: str | None = None + + # MoE (optional, follow-up) + router_key: str | None = None + expert_prefix: str | None = None + + # MLA (optional, follow-up) + kv_lora_rank: int | None = None + q_lora_rank: int | None = None + + # Embedding key + embed_key: str | None = None + + # RoPE default + rope_theta: float = 10000.0 + + # Tokenizer + tokenizer_type: Literal["huggingface", "tiktoken"] = "huggingface" + +ARCHITECTURE_REGISTRY: dict[str, ArchitectureSpec] = {} + +def register(huggingface_name: str, spec: ArchitectureSpec) -> None: + ARCHITECTURE_REGISTRY[huggingface_name] = spec + +def detect_architecture(raw_config: dict[str, Any]) -> ArchitectureSpec: + architectures: list[str] = raw_config.get("architectures", []) # pyright: ignore[reportAny] + for arch_name in architectures: + if arch_name in ARCHITECTURE_REGISTRY: + return ARCHITECTURE_REGISTRY[arch_name] + + model_type: str = raw_config.get("model_type", "") # pyright: ignore[reportAny] + if model_type in ARCHITECTURE_REGISTRY: + return ARCHITECTURE_REGISTRY[model_type] + raise ValueError( + f"Unsupported Architecture: {raw_config.get('architectures')}" + ) + +from . import llama as _llama # noqa: F401, E402 # pyright: ignore[reportUnusedImport] +from . import qwen as _qwen # noqa: F401, E402 # pyright: ignore[reportUnusedImport] diff --git a/src/exo/shared/architecture/config.py b/src/exo/shared/architecture/config.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/exo/shared/architecture/llama.py b/src/exo/shared/architecture/llama.py new file mode 100644 index 0000000000..c1fedaf9a8 --- /dev/null +++ b/src/exo/shared/architecture/llama.py @@ -0,0 +1,25 @@ +from . import ArchitectureSpec, register + +LLAMA_SPEC = ArchitectureSpec( + name = "llama", + attention_type = "grouped_query", + mlp_type = "swiglu", + norm_type = "rms_norm", + rope_type = "standard", + layer_prefix = "model.layers.{layer_idx}", + q_proj_key = "self_attn.q_proj", + k_proj_key = "self_attn.k_proj", + v_proj_key = "self_attn.v_proj", + o_proj_key = "self_attn.o_proj", + gate_proj_key = "mlp.gate_proj", + up_proj_key = "mlp.up_proj", + down_proj_key = "mlp.down_proj", + input_norm_key = "input_layernorm", + post_attn_norm_key = "post_attention_layernorm", + embed_key = "model.embed_tokens", + final_norm_key = "model.norm", + lm_head_key = "lm_head", +) + +register("LlamaForCausalLM", LLAMA_SPEC) +register("MistralForCausalLM", LLAMA_SPEC) diff --git a/src/exo/shared/architecture/qwen.py b/src/exo/shared/architecture/qwen.py new file mode 100644 index 0000000000..8f698127fe --- /dev/null +++ b/src/exo/shared/architecture/qwen.py @@ -0,0 +1,48 @@ +from . import ArchitectureSpec, register + +QWEN_DENSE_SPEC = ArchitectureSpec( + name = "llama", + attention_type = "grouped_query", + mlp_type = "swiglu", + norm_type = "rms_norm", + rope_type = "standard", + layer_prefix = "model.layers.{layer_idx}", + q_proj_key = "self_attn.q_proj", + k_proj_key = "self_attn.k_proj", + v_proj_key = "self_attn.v_proj", + o_proj_key = "self_attn.o_proj", + gate_proj_key = "mlp.gate_proj", + up_proj_key = "mlp.up_proj", + down_proj_key = "mlp.down_proj", + input_norm_key = "input_layernorm", + post_attn_norm_key = "post_attention_layernorm", + embed_key = "model.embed_tokens", + final_norm_key = "model.norm", + lm_head_key = "lm_head", +) + +QWEN3_DENSE_SPEC = ArchitectureSpec( + name = "llama", + attention_type = "grouped_query", + mlp_type = "swiglu", + norm_type = "rms_norm", + rope_type = "standard", + layer_prefix = "model.layers.{layer_idx}", + q_proj_key = "self_attn.q_proj", + k_proj_key = "self_attn.k_proj", + v_proj_key = "self_attn.v_proj", + o_proj_key = "self_attn.o_proj", + gate_proj_key = "mlp.gate_proj", + up_proj_key = "mlp.up_proj", + down_proj_key = "mlp.down_proj", + input_norm_key = "input_layernorm", + post_attn_norm_key = "post_attention_layernorm", + embed_key = "model.embed_tokens", + final_norm_key = "model.norm", + lm_head_key = "lm_head", + q_norm_key = "self_attn.q_norm", + k_norm_key = "self_attn.k_norm", +) + +register("Qwen2ForCausalLM", QWEN_DENSE_SPEC) +register("Qwen3ForCausalLM", QWEN3_DENSE_SPEC) diff --git a/src/exo/shared/model_config.py b/src/exo/shared/model_config.py new file mode 100644 index 0000000000..0abd6509ce --- /dev/null +++ b/src/exo/shared/model_config.py @@ -0,0 +1,77 @@ +import json +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, PositiveInt + +from exo.shared.architecture import ArchitectureSpec, detect_architecture + + +class QuantizationConfig(BaseModel, frozen=True, strict=True): + bits: int + group_size: int + +class ModelConfig(BaseModel, frozen=True, strict=True): + architecture_spec: ArchitectureSpec + num_hidden_layers: PositiveInt + hidden_size: PositiveInt + intermediate_size: PositiveInt + num_attention_heads: PositiveInt + num_key_value_heads: PositiveInt + vocab_size: PositiveInt + head_dim: PositiveInt + rope_theta: float + rope_scaling: dict[str, Any] | None + max_position_embeddings: PositiveInt + rms_norm_eps: float + tie_word_embeddings: bool + quantization_config: QuantizationConfig | None + +def parse_model_config(config_path: Path) -> ModelConfig: # noqa: C901 + with open(config_path) as f: + raw: dict[str, Any] = json.load(f) # pyright: ignore[reportAny] + + arch_spec = detect_architecture(raw) + hidden_size: int = raw["hidden_size"] # pyright: ignore[reportAny] + num_attention_heads: int = raw["num_attention_heads"] # pyright: ignore[reportAny] + num_key_value_heads: int = raw.get("num_key_value_heads", num_attention_heads) # pyright: ignore[reportAny] + + intermediate_size: int = raw.get("intermediate_size", hidden_size * 4) # pyright: ignore[reportAny] + + head_dim: int = raw.get("head_dim", hidden_size // num_attention_heads) # pyright: ignore[reportAny] + num_hidden_layers: int = raw.get("num_hidden_layers") or raw.get("num_heads") or 32 + + rope_theta: float = raw.get("rope_theta", arch_spec.rope_theta) # pyright: ignore[reportAny] + + max_position_embeddings: int = raw.get("max_position_embeddings", 4096) # pyright: ignore[reportAny] + rms_norm_eps: float = raw.get("rms_norm_eps", 1e-6) # pyright: ignore[reportAny] + tie_word_embeddings: bool = raw.get("tie_word_embeddings", False) # pyright: ignore[reportAny] + + quant_raw: dict[str, Any] | None = raw.get("quantization_config") + quantization_config: QuantizationConfig | None = None + + if quant_raw and "bits" in quant_raw and "group_size" in quant_raw: + quantization_config = QuantizationConfig( + bits = int(quant_raw["bits"]), # pyright: ignore[reportAny] + group_size = int(quant_raw["group_size"]), # pyright: ignore[reportAny] + ) + + vocab_size: int = raw["vocab_size"] # pyright: ignore[reportAny] + rope_scaling: dict[str, Any] | None = raw.get("rope_scaling") + + return ModelConfig( + architecture_spec = arch_spec, + num_attention_heads = num_attention_heads, + num_hidden_layers = num_hidden_layers, + num_key_value_heads = num_key_value_heads, + hidden_size = hidden_size, + intermediate_size = intermediate_size, + vocab_size = vocab_size, + head_dim = head_dim, + rope_theta = rope_theta, + rope_scaling = rope_scaling, + max_position_embeddings = max_position_embeddings, + rms_norm_eps = rms_norm_eps, + tie_word_embeddings = tie_word_embeddings, + quantization_config = quantization_config, + ) diff --git a/src/exo/shared/tests/test_architecture.py b/src/exo/shared/tests/test_architecture.py new file mode 100644 index 0000000000..c977f49f23 --- /dev/null +++ b/src/exo/shared/tests/test_architecture.py @@ -0,0 +1,158 @@ +import json +from pathlib import Path + +import pytest +from pydantic import ValidationError + + +def test_registered_architecture(): + """ + Tests the following architecture to be present in the + architecture registry: + - Llama + - Qwen + - Mistral + """ + + from exo.shared.architecture import ARCHITECTURE_REGISTRY + + architectures = [ + "LlamaForCausalLM", + "Qwen2ForCausalLM", + "MistralForCausalLM", + ] + + assert all(arch in ARCHITECTURE_REGISTRY for arch in architectures) + +def test_detect_llama(): + """ + detect_architecture() should find Llama from a new config + dict. + """ + + from exo.shared.architecture import detect_architecture + + raw = { + "architectures": ["LlamaForCausalLM"], + "model_type": "llama", + } + + spec = detect_architecture(raw) + + assert spec.name == "llama" + assert spec.attention_type == "grouped_query" + assert spec.mlp_type == "swiglu" + assert spec.norm_type == "rms_norm" + +def test_detect_unknown_raises(): + """ + Unsupported architectures should raise ValueError. + """ + from exo.shared.architecture import detect_architecture + + with pytest.raises(ValueError, match="Unsupported Architecture"): + detect_architecture({"architectures": ["FakeModelForCausalLM"]}) + +def test_architecture_spec_is_frozen(): + """ + Architectures must be immutable + """ + from exo.shared.architecture import ARCHITECTURE_REGISTRY + + spec = ARCHITECTURE_REGISTRY["LlamaForCausalLM"] + with pytest.raises(ValidationError): + spec.name = "modified" # type: ignore[misc] + + +def test_parse_model_config(tmp_path: Path): + """ + parse_mode_config() must produce a correct ModelConfig + from JSON. + """ + + from exo.shared.model_config import parse_model_config + + config = { + "architectures" : ["LlamaForCausalLM"], + "hidden_size": 2048, + "intermediate_size": 8192, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 16, + "vocab_size": 128256, + "rope_theta": 500000.0, + "max_position_embeddings": 4096, + "rms_norm_eps": 1e-5, + } + + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + parsed = parse_model_config(config_path) + assert parsed.num_attention_heads == 32 + assert parsed.num_key_value_heads == 8 + assert parsed.head_dim == 64 # 2048 // 32 + assert parsed.hidden_size == 2048 + assert parsed.quantization_config is None + + +def test_parse_model_config_with_quantization(tmp_path: Path): + """ + Quantized models should populate quantization_config. + """ + from exo.shared.model_config import parse_model_config + + config = { + "architectures": ["LlamaForCausalLM"], + "hidden_size": 2048, + "intermediate_size": 8192, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 16, + "vocab_size": 128256, + "quantization_config": {"bits": 4, "group_size": 64}, + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + parsed = parse_model_config(config_path) + assert parsed.quantization_config is not None + assert parsed.quantization_config.bits == 4 + assert parsed.quantization_config.group_size == 64 + + +def test_parse_model_config_defaults(tmp_path: Path): + """ + Missing optional fields should use sensible defaults. + """ + from exo.shared.model_config import parse_model_config + + config = { + "architectures": ["LlamaForCausalLM"], + "hidden_size": 2048, + "num_attention_heads": 32, + "vocab_size": 128256, + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + parsed = parse_model_config(config_path) + assert parsed.num_key_value_heads == 32 # defaults to num_attention_heads + assert parsed.tie_word_embeddings is False + assert parsed.rope_scaling is None + + +def test_model_config_is_frozen(tmp_path: Path): + """ + ModelConfig must be immutable (frozen=True). + """ + from exo.shared.model_config import parse_model_config + + config = { + "architectures": ["LlamaForCausalLM"], + "hidden_size": 2048, + "num_attention_heads": 32, + "vocab_size": 128256, + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + parsed = parse_model_config(config_path) + with pytest.raises(ValidationError): + parsed.hidden_size = 4096 # type: ignore[misc] diff --git a/src/exo/shared/tests/test_tokenizer_shared.py b/src/exo/shared/tests/test_tokenizer_shared.py new file mode 100644 index 0000000000..29798441ad --- /dev/null +++ b/src/exo/shared/tests/test_tokenizer_shared.py @@ -0,0 +1,88 @@ +from typing import Any +from unittest.mock import MagicMock + +from exo.shared.types.common import ModelId + + +def test_kimi_eos() -> None: + """Kimi K2 has a known EOS token ID.""" + from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model + + assert get_eos_token_ids_for_model(ModelId("moonshotai/Kimi-K2-Instruct")) == [163586] + + +def test_glm_flash_eos() -> None: + """GLM-4.7-flash has specific EOS token IDs.""" + from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model + + result = get_eos_token_ids_for_model(ModelId("THUDM/glm-4.7-flash")) + assert result == [154820, 154827, 154829] + + +def test_glm_generic_eos() -> None: + """Generic GLM models have different EOS tokens than flash.""" + from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model + + result = get_eos_token_ids_for_model(ModelId("THUDM/glm-4-9b")) + assert result == [151336, 151329, 151338] + + +def test_unknown_model_returns_none() -> None: + """Unknown models should return None (use tokenizer default).""" + from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model + + assert get_eos_token_ids_for_model(ModelId("unknown-org/UnknownModel-7B")) is None + + +def test_chat_template_chatml_fallback() -> None: + """Without apply_chat_template, fall back to ChatML format.""" + from exo.shared.tokenizer.chat_template import apply_chat_template + from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams + + tokenizer = MagicMock(spec=[]) + + task = TextGenerationTaskParams( + model=ModelId("test-model"), + input=[InputMessage(role="user", content="Hello")], + ) + + result: Any = apply_chat_template(tokenizer, task) + assert "<|im_start|>user" in result + assert "Hello" in result + assert "<|im_end|>" in result + assert result.endswith("<|im_start|>assistant\n") # pyright: ignore[reportAny] + + +def test_chat_template_delegates_to_tokenizer() -> None: + """When tokenizer has apply_chat_template, use it.""" + from exo.shared.tokenizer.chat_template import apply_chat_template + from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams + + tokenizer: Any = MagicMock() + tokenizer.apply_chat_template.return_value = "" # pyright: ignore[reportAny] + + task = TextGenerationTaskParams( + model=ModelId("test-model"), + input=[InputMessage(role="user", content="Hello")], + ) + + result: Any = apply_chat_template(tokenizer, task) + assert result == "" + tokenizer.apply_chat_template.assert_called_once() # pyright: ignore[reportAny] + + +def test_normalize_tool_calls_parses_json_string() -> None: + """Tool call arguments stored as strings should be parsed to dicts.""" + from exo.shared.tokenizer.chat_template import ( + _normalize_tool_calls, # pyright: ignore[reportPrivateUsage] + ) + + msg: dict[str, Any] = { + "role": "assistant", + "tool_calls": [ + {"function": {"name": "get_weather", "arguments": '{"city": "London"}'}} + ], + } + _normalize_tool_calls(msg) + tool_calls: Any = msg["tool_calls"] # pyright: ignore[reportAny] + assert tool_calls[0]["function"]["arguments"] == {"city": "London"} diff --git a/src/exo/shared/tokenizer/chat_template.py b/src/exo/shared/tokenizer/chat_template.py new file mode 100644 index 0000000000..bee80ffe4f --- /dev/null +++ b/src/exo/shared/tokenizer/chat_template.py @@ -0,0 +1,84 @@ +import contextlib +import json +from typing import Any, cast + +from loguru import logger + +from exo.shared.types.text_generation import TextGenerationTaskParams + + +def apply_chat_template( + tokenizer: Any, # pyright: ignore[reportAny] + task_params: TextGenerationTaskParams, +) -> str: + """Convert TextGenerationTaskParams to a chat template prompt. + + Converts the internal format (input + instructions) to a messages list + that can be processed by the tokenizer's chat template. + + When chat_template_messages is available (from Chat Completions API), + uses those directly to preserve tool_calls, thinking, and other fields. + Otherwise builds messages from the task params input/instructions. + """ + formatted_messages: list[dict[str, Any]] = [] + if task_params.chat_template_messages is not None: + formatted_messages = list(task_params.chat_template_messages) + for msg in formatted_messages: + _normalize_tool_calls(msg) + else: + if task_params.instructions: + formatted_messages.append( + {"role": "system", "content": task_params.instructions} + ) + + for msg in task_params.input: + if not msg.content: + logger.warning("Received message with empty content, skipping") + continue + formatted_messages.append({"role": msg.role, "content": msg.content}) + + # For assistant prefilling, append content after templating to avoid a closing turn token. + partial_assistant_content: str | None = None + if formatted_messages and formatted_messages[-1].get("role") == "assistant": + partial_assistant_content = cast(str, formatted_messages[-1].get("content", "")) + formatted_messages = formatted_messages[:-1] + + if hasattr(tokenizer, "apply_chat_template"): # pyright: ignore[reportAny] + prompt: str = tokenizer.apply_chat_template( # pyright: ignore[reportAny] + formatted_messages, + tokenize=False, + add_generation_prompt=True, + tools=task_params.tools, + ) + if partial_assistant_content: + prompt += partial_assistant_content + return prompt + + # Fallback: ChatML format + parts: list[str] = [] + for msg in formatted_messages: + parts.append(f"<|im_start|>{msg['role']}\n{msg.get('content', '')}<|im_end|>\n") + parts.append("<|im_start|>assistant\n") + result = "".join(parts) + if partial_assistant_content: + result += partial_assistant_content + return result + + +def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None: + tool_calls: list[Any] | None = msg_dict.get("tool_calls") + if not tool_calls: + return + + for tool_call in tool_calls: # pyright: ignore[reportAny] + if not isinstance(tool_call, dict): + continue + + func: dict[str, Any] | None = tool_call.get("function") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + if not isinstance(func, dict): + continue + + args: Any = func.get("arguments") # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + if isinstance(args, str): + with contextlib.suppress(json.JSONDecodeError): + func["arguments"] = json.loads(args) diff --git a/src/exo/shared/tokenizer/eos_tokens.py b/src/exo/shared/tokenizer/eos_tokens.py new file mode 100644 index 0000000000..4f05ccbb3e --- /dev/null +++ b/src/exo/shared/tokenizer/eos_tokens.py @@ -0,0 +1,16 @@ +from exo.shared.models.model_cards import ModelId + + +def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None: + model_id_lower = model_id.lower() + + if "kimi-k2" in model_id_lower: + return [163586] + elif "glm-4.7-flash" in model_id_lower: + return [154820, 154827, 154829] + elif "glm" in model_id_lower: + return [151336, 151329, 151338] + elif "llama" in model_id_lower: + return [128001, 128008, 128009] + + return None diff --git a/src/exo/shared/tokenizer/loader.py b/src/exo/shared/tokenizer/loader.py new file mode 100644 index 0000000000..fb56f54950 --- /dev/null +++ b/src/exo/shared/tokenizer/loader.py @@ -0,0 +1,40 @@ +import sys +from pathlib import Path +from typing import Any + +from exo.shared.models.model_cards import ModelId + + +def _apply_transformers_5x_compact() -> None: + """Monkey patch for transformers 5.x compatibility (Kimi tokenizer).""" + try: + import transformers.models.gpt2.tokenization_gpt2 as gpt2_tok + from transformers.convert_slow_tokenizer import bytes_to_unicode + + if not hasattr(gpt2_tok, "bytes_to_unicode"): + gpt2_tok.bytes_to_unicode = bytes_to_unicode # pyright: ignore[reportAttributeAccessIssue] + except ImportError: + pass + + +def load_tokenizer_for_model(model_id: ModelId, model_path: Path) -> Any: # pyright: ignore[reportAny] + model_id_lower = model_id.lower() + + if "kimi-k2" in model_id_lower: + _apply_transformers_5x_compact() + sys.path.insert(0, str(model_path)) + + from tokenization_kimi import ( # pyright: ignore[reportMissingImports] + TikTokenTokenizer, # pyright: ignore[reportUnknownVariableType] + ) + hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + + def _patched_encode(text: str, **_kwargs: object) -> list[int]: + return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + + hf_tokenizer.encode = _patched_encode + return hf_tokenizer # pyright: ignore[reportUnknownVariableType] + + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] diff --git a/src/exo/shared/tokenizer/wrapper.py b/src/exo/shared/tokenizer/wrapper.py new file mode 100644 index 0000000000..8e296b0000 --- /dev/null +++ b/src/exo/shared/tokenizer/wrapper.py @@ -0,0 +1,102 @@ +from collections.abc import Callable +from typing import Any, final + +_THINK_TOKENS: list[tuple[str, str]] = [ + ("", ""), + ("", ""), +] + + +@final +class TokenizerWrapper: + """Engine-agnostic tokenizer wrapper that satisfies the Tokenizer protocol. + + Wraps a raw HuggingFace tokenizer and adds tool-calling and thinking + properties expected by the runner. All other attribute access is proxied + to the underlying tokenizer. + """ + + _tokenizer: Any + _tool_call_start: str | None + _tool_call_end: str | None + _tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None + _think_start: str | None + _think_start_id: int | None + _think_end_id: int | None + + def __init__( + self, + tokenizer: Any, # pyright: ignore[reportAny] + *, + tool_call_start: str | None = None, + tool_call_end: str | None = None, + tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None = None, + ) -> None: + object.__setattr__(self, "_tokenizer", tokenizer) + object.__setattr__(self, "_tool_call_start", tool_call_start) + object.__setattr__(self, "_tool_call_end", tool_call_end) + object.__setattr__(self, "_tool_parser", tool_parser) + + # Detect thinking tokens from vocabulary + think_start: str | None = None + think_start_id: int | None = None + think_end_id: int | None = None + vocab: dict[str, int] = tokenizer.get_vocab() # pyright: ignore[reportAny] + for start_tok, end_tok in _THINK_TOKENS: + if start_tok in vocab and end_tok in vocab: + think_start = start_tok + think_start_id = vocab[start_tok] + think_end_id = vocab[end_tok] + break + object.__setattr__(self, "_think_start", think_start) + object.__setattr__(self, "_think_start_id", think_start_id) + object.__setattr__(self, "_think_end_id", think_end_id) + + # Disable tool calling if tokens aren't in the vocabulary + if (tool_call_start and tool_call_start not in vocab) or ( + tool_call_end and tool_call_end not in vocab + ): + object.__setattr__(self, "_tool_call_start", None) + object.__setattr__(self, "_tool_call_end", None) + object.__setattr__(self, "_tool_parser", None) + + @property + def has_tool_calling(self) -> bool: + return self._tool_call_start is not None + + @property + def has_call_start(self) -> bool: + return self._tool_call_start is not None + + @property + def tool_call_start(self) -> str | None: + return self._tool_call_start + + @property + def tool_call_end(self) -> str | None: + return self._tool_call_end + + @property + def tool_parser(self) -> Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None: + return self._tool_parser + + @property + def think_start(self) -> str | None: + return self._think_start + + @property + def think_start_id(self) -> int | None: + return self._think_start_id + + @property + def think_end(self) -> int | None: + return self._think_end_id + + def __getattr__(self, name: str) -> Any: # pyright: ignore[reportAny] + return getattr(self._tokenizer, name) # pyright: ignore[reportAny] + + def __setattr__(self, name: str, value: Any) -> None: # pyright: ignore[reportAny] + if name.startswith("_"): + object.__setattr__(self, name, value) + else: + setattr(self._tokenizer, name, value) # pyright: ignore[reportAny] diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py index 23ca9b7b6a..842a5e90d4 100644 --- a/src/exo/shared/types/api.py +++ b/src/exo/shared/types/api.py @@ -8,7 +8,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId from exo.shared.types.common import CommandId, NodeId from exo.shared.types.memory import Memory -from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta +from exo.shared.types.worker.instances import ( + Instance, + InstanceId, + InstanceMeta, + default_instance_meta, +) from exo.shared.types.worker.shards import Sharding, ShardMetadata from exo.utils.pydantic_ext import CamelCaseModel @@ -224,7 +229,7 @@ class HuggingFaceSearchResult(BaseModel): class PlaceInstanceParams(BaseModel): model_id: ModelId sharding: Sharding = Sharding.Pipeline - instance_meta: InstanceMeta = InstanceMeta.MlxRing + instance_meta: InstanceMeta = Field(default_factory=default_instance_meta) min_nodes: int = 1 diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index 76bd6fd4e6..f504ec12c8 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -1,3 +1,4 @@ +import sys from enum import Enum from pydantic import model_validator @@ -15,6 +16,13 @@ class InstanceId(Id): class InstanceMeta(str, Enum): MlxRing = "MlxRing" MlxJaccl = "MlxJaccl" + Tinygrad = "Tinygrad" + + +def default_instance_meta() -> InstanceMeta: + if sys.platform == "darwin": + return InstanceMeta.MlxRing + return InstanceMeta.Tinygrad class BaseInstance(TaggedModel): @@ -34,9 +42,12 @@ class MlxJacclInstance(BaseInstance): jaccl_devices: list[list[str | None]] jaccl_coordinators: dict[NodeId, str] +class TinygradInstance(BaseInstance): + hosts_by_node: dict[NodeId, list[Host]] | None = None + ephemeral_port: int | None = None # TODO: Single node instance -Instance = MlxRingInstance | MlxJacclInstance +Instance = MlxRingInstance | MlxJacclInstance | TinygradInstance class BoundInstance(CamelCaseModel): diff --git a/src/exo/shared/types/worker/tokenizer.py b/src/exo/shared/types/worker/tokenizer.py new file mode 100644 index 0000000000..28b71d093a --- /dev/null +++ b/src/exo/shared/types/worker/tokenizer.py @@ -0,0 +1,36 @@ +from collections.abc import Callable +from typing import Any, Protocol + + +class Tokenizer(Protocol): + @property + def has_tool_calling(self) -> bool: ... + + @property + def has_call_start(self) -> bool: ... + + @property + def tool_call_start(self) -> str | None: ... + + @property + def tool_call_end(self) -> str | None: ... + + @property + def tool_parser(self) -> Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None: ... + + @property + def think_start(self) -> str | None: ... + + @property + def think_start_id(self) -> int | None: ... + + @property + def think_end(self) -> int | None: ... + + +class MutableTokenizer(Tokenizer, Protocol): + """Extended protocol for tokenizers whose tool-calling fields can be reassigned.""" + + _tool_call_start: str | None + _tool_call_end: str | None + _tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None diff --git a/src/exo/worker/engines/engine_factory.py b/src/exo/worker/engines/engine_factory.py new file mode 100644 index 0000000000..2500997c4e --- /dev/null +++ b/src/exo/worker/engines/engine_factory.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Callable, Generator +from typing import Any + +from pydantic import BaseModel + +from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse + +""" +The annotation frozen=True attribute makes the class immutable, preventing +side-effects during runtime. This is critical since we are working with +heterogenuous consumer-grade devices. +""" + +class Engine(BaseModel, frozen=True): + # (BoundInstance) -> context + initialize: Callable[[BoundInstance], Any] + + # (BoundInstance, context) -> model, tokenizer + load: Callable[..., tuple[Any, Any]] + + # (model, tokenizer, task, prompt) -> Generator[GenerationResponse] + generate: Callable[..., Generator[GenerationResponse | ToolCallResponse]] + + apply_chat_template: Callable[..., str] + detect_thinking_prompt_suffix: Callable[..., bool] + + # (model, tokenizer) -> initialize + warmup: Callable[..., int] + cleanup: Callable[[], None] + +def create_engine(bound_instance: BoundInstance) -> Engine: + from exo.shared.types.worker.instances import ( + MlxJacclInstance, + MlxRingInstance, + TinygradInstance, + ) + + match bound_instance.instance: + case MlxRingInstance() | MlxJacclInstance(): + # Lazy import - MLX must be loaded only on MacOS. + from exo.worker.engines.mlx.generator.generate import ( + mlx_generate_with_postprocessing, + warmup_inference, + ) + from exo.worker.engines.mlx.utils_mlx import ( + apply_chat_template, + detect_thinking_prompt_suffix, + initialize_mlx, + load_mlx_items, + ) + + return Engine( + initialize = initialize_mlx, + load = load_mlx_items, + generate=mlx_generate_with_postprocessing, + apply_chat_template=apply_chat_template, + detect_thinking_prompt_suffix=detect_thinking_prompt_suffix, + warmup = warmup_inference, + cleanup = _mlx_cleanup, + ) + + case TinygradInstance(): + from exo.shared.tokenizer.chat_template import ( + apply_chat_template, + ) + from exo.worker.engines.tinygrad.generator.generate import ( + tinygrad_generate, + warmup_inference, + ) + from exo.worker.engines.tinygrad.utils_tinygrad import ( + initialize_tinygrad, + load_tinygrad_items, + ) + + return Engine( + initialize=initialize_tinygrad, + load=load_tinygrad_items, + generate=tinygrad_generate, + apply_chat_template=apply_chat_template, + detect_thinking_prompt_suffix=_tinygrad_detect_thinking, + warmup=warmup_inference, + cleanup=_tinygrad_cleanup, + ) + + case _: # pyright: ignore[reportUnnecessaryComparison] + raise ValueError(f"Unsupported Instance: {type(bound_instance.instance)}") + + +def _mlx_cleanup() -> None: + from mlx.core import clear_cache + clear_cache() + + +def _tinygrad_cleanup() -> None: + from exo.worker.engines.tinygrad.generator.generate import cleanup_jit_state + cleanup_jit_state() + + +def _tinygrad_detect_thinking(prompt: str, tokenizer: Any) -> bool: # pyright: ignore[reportAny] + return False diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 6e92620513..d6e83fe1f1 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -1,18 +1,27 @@ import time from copy import deepcopy -from typing import Callable, Generator, cast, get_args +from functools import cache +from typing import Any, Callable, Generator, cast, get_args import mlx.core as mx from mlx_lm.generate import stream_generate from mlx_lm.models.cache import ArraysCache, RotatingKVCache +from mlx_lm.models.gpt_oss import Model as GptOssModel from mlx_lm.sample_utils import make_sampler from mlx_lm.tokenizer_utils import TokenizerWrapper +from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs] + HarmonyEncodingName, + Role, + StreamableParser, + load_harmony_encoding, +) from exo.shared.types.api import ( CompletionTokensDetails, FinishReason, GenerationStats, PromptTokensDetails, + ToolCallItem, TopLogprobItem, Usage, ) @@ -22,6 +31,7 @@ from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.worker.runner_response import ( GenerationResponse, + ToolCallResponse, ) from exo.worker.engines.mlx import Model from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill @@ -494,3 +504,247 @@ def mlx_generate( # Limit accumulated_text to what's needed for stop sequence detection if max_stop_len > 0 and len(accumulated_text) > max_stop_len: accumulated_text = accumulated_text[-max_stop_len:] + + +def mlx_generate_with_postprocessing( + model: Model, + tokenizer: TokenizerWrapper, + task: TextGenerationTaskParams, + prompt: str, + model_id: str, + kv_prefix_cache: KVPrefixCache | None = None, + group: mx.distributed.Group | None = None, +) -> Generator[GenerationResponse | ToolCallResponse]: + """ + This wrapper function for mlx_generate() includes model specific + post-processing required by GPT-OSS, Kimi and GLM. + + This will ensure engine-dependent post-processing to be contained + within the engine package, requiring minimal changes in `runner.py` + file. + """ + + gen = mlx_generate(model=model, tokenizer=tokenizer, task=task, + prompt=prompt, kv_prefix_cache=kv_prefix_cache, group=group) + + # Kimi-K2 has tool call sections - we don't care about them + if "kimi" in model_id.lower(): + gen = filter_kimi_tokens(gen) + patch_kimi_tokenizer(tokenizer) + + # GLM models need patched parser (upstream has bug with None regex match) + elif "glm" in model_id.lower(): + patch_glm_tokenizer(tokenizer) + + # GPT-OSS specific parsing to match other model formats. + elif isinstance(model, GptOssModel): + gen = parse_gpt_oss(gen) + + return gen + + +def parse_gpt_oss( + responses: Generator[GenerationResponse | ToolCallResponse], +) -> Generator[GenerationResponse | ToolCallResponse]: + encoding = get_gpt_oss_encoding() # pyright: ignore[reportAny] + stream: Any = StreamableParser(encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny] + thinking = False + current_tool_name: str | None = None + tool_arg_parts: list[str] = [] + + for response in responses: + assert isinstance(response, GenerationResponse) + stream.process(response.token) # pyright: ignore[reportAny] + + delta: str | None = stream.last_content_delta # pyright: ignore[reportAny] + ch: str | None = stream.current_channel # pyright: ignore[reportAny] + recipient: str | None = stream.current_recipient # pyright: ignore[reportAny] + + if recipient != current_tool_name: + if current_tool_name is not None: + prefix = "functions." + if current_tool_name.startswith(prefix): + current_tool_name = current_tool_name[len(prefix) :] + yield ToolCallResponse( + tool_calls=[ + ToolCallItem( + name=current_tool_name, + arguments="".join(tool_arg_parts).strip(), + ) + ], + usage=response.usage, + ) + tool_arg_parts = [] + current_tool_name = recipient + + # If inside a tool call, accumulate arguments + if current_tool_name is not None: + if delta: + tool_arg_parts.append(delta) + continue + + if ch == "analysis" and not thinking: + thinking = True + yield response.model_copy(update={"text": ""}) + + if ch != "analysis" and thinking: + thinking = False + yield response.model_copy(update={"text": ""}) + + if delta: + yield response.model_copy(update={"text": delta}) + + if response.finish_reason is not None: + if thinking: + yield response.model_copy(update={"text": ""}) + yield response + + +@cache +def get_gpt_oss_encoding() -> Any: # pyright: ignore[reportAny] + return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + +def filter_kimi_tokens( + responses: Generator[GenerationResponse | ToolCallResponse], +) -> Generator[GenerationResponse]: + for resp in responses: + assert isinstance(resp, GenerationResponse) + if ( + resp.text == "<|tool_calls_section_begin|>" + or resp.text == "<|tool_calls_section_end|>" + ): + continue + yield resp + + +def patch_glm_tokenizer(tokenizer: TokenizerWrapper): + """ + Fixed version of mlx_lm's glm47 tool parser that handles regex match failures. + """ + import ast + import json + from typing import Any + + import regex as re + + _func_name_regex = re.compile(r"^(.*?)", re.DOTALL) + _func_arg_regex = re.compile( + r"(.*?)(?:\n|\s)*(.*?)(?:|(?=)|$)", + re.DOTALL, + ) + + tool_call_start = "" + tool_call_end = "" + + def _is_string_type( + tool_name: str, + arg_name: str, + tools: list[Any] | None, + ) -> bool: + if tools is None: + return False + for tool in tools: # pyright: ignore[reportAny] + func = tool["function"] # pyright: ignore[reportAny] + if func["name"] == tool_name: + params = func["parameters"] # pyright: ignore[reportAny] + if params is None: + return False + props = params.get("properties", {}) # pyright: ignore[reportAny] + arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny] + arg_type = arg_props.get("type", None) # pyright: ignore[reportAny] + return arg_type == "string" # pyright: ignore[reportAny] + return False + + def _deserialize(value: str) -> Any: # pyright: ignore[reportAny] + try: + return json.loads(value) # pyright: ignore[reportAny] + except Exception: + pass + try: + return ast.literal_eval(value) # pyright: ignore[reportAny] + except Exception: + pass + return value + + + def parse_tool_call(text: str, tools: list[Any] | None = None): + func_name_match = _func_name_regex.search(text) + if func_name_match is None: + raise ValueError(f"Could not parse function name from tool call: {text!r}") + func_name = func_name_match.group(1) + + pairs = _func_arg_regex.findall(text) + arg_dct: dict[str, Any] = {} + for key, value in pairs: # pyright: ignore[reportAny] + arg_key = key.strip() # pyright: ignore[reportAny] + arg_val = value.strip() # pyright: ignore[reportAny] + if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny] + arg_val = _deserialize(arg_val) # pyright: ignore[reportAny] + arg_dct[arg_key] = arg_val + return dict(name=func_name, arguments=arg_dct) + + tokenizer._tool_call_start = tool_call_start + tokenizer._tool_call_end = tool_call_end + tokenizer._tool_parser = parse_tool_call + + +def patch_kimi_tokenizer(tokenizer: TokenizerWrapper): + """ + Version of to-be-upstreamed kimi-k2 tool parser + """ + import ast + import json + from typing import Any + + import regex as re + + # kimi has a fixed function naming scheme, with a json formatted arg + # functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3} + # Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."} + _func_name_regex = re.compile( + r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL + ) + _func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL) + + # kimi has a tool_calls_section - we're leaving this up to the caller to handle + tool_call_start = "<|tool_call_begin|>" + tool_call_end = "<|tool_call_end|>" + + def _deserialize(value: str) -> Any: # pyright: ignore[reportAny] + try: + return json.loads(value) # pyright: ignore[reportAny] + except Exception: + pass + + try: + return ast.literal_eval(value) # pyright: ignore[reportAny] + except Exception: + pass + return value + + def parse_tool_call(text: str, tools: Any | None = None): + func_name_match = _func_name_regex.search(text) + if func_name_match is None: + raise ValueError(f"Could not parse function name from tool call: {text!r}") + original_func_name = func_name_match.group(1) + tool_id = func_name_match.group(2) + # strip off the `functions.` prefix, if it exists. + func_name = original_func_name[original_func_name.find(".") + 1 :] + + func_args_match = _func_arg_regex.search(text) + if func_args_match is None: + raise ValueError(f"Could not parse function args from tool call: {text!r}") + func_args = func_args_match.group(1) + # the args should be valid json - no need to check against our tools to deserialize + arg_dct = _deserialize(func_args) # pyright: ignore[reportAny] + + return dict( + id=f"{original_func_name}:{tool_id}", + name=func_name, + arguments=arg_dct, # pyright: ignore[reportAny] + ) + + tokenizer._tool_call_start = tool_call_start + tokenizer._tool_call_end = tool_call_end + tokenizer._tool_parser = parse_tool_call diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 48b902ff8c..bedece01d4 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -149,6 +149,9 @@ def mlx_distributed_init( os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator group = mx.distributed.init(backend="jaccl", strict=True) + case _: + raise TypeError(f"Unsupported instance type for MLX: {type(bound_instance.instance)}") + logger.info(f"Rank {rank} mlx distributed initialization complete") return group diff --git a/src/exo/worker/engines/tinygrad/__init__.py b/src/exo/worker/engines/tinygrad/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/exo/worker/engines/tinygrad/cache.py b/src/exo/worker/engines/tinygrad/cache.py new file mode 100644 index 0000000000..b837a03c3f --- /dev/null +++ b/src/exo/worker/engines/tinygrad/cache.py @@ -0,0 +1,73 @@ +from tinygrad.dtype import dtypes +from tinygrad.tensor import Tensor + + +class KVCache: + def __init__(self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + max_seq_len: int, + ) -> None: + self.keys: list[Tensor] = [Tensor.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtypes.float16) for _ in range(num_layers)] # pyright: ignore[reportUnknownMemberType] + self.values: list[Tensor] = [Tensor.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtypes.float16) for _ in range(num_layers)] # pyright: ignore[reportUnknownMemberType] + + self.max_seq_len = max_seq_len + self._positions: Tensor = Tensor.arange(max_seq_len).reshape(1, 1, max_seq_len, 1) # pyright: ignore[reportUnknownMemberType] + self.col_indices: Tensor = Tensor.arange(max_seq_len).reshape(1, 1, 1, max_seq_len) # pyright: ignore[reportUnknownMemberType] + + def update( + self, + layers_idx: int, + key: Tensor, + value: Tensor, + position: int | Tensor = 0, + ) -> tuple[Tensor, Tensor]: + seq_len = key.shape[2] + + """ + Mask: For positions: + [self.position, self.position + seq_len] + + We are using mask + pad since tinygrad tensor + cannot handle slice assignment. + """ + + positions = self._positions + + if isinstance(position, Tensor): + mask = positions == position + self.keys[layers_idx] = Tensor.where( + mask, key.half(), self.keys[layers_idx] + ) + self.values[layers_idx] = Tensor.where( + mask, value.half(), self.values[layers_idx] + ) + else: + mask = (positions >= position) & (positions < position + seq_len) + pad_prev = position + pad_next = self.max_seq_len - position - seq_len + new_k = key.pad( + ((0, 0), (0, 0), (pad_prev, pad_next), (0, 0)) + ).half() + new_v = value.pad( + ((0, 0), (0, 0), (pad_prev, pad_next), (0, 0)) + ).half() + + self.keys[layers_idx] = new_k + self.values[layers_idx] = new_v + + if position > 0: + self.keys[layers_idx] = Tensor.where( + mask, new_k, self.keys[layers_idx] + ) + + self.values[layers_idx] = Tensor.where( + mask, new_v, self.values[layers_idx] + ) + + return self.keys[layers_idx], self.values[layers_idx] + + @property + def seq_len(self) -> int: + return int(self.keys[0].shape[2]) diff --git a/src/exo/worker/engines/tinygrad/constants.py b/src/exo/worker/engines/tinygrad/constants.py new file mode 100644 index 0000000000..d1276e486a --- /dev/null +++ b/src/exo/worker/engines/tinygrad/constants.py @@ -0,0 +1,4 @@ +DEFAULT_MAX_TOKENS: int = 32168 +DEFAULT_TEMPERATURE: float = 0.7 +DEFAULT_TOP_P: float = 1.0 +DEFAULT_TOP_LOGPROBS: int = 5 diff --git a/src/exo/worker/engines/tinygrad/forward.py b/src/exo/worker/engines/tinygrad/forward.py new file mode 100644 index 0000000000..92934e1861 --- /dev/null +++ b/src/exo/worker/engines/tinygrad/forward.py @@ -0,0 +1,117 @@ +from typing import NamedTuple + +from tinygrad.tensor import Tensor + +from exo.shared.model_config import ModelConfig +from exo.worker.engines.tinygrad.cache import KVCache +from exo.worker.engines.tinygrad.layers.attention import grouped_query_attention +from exo.worker.engines.tinygrad.layers.embedding import apply_embedding, apply_lm_head +from exo.worker.engines.tinygrad.layers.mlp import swiglu_mlp +from exo.worker.engines.tinygrad.layers.normalization import rms_norm +from exo.worker.engines.tinygrad.weight_loader import LayerWeights, TransformerWeights + + +class TransformerBlockBuilder(NamedTuple): + cos_freqs: Tensor + sin_freqs: Tensor + layer: LayerWeights + idx: int + offset: int | Tensor + +def forward_pass( + weights: TransformerWeights, + input_ids: Tensor, + cache: KVCache | None, + position_offset: int | Tensor = 0, + rope_cos: Tensor | None = None, + rope_sin: Tensor | None = None, +) -> tuple[Tensor, KVCache]: + config = weights.config + _batch, _seq_len = input_ids.shape + + x = apply_embedding(weights.embed_tokens, input_ids) + + if cache is None: + """ + I am reducing the max_seq_len down to 4096 to work + with consumer grade GPUs. Unlike Apple systems, + most computers have memory statically partionined + if using integrated memory. With discrete GPUs, the + VRAM issue becomes explicit. + + When testing out on my AMD RX6600M, this is my way + of handling OOM errors. + """ + cache = KVCache( + num_layers = len(weights.layers), + num_kv_heads = config.num_key_value_heads, + head_dim = config.head_dim, + max_seq_len = min(config.max_position_embeddings, 4096), + ) + + cos = rope_cos if rope_cos is not None else weights.rope_cos + sin = rope_sin if rope_sin is not None else weights.rope_sin + + for layer_idx, layer in enumerate(weights.layers): + builder = TransformerBlockBuilder( + cos, sin, + layer, layer_idx, position_offset, + ) + x = _transformer_block(x, config, cache, builder) + + if isinstance(position_offset, int): + x = x.realize(cache.keys[layer_idx], cache.values[layer_idx]) + + x = rms_norm(x, weights.final_norm, config.rms_norm_eps) + logits = apply_lm_head(x, weights.lm_head) + + return logits, cache + +def _transformer_block( + x: Tensor, + config: ModelConfig, + cache: KVCache, + builder: TransformerBlockBuilder, +) -> Tensor: + residual = x + layer = builder.layer + x = rms_norm(x, layer.input_norm, config.rms_norm_eps) + + match config.architecture_spec.attention_type: + case "grouped_query" | "multi_head": + x = grouped_query_attention( + x, qkv_proj = layer.qkv_proj, + o_proj = layer.o_proj, + cos_freqs = builder.cos_freqs, + sin_freqs = builder.sin_freqs, + cache = cache, layer_idx = builder.idx, + position_offset = builder.offset, + cache_position = builder.offset, + num_heads = config.num_attention_heads, + num_kv_heads = config.num_key_value_heads, + head_dim = config.head_dim, + q_norm = layer.q_norm, + k_norm = layer.k_norm, + rms_norm_eps = config.rms_norm_eps, + ) + case "multi_latent": + raise NotImplementedError( + "MLA attention: not yet been implemented" + ) + + x = x + residual + residual = x + + x = rms_norm(x, layer.post_attn_norm, config.rms_norm_eps) + match config.architecture_spec.mlp_type: + case "swiglu": + x = swiglu_mlp( + x, layer.gate_up_proj, layer.down_proj + ) + case "moe_top_k": + raise NotImplementedError( + "MoE MLP: not yet implemented" + ) + + x = x + residual + return x diff --git a/src/exo/worker/engines/tinygrad/generator/generate.py b/src/exo/worker/engines/tinygrad/generator/generate.py new file mode 100644 index 0000000000..7ba551af1e --- /dev/null +++ b/src/exo/worker/engines/tinygrad/generator/generate.py @@ -0,0 +1,340 @@ +import struct +import time +from collections.abc import Callable, Generator +from dataclasses import dataclass +from typing import Any + +from tinygrad.dtype import dtypes +from tinygrad.engine.jit import TinyJit +from tinygrad.helpers import Context +from tinygrad.tensor import Tensor + +from exo.shared.model_config import ModelConfig +from exo.shared.models.model_cards import ModelId +from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model +from exo.shared.types.api import ( + CompletionTokensDetails, + GenerationStats, + PromptTokensDetails, + TopLogprobItem, + Usage, +) +from exo.shared.types.memory import Memory +from exo.shared.types.text_generation import TextGenerationTaskParams +from exo.shared.types.worker.runner_response import GenerationResponse +from exo.worker.engines.tinygrad.constants import ( + DEFAULT_MAX_TOKENS, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_P, +) + +from ..cache import KVCache +from ..forward import forward_pass +from ..sampling import sample_token +from ..weight_loader import TransformerWeights + +_PREFILL_BUCKETS: list[int] = [32, 64, 128, 256, 512] + +def _pad_to_bucket(input_ids: list[int], pad_id: int = 0) -> list[int]: + """ + Padding input_ids to the nearest bucket size will cache the tensor + sizes. This will increase the guarentee to hit the cache, leading to + quicker time to first token. + """ + + for bucket in _PREFILL_BUCKETS: + if len(input_ids) <= bucket: + return input_ids + [pad_id] * (bucket - len(input_ids)) + return input_ids + + +@dataclass +class _JitState: + """Persistent decode state reused across requests.""" + jit_decode: Callable[..., tuple[Tensor, ...]] + cache: KVCache + input_buffer: Tensor + position_buffer: Tensor + +_jit_registry: dict[int, _JitState] = {} + +def cleanup_jit_state() -> None: + """Called by engine cleanup to free all JIT state.""" + _jit_registry.clear() + + +def _build_jit_decode( + weights: TransformerWeights, + cache: KVCache, +) -> Callable[..., tuple[Tensor, ...]]: + num_layers = len(weights.layers) + + @TinyJit + def decode( + input_ids: Tensor, position: Tensor, + rope_cos_table: Tensor, rope_sin_table: Tensor, + *cache_kv: Tensor, + ) -> tuple[Tensor, ...]: + for i in range(num_layers): + cache.keys[i] = cache_kv[i] + cache.values[i] = cache_kv[num_layers + i] + + logits, _ = forward_pass( + weights, input_ids, cache, + position_offset=position, + rope_cos=rope_cos_table, rope_sin=rope_sin_table, + ) + + # Realize everything at once — JIT captures one fused kernel schedule + # instead of 32 fragmented ones. + logits = logits.realize(*cache.keys, *cache.values) + + return (logits, *cache.keys, *cache.values) + + return decode + +def tinygrad_generate( + model: TransformerWeights, + tokenizer: Any, # pyright: ignore[reportAny] + task: TextGenerationTaskParams, + prompt: str, + kv_prefix_cache: Any = None, # pyright: ignore[reportAny] + on_prefill_progress: Callable[[int, int], None] | None = None, + group: None = None, +) -> Generator[GenerationResponse]: + input_ids = _encode_prompt(tokenizer, prompt) + + max_tokens = task.max_output_tokens or DEFAULT_MAX_TOKENS + temperature = task.temperature or DEFAULT_TEMPERATURE + top_p = task.top_p or DEFAULT_TOP_P + + request_logprobs = task.logprobs + top_logprobs_count = task.top_logprobs or 0 + + eos_ids = _get_eos_ids(tokenizer, model.config) + print(f"[DEBUG] eos_ids={eos_ids}") + prompt_tokens = len(input_ids) + input_ids = _pad_to_bucket(input_ids) + + model_key = id(model) + num_layers = len(model.layers) + state = _jit_registry.get(model_key) + + if state is None: + # First request: create cache, JIT, and pre-allocate buffers + config = model.config + cache = KVCache( + num_layers=len(model.layers), + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_seq_len=min(config.max_position_embeddings, 4096), + ) + # Realize cache tensors — Tensor.zeros() produces lazy/const tensors + # that TinyJit rejects as inputs. Force device buffer allocation. + for i in range(len(model.layers)): + cache.keys[i] = cache.keys[i].contiguous().realize() # pyright: ignore[reportUnknownMemberType] + cache.values[i] = cache.values[i].contiguous().realize() # pyright: ignore[reportUnknownMemberType] + jit_decode = _build_jit_decode(model, cache) + input_buffer = Tensor.empty(1, 1, dtype=dtypes.int32).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + position_buffer = Tensor.empty(1, dtype=dtypes.int32).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + state = _JitState( + jit_decode=jit_decode, + cache=cache, + input_buffer=input_buffer, + position_buffer=position_buffer, + ) + _jit_registry[model_key] = state + + cache = state.cache + + # Batched prefill: process all prompt tokens in a single forward pass. + # Uses position_offset=0 (int) which triggers local attention (seq_len × seq_len) + # instead of full-cache attention, producing ~324 kernel dispatches total + # instead of 324 × N token-by-token dispatches. + # BEAM is disabled because prefill shapes vary per prompt length (not cacheable) + # and BEAM may select WMMA kernels incompatible with RDNA 2 (gfx1032). + if not input_ids: + raise ValueError("Prompt must contain at least one token") + + prefill_start = time.time() + prompt_tensor = Tensor(input_ids, dtype=dtypes.int32).reshape(1, -1).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + with Context(BEAM=0): + logits, _ = forward_pass( + model, prompt_tensor, cache, + position_offset=0, + rope_cos=model.rope_cos, rope_sin=model.rope_sin, + ) + # Take the real last token's logits (not the padded last position). + logits = logits[:, prompt_tokens - 1:prompt_tokens, :].contiguous() # pyright: ignore[reportUnknownMemberType] + # Make cache tensors contiguous for JIT compatibility. + for i in range(num_layers): + cache.keys[i] = cache.keys[i].contiguous() # pyright: ignore[reportUnknownMemberType] + cache.values[i] = cache.values[i].contiguous() # pyright: ignore[reportUnknownMemberType] + # Realize everything at once — same pattern as _build_jit_decode. + logits = logits.realize(*cache.keys, *cache.values) + + # Rebuild the JIT after prefill. The batched prefill creates entirely new + # cache tensor objects (via Tensor.where + contiguous + realize) that differ + # from the JIT's captured output buffers. Rebuilding ensures the JIT + # re-captures with the correct buffer objects. The cost is 2 slow decode + # steps per request (cnt=0 jit-ignore, cnt=1 jit-capture), after which + # all subsequent tokens use fast JIT replay. + jit_decode = _build_jit_decode(model, cache) + state.jit_decode = jit_decode + + prefill_time = time.time() - prefill_start + prompt_tps = prompt_tokens / max(prefill_time, 1e-9) + + position = prompt_tokens + + # Decode + generation_start = time.time() + for token_idx in range(max_tokens): + result = sample_token( + logits, temperature=temperature, top_p=top_p, + top_logprobs_count=top_logprobs_count, + request_logprobs=request_logprobs, + ) + + token_text: str = tokenizer.decode([result.token_id]) # pyright: ignore[reportAny] + + is_eos = result.token_id in eos_ids + if is_eos: + print(f"[DEBUG] EOS detected: token_id={result.token_id}, text={token_text!r}") + if "<|eot_id|>" in token_text or "<|end" in token_text: + print(f"[DEBUG] Special token in text but is_eos={is_eos}, token_id={result.token_id}") + tokens_generated = token_idx + 1 + elapsed = time.time() - generation_start + generation_tps = tokens_generated / max(elapsed, 1e-9) + + finish_reason = None + stats = None + usage = None + + if is_eos: + finish_reason = "stop" + elif token_idx == max_tokens - 1: + finish_reason = "length" + + if finish_reason is not None: + stats = GenerationStats( + prompt_tps=prompt_tps, generation_tps=generation_tps, + prompt_tokens=prompt_tokens, + generation_tokens=tokens_generated, + peak_memory_usage=Memory.from_bytes(0), + ) + + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=tokens_generated, + total_tokens=prompt_tokens + tokens_generated, + prompt_tokens_details=PromptTokensDetails(), + completion_tokens_details=CompletionTokensDetails(), + ) + + logprob = result.logprob if request_logprobs else None + top_lps = None + if request_logprobs and task.top_logprobs: + top_lps = [ + TopLogprobItem( + token=str(tokenizer.decode([tok_id])), # pyright: ignore[reportAny] + logprob=lp, + bytes=list(str(tokenizer.decode([tok_id])).encode("utf-8")), # pyright: ignore[reportAny] + ) + for tok_id, lp in result.top_logprobs + ] + + if is_eos: + token_text = "" + + yield GenerationResponse( + text=token_text, token=result.token_id, + logprob=logprob, top_logprobs=top_lps, + finish_reason=finish_reason, stats=stats, usage=usage, + ) + + if finish_reason is not None: + break + + state.input_buffer._buffer().copyin(memoryview(bytearray(struct.pack('=i', result.token_id)))) # pyright: ignore[reportPrivateUsage] + state.position_buffer._buffer().copyin(memoryview(bytearray(struct.pack('=i', position)))) # pyright: ignore[reportPrivateUsage] + results = jit_decode( + state.input_buffer, state.position_buffer, + model.rope_cos, model.rope_sin, + *cache.keys, *cache.values, + ) + + logits = results[0] + for i in range(num_layers): + cache.keys[i] = results[1 + i] + cache.values[i] = results[1 + num_layers + i] + position += 1 + +def warmup_inference(model: TransformerWeights, tokenizer: Any, group: None = None) -> int: # pyright: ignore[reportAny] + """Run a full generation loop to warm up forward pass, KV cache, and sampling.""" + from exo.shared.tokenizer.chat_template import apply_chat_template + from exo.shared.types.common import ModelId as CommonModelId + from exo.shared.types.text_generation import InputMessage + + warmup_task = TextGenerationTaskParams( + model=CommonModelId("warmup"), + input=[InputMessage(role="user", content="Time to warm up!")], + ) + + prompt: str = apply_chat_template(tokenizer, warmup_task) + tokens_generated = 0 + + for _ in tinygrad_generate(model, tokenizer, warmup_task, prompt): + tokens_generated += 1 + if tokens_generated >= 5: + break + + _warmup_prefill_buckets(model) + + return tokens_generated + + +def _warmup_prefill_buckets(model: TransformerWeights) -> None: + """Pre-compile prefill kernels at each bucket size to avoid first-request compilation.""" + model_key = id(model) + state = _jit_registry.get(model_key) + if state is None: + return + + cache = state.cache + num_layers = len(model.layers) + + with Context(BEAM=0): + for bucket_size in _PREFILL_BUCKETS: + dummy = Tensor.zeros(1, bucket_size, dtype=dtypes.int32).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + logits, _ = forward_pass( + model, dummy, cache, + position_offset=0, + rope_cos=model.rope_cos, rope_sin=model.rope_sin, + ) + logits = logits[:, -1:, :].contiguous() # pyright: ignore[reportUnknownMemberType] + for i in range(num_layers): + cache.keys[i] = cache.keys[i].contiguous() # pyright: ignore[reportUnknownMemberType] + cache.values[i] = cache.values[i].contiguous() # pyright: ignore[reportUnknownMemberType] + logits.realize(*cache.keys, *cache.values) + +def _encode_prompt(tokenizer: Any, prompt: str) -> list[int]: # pyright: ignore[reportAny] + result: Any = tokenizer.encode(prompt) # pyright: ignore[reportAny] + + return result.ids if hasattr(result, "ids") else result # pyright: ignore[reportAny] + +def _get_eos_ids(tokenizer: Any, config: ModelConfig) -> set[int]: # pyright: ignore[reportAny] + eos_ids: set[int] = set() + + model_eos = get_eos_token_ids_for_model(ModelId(config.architecture_spec.name)) + + if model_eos: + eos_ids.update(model_eos) + + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: # pyright: ignore[reportAny] + eos_ids.add(int(tokenizer.eos_token_id)) # pyright: ignore[reportAny] + + if not eos_ids: + eos_ids.add(2) + + return eos_ids diff --git a/src/exo/worker/engines/tinygrad/layers/__init__.py b/src/exo/worker/engines/tinygrad/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/exo/worker/engines/tinygrad/layers/attention.py b/src/exo/worker/engines/tinygrad/layers/attention.py new file mode 100644 index 0000000000..04cf6266cb --- /dev/null +++ b/src/exo/worker/engines/tinygrad/layers/attention.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import math + +from tinygrad.tensor import Tensor + +from exo.worker.engines.tinygrad.cache import KVCache +from exo.worker.engines.tinygrad.layers.normalization import rms_norm +from exo.worker.engines.tinygrad.layers.rotary import apply_rope +from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + +LinearWeight = Tensor | QuantizedLinear + +def linear_forward(x: Tensor, weight: LinearWeight) -> Tensor: + if isinstance(weight, QuantizedLinear): + return weight(x) + return x @ weight.T + +def grouped_query_attention( + x: Tensor, + qkv_proj: LinearWeight, + o_proj: LinearWeight, + cos_freqs: Tensor, + sin_freqs: Tensor, + cache: KVCache, + layer_idx: int, + position_offset: int | Tensor, + cache_position: int | Tensor, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_norm: Tensor | None = None, + k_norm: Tensor | None = None, + rms_norm_eps: float = 1e-6, +) -> Tensor: + _batch, seq_len, _ = x.shape + + q_dim = num_heads * head_dim + kv_dim = num_kv_heads * head_dim + qkv = linear_forward(x, qkv_proj) + q = qkv[..., :q_dim].reshape(int(_batch), seq_len, num_heads, head_dim).permute(0, 2, 1, 3) # pyright: ignore[reportUnknownMemberType] + k = qkv[..., q_dim:q_dim + kv_dim].reshape(int(_batch), seq_len, num_kv_heads, head_dim).permute(0, 2, 1, 3) # pyright: ignore[reportUnknownMemberType] + v = qkv[..., q_dim + kv_dim:].reshape(int(_batch), seq_len, num_kv_heads, head_dim).permute(0, 2, 1, 3) # pyright: ignore[reportUnknownMemberType] + + if q_norm is not None: + q = rms_norm(q, q_norm, rms_norm_eps) + if k_norm is not None: + k = rms_norm(k, k_norm, rms_norm_eps) + + q = apply_rope(q, cos_freqs, sin_freqs, position_offset) + k = apply_rope(k, cos_freqs, sin_freqs, position_offset) + + # Store K,V in cache for future decode steps + cache.update(layer_idx, k, v, position = cache_position) + + if isinstance(position_offset, int): + # Prefill: compute attention against local K,V (seq_len × seq_len). + # This avoids the wasteful (seq_len × max_seq_len) matmul that the + # full-cache path would produce — up to 80× less work for short prompts. + k_attn, v_attn = k, v + else: + # Decode (JIT): use full pre-allocated cache K,V. + # Shapes are fixed (max_seq_len) which is required for TinyJit replay. + k_attn = cache.keys[layer_idx] + v_attn = cache.values[layer_idx] + + if num_kv_heads < num_heads: + repeat_factor = num_heads // num_kv_heads + k_attn = k_attn.unsqueeze(2).expand( # pyright: ignore[reportUnknownMemberType] + int(_batch), num_kv_heads, repeat_factor, -1, head_dim, + ).reshape(int(_batch), num_heads, -1, head_dim) + v_attn = v_attn.unsqueeze(2).expand( # pyright: ignore[reportUnknownMemberType] + int(_batch), num_kv_heads, repeat_factor, -1, head_dim, + ).reshape(int(_batch), num_heads, -1, head_dim) + + scale = 1.0 / math.sqrt(head_dim) + scores: Tensor = (q @ k_attn.transpose(-2, -1)) * scale + + if isinstance(position_offset, int): + # Prefill: standard causal mask over (seq_len × seq_len). + # All positions in the local K,V are valid, so no unfilled-position mask needed. + if seq_len > 1: + causal_mask = Tensor.ones(seq_len, seq_len).triu(1).reshape(1, 1, seq_len, seq_len) # pyright: ignore[reportUnknownMemberType] + scores = scores + causal_mask * float("-1e9") + else: + # Decode: mask unfilled positions in the full cache. + valid_len = cache_position + seq_len # pyright: ignore[reportOperatorIssue, reportUnknownVariableType] + col_indeces: Tensor = cache.col_indices + unfilled_mask: Tensor = col_indeces >= valid_len # pyright: ignore[reportOperatorIssue, reportUnknownVariableType] + scores = scores + unfilled_mask * float("-1e9") # pyright: ignore[reportUnknownVariableType] + + attn_weights: Tensor = scores.softmax(axis=-1) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + out: Tensor = attn_weights @ v_attn # pyright: ignore[reportUnknownVariableType] + out = out.permute(0, 2, 1, 3).reshape(int(_batch), seq_len, num_heads * head_dim) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + + return linear_forward(out, o_proj) # pyright: ignore[reportUnknownArgumentType] diff --git a/src/exo/worker/engines/tinygrad/layers/embedding.py b/src/exo/worker/engines/tinygrad/layers/embedding.py new file mode 100644 index 0000000000..808ac465c6 --- /dev/null +++ b/src/exo/worker/engines/tinygrad/layers/embedding.py @@ -0,0 +1,22 @@ +from tinygrad.tensor import Tensor + +from exo.worker.engines.tinygrad.quantization.layers import ( + QuantizedEmbedding, + QuantizedLinear, +) + +EmbedWeight = Tensor | QuantizedEmbedding +LinearWeight = Tensor | QuantizedLinear + +def apply_embedding( + embed: EmbedWeight, + input_ids: Tensor, +) -> Tensor: + if isinstance(embed, QuantizedEmbedding): + return embed(input_ids) + return embed[input_ids] + +def apply_lm_head(x: Tensor, lm_head_weight: LinearWeight) -> Tensor: + if isinstance(lm_head_weight, QuantizedLinear): + return lm_head_weight(x) + return x @ lm_head_weight.T diff --git a/src/exo/worker/engines/tinygrad/layers/mlp.py b/src/exo/worker/engines/tinygrad/layers/mlp.py new file mode 100644 index 0000000000..06b964c1fb --- /dev/null +++ b/src/exo/worker/engines/tinygrad/layers/mlp.py @@ -0,0 +1,16 @@ +from tinygrad.tensor import Tensor + +from exo.worker.engines.tinygrad.layers.attention import LinearWeight, linear_forward + + +def swiglu_mlp( + x: Tensor, + gate_up_proj: LinearWeight, + down_proj: LinearWeight, +) -> Tensor: + gate_up = linear_forward(x, gate_up_proj) + half = gate_up.shape[-1] // 2 + gate = gate_up[..., :half] + up = gate_up[..., half:] + activated = gate.silu() * up + return linear_forward(activated, down_proj) diff --git a/src/exo/worker/engines/tinygrad/layers/normalization.py b/src/exo/worker/engines/tinygrad/layers/normalization.py new file mode 100644 index 0000000000..e73ba92f3e --- /dev/null +++ b/src/exo/worker/engines/tinygrad/layers/normalization.py @@ -0,0 +1,8 @@ +from tinygrad.tensor import Tensor + + +def rms_norm(x: Tensor, weight: Tensor, eps: float) -> Tensor: + variance = (x * x).mean(axis=-1, keepdim=True) + x_normed = x * (variance + eps).rsqrt() + + return x_normed * weight diff --git a/src/exo/worker/engines/tinygrad/layers/rotary.py b/src/exo/worker/engines/tinygrad/layers/rotary.py new file mode 100644 index 0000000000..7e533ce14b --- /dev/null +++ b/src/exo/worker/engines/tinygrad/layers/rotary.py @@ -0,0 +1,52 @@ +from typing import Tuple + +from tinygrad.dtype import dtypes +from tinygrad.tensor import Tensor + + +def compute_rope_frequencies( + head_dim: int, + max_seq_len: int, + rope_theta: float = 10000.0, +) -> Tuple[Tensor, Tensor]: + dim_pairs = head_dim // 2 + freq_exponents = Tensor.arange(0, dim_pairs, dtype=dtypes.float32) * 2.0 / head_dim # pyright: ignore[reportUnknownMemberType] + inv_freq = 1.0 / (rope_theta ** freq_exponents) + + positions = Tensor.arange(0, max_seq_len, dtype=dtypes.float32) # pyright: ignore[reportUnknownMemberType] + angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) + + cos_freqs = angles.cos() + sin_freqs = angles.sin() + + return cos_freqs, sin_freqs # Each: (max_seq_len, head_dim // 2) + +def apply_rope( + x: Tensor, + cos_freqs: Tensor, + sin_freqs: Tensor, + position_offset: int | Tensor = 0, +) -> Tensor: + seq_len = x.shape[2] + + cos = cos_freqs + sin = sin_freqs + + if isinstance(position_offset, Tensor): + # position_offset is shape (1,) — index into full tables. + # cos_freqs is (max_seq_len, dim/2), result: (1, dim/2) → (1, 1, 1, dim/2) + cos = cos_freqs[position_offset].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType] + sin = sin_freqs[position_offset].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType] + else: + cos = cos_freqs[position_offset:position_offset + seq_len].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType] + sin = sin_freqs[position_offset:position_offset + seq_len].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType] + + half = x.shape[-1] // 2 + + x1 = x[..., :half] + x2 = x[..., half:] + + out1 = x1 * cos - x2 * sin + out2 = x2 * cos + x1 * sin + + return out1.cat(out2, dim=-1) diff --git a/src/exo/worker/engines/tinygrad/quantization/__init__.py b/src/exo/worker/engines/tinygrad/quantization/__init__.py new file mode 100644 index 0000000000..9ad8af517e --- /dev/null +++ b/src/exo/worker/engines/tinygrad/quantization/__init__.py @@ -0,0 +1,13 @@ +from .dequantization import affine_dequantize +from .layers import QuantizedEmbedding, QuantizedLinear +from .packing import PackedTensor, calculate_pack_factor, pack_bits, unpack_bits + +__all__ = [ + "PackedTensor", + "QuantizedEmbedding", + "QuantizedLinear", + "affine_dequantize", + "calculate_pack_factor", + "pack_bits", + "unpack_bits", +] diff --git a/src/exo/worker/engines/tinygrad/quantization/dequantization.py b/src/exo/worker/engines/tinygrad/quantization/dequantization.py new file mode 100644 index 0000000000..2a99d57790 --- /dev/null +++ b/src/exo/worker/engines/tinygrad/quantization/dequantization.py @@ -0,0 +1,19 @@ +from tinygrad.tensor import Tensor + + +def affine_dequantize( + quantized: Tensor, + scales: Tensor, + biases: Tensor, + group_size: int, +) -> Tensor: + """ + Apply group-wise affine dequantization: + out = quantized * scale + bias + This is the same algorithm used by MLX quantization + """ + + extended_scales = scales.repeat_interleave(group_size, dim=-1) + extended_biases = biases.repeat_interleave(group_size, dim=-1) + + return quantized * extended_scales + extended_biases diff --git a/src/exo/worker/engines/tinygrad/quantization/layers.py b/src/exo/worker/engines/tinygrad/quantization/layers.py new file mode 100644 index 0000000000..8711fc5c47 --- /dev/null +++ b/src/exo/worker/engines/tinygrad/quantization/layers.py @@ -0,0 +1,84 @@ +from typing import final + +from tinygrad.tensor import Tensor + +from .dequantization import affine_dequantize +from .packing import PackedTensor, unpack_bits + + +@final +class QuantizedLinear: + """ + This is a drop-in replacement for tinygrad's Linear class while + supporting quantization. + + Weights are eagerly dequantized in __init__ so that tinygrad's kernel + cache can reuse compiled kernels for same-shape weights, avoiding + redundant HIP/CUDA compilations during the first forward pass. + """ + + def __init__( + self, + weight_q: PackedTensor, + scales: Tensor, + biases: Tensor, + group_size: int = 64, + bias: Tensor | None = None, + ) -> None: + self.weight_q = weight_q + self.scales = scales + self.biases = biases + self.group_size = group_size + self.bias = bias + + def __call__(self, x: Tensor) -> Tensor: + result = x @ self.dequantize().T + if self.bias is not None: + result = result + self.bias + return result + + def dequantize(self) -> Tensor: + unpacked = unpack_bits(self.weight_q) + return affine_dequantize( + unpacked, self.scales, + self.biases, self.group_size) + + @property + def in_features(self) -> int: + return self.weight_q.original_shape[1] + + @property + def out_features(self) -> int: + return self.weight_q.original_shape[0] + +@final +class QuantizedEmbedding: + """ + This is a drop-in replacement for tinygrad's Embedding class providing + the embedding support. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + weight_q: PackedTensor, + scales: Tensor, + biases: Tensor, + group_size: int = 64, + ) -> None: + self.num_embeddings = num_embeddings + self.weight_q = weight_q + self.embedding_dim = embedding_dim + self.scales = scales + self.biases = biases + self.group_size = group_size + + def __call__(self, indices: Tensor) -> Tensor: + return self.dequantize()[indices] + + def dequantize(self) -> Tensor: + unpacked = unpack_bits(self.weight_q) + return affine_dequantize( + unpacked, self.scales, self.biases, self.group_size + ) diff --git a/src/exo/worker/engines/tinygrad/quantization/packing.py b/src/exo/worker/engines/tinygrad/quantization/packing.py new file mode 100644 index 0000000000..fdbb2eca4d --- /dev/null +++ b/src/exo/worker/engines/tinygrad/quantization/packing.py @@ -0,0 +1,85 @@ +from math import ceil +from typing import final + +from pydantic import BaseModel, ConfigDict +from tinygrad.dtype import dtypes +from tinygrad.tensor import Tensor + + +@final +class PackedTensor(BaseModel): + # Immutable container for bit-packed weight data with unpacking metadata + + model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True) + + tensor: Tensor + original_shape: tuple[int, ...] + pack_factor: int + bits: int + + def __setattr__(self, name: str, value: object) -> None: + raise AttributeError(f"'{type(self).__name__}' is frozen") + +def calculate_pack_factor(bits: int) -> int: + """ + Pack factor is the number of values packed into a bits. The general + formula is pack_factor = (packed_type_bits) // bits. + In this case, we are dealing with a uint32 value. + """ + + return 32 // bits + +def pack_bits(unpacked: Tensor, bits: int) -> PackedTensor: + """ + This functions converts a float16 tensor into a packed uint32 tensor, as + used by MLX community for their quantized models. + """ + + pack_factor = calculate_pack_factor(bits) + original_shape = tuple(int(d) for d in unpacked.shape) + last_dim = original_shape[-1] + padded_last_dim = ceil(last_dim / pack_factor) * pack_factor + packed_last_dim = padded_last_dim // pack_factor + packed_shape = (*original_shape[:-1], packed_last_dim) + + if last_dim != padded_last_dim: + pad_shape = (*original_shape[:-1], padded_last_dim - last_dim) + padding = Tensor.zeros(*pad_shape, dtype = unpacked.dtype) # pyright: ignore[reportUnknownMemberType] + unpacked = unpacked.cat(padding, dim = -1) + + packed = Tensor.zeros(*packed_shape, dtype=dtypes.uint32) # pyright: ignore[reportUnknownMemberType] + for slot in range(pack_factor): + values = unpacked[..., slot::pack_factor].cast(dtypes.uint32) + packed = packed | (values << (slot * bits)) + + return PackedTensor( + tensor = packed, + original_shape = original_shape, + pack_factor = pack_factor, + bits = bits, + ) + +def unpack_bits(packed_tensor: PackedTensor) -> Tensor: + """ + Unpack bit-packed uint32 tensor backed to individual float16 values + """ + packed = packed_tensor.tensor + original_shape = packed_tensor.original_shape + pack_factor = packed_tensor.pack_factor + bits = packed_tensor.bits + mask = (1 << bits) - 1 + + slices: list[Tensor] = [] + for slot in range(pack_factor): + slices.append((packed >> (slot * bits)) & mask) + + stacked = Tensor.stack(*slices, dim=0) + padded_last_dim = packed.shape[-1] * pack_factor + padded_shape = (*original_shape[:-1], padded_last_dim) + result = stacked.permute(*range(1, len(stacked.shape)), 0).reshape(padded_shape) # pyright: ignore[reportUnknownMemberType] + + if padded_last_dim != original_shape[-1]: + # Trim padding for padded, packed tensors + result = result[..., :original_shape[-1]] + + return result.cast(dtypes.float16) diff --git a/src/exo/worker/engines/tinygrad/quantization/shapes.py b/src/exo/worker/engines/tinygrad/quantization/shapes.py new file mode 100644 index 0000000000..138bb8ea15 --- /dev/null +++ b/src/exo/worker/engines/tinygrad/quantization/shapes.py @@ -0,0 +1,71 @@ +from typing import Literal + +from exo.shared.model_config import ModelConfig + +LayerType = Literal[ + "q_proj", "k_proj", "v_proj", "o_proj", "qkv_proj", + "gate_proj", "up_proj", "down_proj", + "embed_tokens", "lm_head", + "unknown", +] + +_LAYER_PATTERNS: dict[LayerType, list[str]] = { + "q_proj": [".q_proj.", ".q.", ".attn.q"], + "k_proj": [".k_proj.", ".k.", ".attn.k"], + "v_proj": [".v_proj.", ".v.", ".attn.v"], + "o_proj": [".o_proj.", ".o.", ".attn.o"], + "qkv_proj": [".c_attn.", ".qkv.", ".attn.c_attn"], + "gate_proj": [".gate_proj.", ".gate.", ".mlp.gate."], + "up_proj": [".up_proj.", ".up.", ".mlp.up."], + "down_proj": [".down_proj.", ".down.", ".mlp.down."], + "embed_tokens": ["embed_tokens", "wte"], + "lm_head": ["lm_head", "output_layer"], +} + +def detect_layer_type(weight_key: str) -> LayerType: + """ + Detect canonical layer type from a safetensors key. + """ + + key = weight_key.lower() + + for layer_type, patterns in _LAYER_PATTERNS.items(): + if any(pattern in key for pattern in patterns): + return layer_type + + return "unknown" + +def infer_weight_shape(weight_key: str, config: ModelConfig) -> tuple[int, ...]: + """ + Infer original weight shape form layer name and ModelConfig. + + Handles GQA correctly: k_proj and v_proj use kv_dim instead of + hidden_size. + """ + + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + vocab_size = config.vocab_size + + kv_dim = config.num_key_value_heads * config.head_dim + q_dim = config.num_attention_heads * config.head_dim + + shape_map: dict[LayerType, tuple[int, ...]] = { + "q_proj": (q_dim, hidden_size), + "k_proj": (kv_dim, hidden_size), + "v_proj": (kv_dim, hidden_size), + "o_proj": (hidden_size, q_dim), + "qkv_proj": (q_dim + 2 * kv_dim, hidden_size), + "gate_proj": (intermediate_size, hidden_size), + "up_proj": (intermediate_size, hidden_size), + "down_proj": (hidden_size, intermediate_size), + "embed_tokens": (vocab_size, hidden_size), + "lm_head": (vocab_size, hidden_size), + } + + layer_type = detect_layer_type(weight_key) + if layer_type in shape_map: + return shape_map[layer_type] + + # Generic fallback + return (q_dim, hidden_size) diff --git a/src/exo/worker/engines/tinygrad/sampling.py b/src/exo/worker/engines/tinygrad/sampling.py new file mode 100644 index 0000000000..6fbcb6f15d --- /dev/null +++ b/src/exo/worker/engines/tinygrad/sampling.py @@ -0,0 +1,44 @@ +from typing import NamedTuple + +from tinygrad.tensor import Tensor + + +class SampleResult(NamedTuple): + token_id: int + logprob: float + top_logprobs: list[tuple[int, float]] # (token_id, logprob) + +def sample_token( + logits: Tensor, + temperature: float = 0.7, + top_p: float = 0.9, + top_logprobs_count: int = 0, + request_logprobs: bool = False, +) -> SampleResult: + last_logits = logits[0, -1, :] + + if temperature == 0: + token_id = int(last_logits.argmax().item()) # pyright: ignore[reportUnknownMemberType] + else: + # Gumble-max trick for reducing GPU -> CPU sync. + + scaled = last_logits / temperature + gumble_noise = -(-Tensor.rand(scaled.shape).log()).log() # pyright: ignore[reportUnknownMemberType] + token_id = int((scaled + gumble_noise).argmax().item()) # pyright: ignore[reportUnknownMemberType] + + selected_logprob: float = 0.0 + top_logprobs: list[tuple[int, float]] = [] + + if request_logprobs: + log_probs = last_logits.log_softmax(axis = -1) + selected_logprob = float(log_probs[token_id].item()) + + if top_logprobs_count > 0: + values, indices = log_probs.topk(top_logprobs_count) + top_logprobs = [(int(idx), float(val)) for val, idx in zip(values.tolist(), indices.tolist(), strict=True)] # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType, reportArgumentType] + + return SampleResult( + token_id=token_id, + logprob=selected_logprob, + top_logprobs=top_logprobs, + ) diff --git a/src/exo/worker/engines/tinygrad/utils_tinygrad.py b/src/exo/worker/engines/tinygrad/utils_tinygrad.py new file mode 100644 index 0000000000..b85ac6efec --- /dev/null +++ b/src/exo/worker/engines/tinygrad/utils_tinygrad.py @@ -0,0 +1,36 @@ +from collections.abc import Callable +from typing import Any + +from exo.download.download_utils import build_model_path +from exo.shared.model_config import parse_model_config +from exo.shared.tokenizer.loader import load_tokenizer_for_model +from exo.shared.tokenizer.wrapper import TokenizerWrapper +from exo.shared.types.worker.instances import BoundInstance, TinygradInstance + +from .weight_loader import TransformerWeights, load_transformer_weights + + +def initialize_tinygrad(bound_instance: BoundInstance) -> None: + instance = bound_instance.instance + assert isinstance(instance, TinygradInstance) + # For single instance, we can let tinygrad instance decide device. + # When sharding models, we may have to add further explicit initialization + # routines to ensure effective sharding. + + +def load_tinygrad_items( + bound_instance: BoundInstance, group: None, + on_timeout: Callable[[], None] | None = None, +) -> tuple[TransformerWeights, Any]: + shard = bound_instance.bound_shard + model_id = shard.model_card.model_id + model_path = build_model_path(model_id) + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights( + model_path=model_path, config=config, + start_layer=shard.start_layer, end_layer=shard.end_layer, + ) + + tokenizer = TokenizerWrapper(load_tokenizer_for_model(model_id, model_path)) + + return weights, tokenizer diff --git a/src/exo/worker/engines/tinygrad/weight_loader.py b/src/exo/worker/engines/tinygrad/weight_loader.py new file mode 100644 index 0000000000..7d23fb1131 --- /dev/null +++ b/src/exo/worker/engines/tinygrad/weight_loader.py @@ -0,0 +1,269 @@ +from pathlib import Path +from typing import Literal, NamedTuple, overload + +from tinygrad.device import Device +from tinygrad.nn.state import safe_load +from tinygrad.tensor import Tensor + +from exo.shared.architecture import ArchitectureSpec +from exo.shared.model_config import ModelConfig +from exo.worker.engines.tinygrad.layers.rotary import compute_rope_frequencies +from exo.worker.engines.tinygrad.quantization.layers import ( + QuantizedEmbedding, + QuantizedLinear, +) +from exo.worker.engines.tinygrad.quantization.packing import PackedTensor +from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + +LinearWeight = Tensor | QuantizedLinear +EmbedWeight = Tensor | QuantizedEmbedding + +class LayerWeights(NamedTuple): + qkv_proj: LinearWeight # Merged Q+K+V + o_proj: LinearWeight + + gate_up_proj: LinearWeight # Merged gate+up + down_proj: LinearWeight + + input_norm: Tensor + post_attn_norm: Tensor + + # Optional layers + q_norm: Tensor | None = None + k_norm: Tensor | None = None + + # MoE (None for dense models) + router_weight: Tensor | None = None + expert_gate_projs: list[LinearWeight] | None = None + expert_up_projs: list[LinearWeight] | None = None + expert_down_projs: list[LinearWeight] | None = None + +class TransformerWeights(NamedTuple): + embed_tokens: EmbedWeight + lm_head: LinearWeight + final_norm: Tensor + layers: list[LayerWeights] + config: ModelConfig + rope_sin: Tensor + rope_cos: Tensor + +def load_transformer_weights( + model_path: Path, + config: ModelConfig, + start_layer: int = 0, + end_layer: int | None = None, +) -> TransformerWeights: + + if end_layer is None: + end_layer = config.num_hidden_layers + + spec = config.architecture_spec + raw_weights = _load_all_safetensors(model_path) + + embed_tokens = _build_weight(raw_weights, + f"{spec.embed_key}.weight", + config, is_embedding = True) + + lm_head: LinearWeight + if config.tie_word_embeddings: + if isinstance(embed_tokens, QuantizedEmbedding): + lm_head = QuantizedLinear( + weight_q = embed_tokens.weight_q, + scales = embed_tokens.scales, + biases = embed_tokens.biases, + group_size = embed_tokens.group_size, + ) + else: + lm_head = embed_tokens + else: + lm_head = _build_weight( + raw_weights, f"{spec.lm_head_key}.weight", config, + ) + + final_norm = raw_weights[f"{spec.final_norm_key}.weight"] + + layers: list[LayerWeights] = [] + + for layer_idx in range(start_layer, end_layer): + prefix = spec.layer_prefix.format(layer_idx=layer_idx) + layers.append(_build_layer_weights(raw_weights, prefix, spec, config)) + + rope_cos, rope_sin = compute_rope_frequencies( + head_dim = config.head_dim, + max_seq_len = config.max_position_embeddings, + rope_theta = config.rope_theta, + ) + + return TransformerWeights( + embed_tokens=embed_tokens, lm_head=lm_head, + final_norm=final_norm, layers=layers, config=config, + rope_cos = rope_cos.realize(), + rope_sin = rope_sin.realize(), + ) + +def _merge_linear_weights(*weights: LinearWeight) -> LinearWeight: + """Merge multiple LinearWeight objects by concatenating along the output dimension (dim 0). + + For quantized weights, concatenates packed uint32, scales, and biases directly — + no extra memory from dequantization. For non-quantized or mixed weights, dequantizes + first then concatenates plain Tensors. + """ + if all(isinstance(w, QuantizedLinear) for w in weights): + qls = [w for w in weights if isinstance(w, QuantizedLinear)] + merged_tensor = qls[0].weight_q.tensor.cat( + *[w.weight_q.tensor for w in qls[1:]], dim=0 + ).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + merged_scales = qls[0].scales.cat( + *[w.scales for w in qls[1:]], dim=0 + ).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + merged_biases = qls[0].biases.cat( + *[w.biases for w in qls[1:]], dim=0 + ).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + return QuantizedLinear( + weight_q=PackedTensor( + tensor=merged_tensor, + original_shape=( + sum(w.weight_q.original_shape[0] for w in qls), + qls[0].weight_q.original_shape[1], + ), + pack_factor=qls[0].weight_q.pack_factor, + bits=qls[0].weight_q.bits, + ), + scales=merged_scales, + biases=merged_biases, + group_size=qls[0].group_size, + ) + # Non-quantized or mixed: dequantize if needed, then cat + tensors: list[Tensor] = [] + for w in weights: + if isinstance(w, QuantizedLinear): + tensors.append(w.dequantize()) + else: + tensors.append(w) + return tensors[0].cat(*tensors[1:], dim=0).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + +def _build_layer_weights( + raw: dict[str, Tensor], + prefix: str, + spec: ArchitectureSpec, + config: ModelConfig, +) -> LayerWeights: + def key(suffix: str) -> str: + return f"{prefix}.{suffix}.weight" + + q_norm = raw.get(f"{prefix}.{spec.q_norm_key}.weight") if spec.q_norm_key else None + k_norm = raw.get(f"{prefix}.{spec.k_norm_key}.weight") if spec.k_norm_key else None + + q_proj = _build_weight(raw, key(spec.q_proj_key), config) + k_proj = _build_weight(raw, key(spec.k_proj_key), config) + v_proj = _build_weight(raw, key(spec.v_proj_key), config) + qkv_proj = _merge_linear_weights(q_proj, k_proj, v_proj) + + gate_proj = _build_weight(raw, key(spec.gate_proj_key), config) + up_proj = _build_weight(raw, key(spec.up_proj_key), config) + gate_up_proj = _merge_linear_weights(gate_proj, up_proj) + + return LayerWeights( + qkv_proj=qkv_proj, + o_proj=_build_weight(raw, key(spec.o_proj_key), config), + gate_up_proj=gate_up_proj, + down_proj=_build_weight(raw, key(spec.down_proj_key), config), + input_norm=raw[f"{prefix}.{spec.input_norm_key}.weight"], + post_attn_norm=raw[f"{prefix}.{spec.post_attn_norm_key}.weight"], + q_norm=q_norm, + k_norm=k_norm, + ) + +@overload +def _build_weight(raw: dict[str, Tensor], key: str, config: ModelConfig, is_embedding: Literal[True]) -> EmbedWeight: ... +@overload +def _build_weight(raw: dict[str, Tensor], key: str, config: ModelConfig, is_embedding: Literal[False] = ...) -> LinearWeight: ... + +def _build_weight( + raw: dict[str, Tensor], + key: str, + config: ModelConfig, + is_embedding: bool = False, +) -> LinearWeight | EmbedWeight: + scales_key = key.replace(".weight", ".scales") + biases_key = key.replace(".weight", ".biases") + + # MLX quantized: .weight (packed uint32) + .scales + .biases + quantization_config + if key in raw and config.quantization_config is not None and scales_key in raw and biases_key in raw: + qcfg = config.quantization_config + packed = PackedTensor( + tensor = raw[key], + original_shape = infer_weight_shape(key, config), + pack_factor = 32 // qcfg.bits, + bits = qcfg.bits, + ) + + if is_embedding: + return QuantizedEmbedding( + num_embeddings = config.vocab_size, + embedding_dim = config.hidden_size, + weight_q = packed, + scales = raw[scales_key], + biases = raw[biases_key], + group_size = qcfg.group_size, + ) + + return QuantizedLinear( + weight_q = packed, + scales = raw[scales_key], + biases = raw[biases_key], + group_size = qcfg.group_size, + ) + + # Plain unquantized: .weight only + if key in raw: + return raw[key] + + # Legacy .qweight format + qweight_key = key.replace(".weight", ".qweight") + + if qweight_key in raw and config.quantization_config is not None: + qcfg = config.quantization_config + packed = PackedTensor( + tensor = raw[qweight_key], + original_shape = infer_weight_shape(key, config), + pack_factor = 32 // qcfg.bits, + bits = qcfg.bits, + ) + + if is_embedding: + return QuantizedEmbedding( + num_embeddings = config.vocab_size, + embedding_dim = config.hidden_size, + weight_q = packed, + scales = raw[scales_key], + biases = raw[biases_key], + group_size = qcfg.group_size, + ) + + return QuantizedLinear( + weight_q = packed, + scales = raw[scales_key], + biases = raw[biases_key], + group_size = qcfg.group_size, + ) + + raise KeyError(f"Weight key '{key}' not found (also tried {qweight_key})") + +def _load_all_safetensors(path: Path) -> dict[str, Tensor]: + from tinygrad.helpers import Context + + merged: dict[str, Tensor] = {} + + # Disable BEAM during weight loading. Copy kernels (DISK -> GPU) have unique + # shapes per tensor and don't benefit from beam search optimisation. + with Context(BEAM=0): + for safetensor_file in sorted(path.glob("*.safetensors")): + shard = safe_load(str(safetensor_file)) + for key, tensor in shard.items(): + merged[key] = tensor.to(Device.DEFAULT).contiguous().realize() # pyright: ignore[reportUnknownMemberType] + + if not merged: + raise FileNotFoundError(f"No .safetensors file found in {path}") + + return merged diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index 9949cb7e65..746c993fcf 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -1,15 +1,41 @@ import os +import signal +import threading +from multiprocessing.connection import Connection import loguru from exo.shared.types.events import Event, RunnerStatusUpdated from exo.shared.types.tasks import Task, TaskId -from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance +from exo.shared.types.worker.instances import ( + BoundInstance, + MlxJacclInstance, + TinygradInstance, +) from exo.shared.types.worker.runners import RunnerFailed from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender logger: "loguru.Logger" = loguru.logger +def _start_parent_death_watchdog(conn: Connection) -> None: + """Watch for parent death via a multiprocessing.Connection. + + The parent holds the other end of the pipe. When the parent dies, + the OS closes its end and our recv() returns with EOFError. + This is a POSIX and Windows guarantee, making it cross-platform. + """ + + def _watchdog() -> None: + try: + conn.recv() + except (EOFError, OSError): + pass + finally: + conn.close() + os.kill(os.getpid(), signal.SIGTERM) + + t = threading.Thread(target=_watchdog, daemon=True) + t.start() def entrypoint( bound_instance: BoundInstance, @@ -17,27 +43,45 @@ def entrypoint( task_receiver: MpReceiver[Task], cancel_receiver: MpReceiver[TaskId], _logger: "loguru.Logger", + parent_death_conn: Connection, ) -> None: - fast_synch_override = os.environ.get("EXO_FAST_SYNCH") - if fast_synch_override == "on" or ( - fast_synch_override != "off" - and ( - isinstance(bound_instance.instance, MlxJacclInstance) - and len(bound_instance.instance.jaccl_devices) >= 2 - ) - ): - os.environ["MLX_METAL_FAST_SYNCH"] = "1" - else: - os.environ["MLX_METAL_FAST_SYNCH"] = "0" + _start_parent_death_watchdog(parent_death_conn) + + is_tinygrad = isinstance(bound_instance.instance, TinygradInstance) + + if not is_tinygrad: + fast_synch_override = os.environ.get("EXO_FAST_SYNCH") + if fast_synch_override == "on" or ( + fast_synch_override != "off" + and ( + isinstance(bound_instance.instance, MlxJacclInstance) + and len(bound_instance.instance.jaccl_devices) >= 2 + ) + ): + os.environ["MLX_METAL_FAST_SYNCH"] = "1" + else: + os.environ["MLX_METAL_FAST_SYNCH"] = "0" global logger logger = _logger - logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}") + if not is_tinygrad: + logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}") + + # Import main after setting global logger - this lets us just import logger from this module. + # Guard by instance type: TinygradInstance must never import MLX modules (macOS-only). + if is_tinygrad: + # Default JIT=1 and BEAM=2 for decode inference. Prefill/weight-loading + # paths override BEAM with Context(BEAM=0). User can still override via env. + os.environ.setdefault("JIT", "1") + os.environ.setdefault("BEAM", "2") + + os.environ.setdefault("TC", "1") - # Import main after setting global logger - this lets us just import logger from this module try: - if bound_instance.is_image_model: + if is_tinygrad: + from exo.worker.runner.llm_inference.tinygrad_runner import main + elif bound_instance.is_image_model: from exo.worker.runner.image_models.runner import main else: from exo.worker.runner.llm_inference.runner import main diff --git a/src/exo/worker/runner/llm_inference/tinygrad_runner.py b/src/exo/worker/runner/llm_inference/tinygrad_runner.py new file mode 100644 index 0000000000..68db7569c1 --- /dev/null +++ b/src/exo/worker/runner/llm_inference/tinygrad_runner.py @@ -0,0 +1,403 @@ +import resource +import time +from collections.abc import Generator +from typing import TYPE_CHECKING + +from exo.shared.models.model_cards import ModelTask +from exo.shared.tokenizer.chat_template import apply_chat_template +from exo.shared.types.chunks import ( + ErrorChunk, + TokenChunk, + ToolCallChunk, +) +from exo.shared.types.events import ( + ChunkGenerated, + Event, + RunnerStatusUpdated, + TaskAcknowledged, + TaskStatusUpdated, +) +from exo.shared.types.tasks import ( + LoadModel, + Shutdown, + StartWarmup, + Task, + TaskId, + TaskStatus, + TextGeneration, +) +from exo.shared.types.text_generation import TextGenerationTaskParams +from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.runner_response import ( + GenerationResponse, + ToolCallResponse, +) +from exo.shared.types.worker.runners import ( + RunnerFailed, + RunnerIdle, + RunnerLoaded, + RunnerLoading, + RunnerReady, + RunnerRunning, + RunnerShutdown, + RunnerShuttingDown, + RunnerStatus, + RunnerWarmingUp, +) +from exo.utils.channels import MpReceiver, MpSender +from exo.worker.engines.tinygrad.generator.generate import ( + cleanup_jit_state, + tinygrad_generate, + warmup_inference, +) +from exo.worker.engines.tinygrad.utils_tinygrad import ( + initialize_tinygrad, + load_tinygrad_items, +) +from exo.worker.engines.tinygrad.weight_loader import TransformerWeights +from exo.worker.runner.bootstrap import logger + +from .tool_parsers import ToolParser, make_mlx_parser + + +def main( + bound_instance: BoundInstance, + event_sender: MpSender[Event], + task_receiver: MpReceiver[Task], + cancel_receiver: MpReceiver[TaskId], +) -> None: + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard)) + + runner_id = bound_instance.bound_runner_id + shard_metadata = bound_instance.bound_shard + + logger.info("hello from the tinygrad runner") + + setup_start_time = time.time() + cancelled_tasks = set[TaskId]() + + inference_model: TransformerWeights | None = None + tokenizer = None + tool_parser: ToolParser | None = None + check_for_cancel_every: int | None = None + + current_status: RunnerStatus = RunnerIdle() + logger.info("tinygrad runner created") + event_sender.send( + RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) + ) + seen = set[TaskId]() + with task_receiver as tasks: + for task in tasks: + if task.task_id in seen: + logger.warning("repeat task - potential error") + seen.add(task.task_id) + cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK")) + event_sender.send( + TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running) + ) + match task: + case LoadModel() if isinstance(current_status, (RunnerIdle, RunnerFailed)): + total_layers = shard_metadata.end_layer - shard_metadata.start_layer + current_status = RunnerLoading( + layers_loaded=0, total_layers=total_layers + ) + logger.info("tinygrad runner loading") + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + def on_model_load_timeout() -> None: + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, + runner_status=RunnerFailed( + error_message="Model loading timed out" + ), + ) + ) + time.sleep(0.5) + + assert ( + ModelTask.TextGeneration in shard_metadata.model_card.tasks + ), f"Incorrect model task(s): {shard_metadata.model_card.tasks}" + + initialize_tinygrad(bound_instance) + + inference_model, tokenizer = load_tinygrad_items( # pyright: ignore[reportAny] + bound_instance, + None, + on_timeout=on_model_load_timeout, + ) + logger.info( + f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}" # pyright: ignore[reportAny] + ) + if tokenizer.has_tool_calling: # pyright: ignore[reportAny] + assert tokenizer.tool_call_start # pyright: ignore[reportAny] + assert tokenizer.tool_call_end # pyright: ignore[reportAny] + assert tokenizer.tool_parser # pyright: ignore[reportAny] + tool_parser = make_mlx_parser( + tokenizer.tool_call_start, # pyright: ignore[reportAny] + tokenizer.tool_call_end, # pyright: ignore[reportAny] + tokenizer.tool_parser, # pyright: ignore[reportAny] + ) + current_status = RunnerLoaded() + logger.info("tinygrad runner loaded") + + case StartWarmup() if isinstance(current_status, RunnerLoaded): + current_status = RunnerWarmingUp() + logger.info("tinygrad runner warming up") + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + assert inference_model + assert tokenizer + + t = time.monotonic() + toks = warmup_inference( + model=inference_model, + tokenizer=tokenizer, + ) + logger.info(f"warmed up by generating {toks} tokens") + check_for_cancel_every = min( + max(1, round(toks / max(time.monotonic() - t, 0.001))), 100 + ) + logger.info( + f"tinygrad runner checking for cancellation every {check_for_cancel_every} tokens" + ) + logger.info( + f"tinygrad runner initialized in {time.time() - setup_start_time} seconds" + ) + current_status = RunnerReady() + logger.info("tinygrad runner ready") + + case TextGeneration(task_params=task_params, command_id=command_id) if ( + isinstance(current_status, RunnerReady) + ): + logger.info(f"received chat request: {task}") + current_status = RunnerRunning() + logger.info("tinygrad runner running") + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + assert inference_model + assert tokenizer + assert check_for_cancel_every + + try: + _check_for_debug_prompts(task_params) + + prompt = apply_chat_template(tokenizer, task_params) + + gen: Generator[GenerationResponse | ToolCallResponse] = tinygrad_generate( + model=inference_model, + tokenizer=tokenizer, + task=task_params, + prompt=prompt, + ) + + if tool_parser: + gen = _parse_tool_calls(gen, tool_parser) + + completion_tokens = 0 + tokens_since_last_cancel_check = check_for_cancel_every + for response in gen: + tokens_since_last_cancel_check += 1 + if tokens_since_last_cancel_check >= check_for_cancel_every: + tokens_since_last_cancel_check = 0 + cancelled_tasks.update(cancel_receiver.collect()) + want_to_cancel = (task.task_id in cancelled_tasks) or ( + TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks + ) + if want_to_cancel: + break + + match response: + case GenerationResponse(): + completion_tokens += 1 + if response.finish_reason == "error": + event_sender.send( + ChunkGenerated( + command_id=command_id, + chunk=ErrorChunk( + error_message=response.text, + model=shard_metadata.model_card.model_id, + ), + ) + ) + else: + assert response.finish_reason not in ( + "error", + "tool_calls", + "function_call", + ) + event_sender.send( + ChunkGenerated( + command_id=command_id, + chunk=TokenChunk( + model=shard_metadata.model_card.model_id, + text=response.text, + token_id=response.token, + usage=response.usage, + finish_reason=response.finish_reason, + stats=response.stats, + logprob=response.logprob, + top_logprobs=response.top_logprobs, + is_thinking=response.is_thinking, + ), + ) + ) + case ToolCallResponse(): + event_sender.send( + ChunkGenerated( + command_id=command_id, + chunk=ToolCallChunk( + tool_calls=response.tool_calls, + model=shard_metadata.model_card.model_id, + usage=response.usage, + stats=response.stats, + ), + ) + ) + + except Exception as e: + event_sender.send( + ChunkGenerated( + command_id=command_id, + chunk=ErrorChunk( + model=shard_metadata.model_card.model_id, + finish_reason="error", + error_message=str(e), + ), + ) + ) + raise + + current_status = RunnerReady() + logger.info("tinygrad runner ready") + + case Shutdown(): + current_status = RunnerShuttingDown() + logger.info("tinygrad runner shutting down") + if not TYPE_CHECKING: + del inference_model, tokenizer + cleanup_jit_state() + import gc + + gc.collect() + + event_sender.send( + RunnerStatusUpdated( + runner_id=runner_id, runner_status=current_status + ) + ) + event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + current_status = RunnerShutdown() + + case _: + raise ValueError( + f"Received {task.__class__.__name__} outside of state machine in {current_status=}" + ) + + was_cancelled = (task.task_id in cancelled_tasks) or ( + TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks + ) + if not was_cancelled: + event_sender.send( + TaskStatusUpdated( + task_id=task.task_id, task_status=TaskStatus.Complete + ) + ) + event_sender.send( + RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) + ) + + if isinstance(current_status, RunnerShutdown): + break + + +def _parse_tool_calls( + responses: Generator[GenerationResponse | ToolCallResponse], + tool_parser: ToolParser, +) -> Generator[GenerationResponse | ToolCallResponse]: + """Wrap a generation stream to detect and parse tool calls.""" + in_tool_call = False + tool_call_text_parts: list[str] = [] + for response in responses: + if not isinstance(response, GenerationResponse): + yield response + continue + + if response.text.startswith(tool_parser.start_parsing): + in_tool_call = True + + if in_tool_call: + tool_call_text_parts.append(response.text) + if response.text.endswith(tool_parser.end_parsing): + parsed = tool_parser.parse_tool_calls( + "".join(tool_call_text_parts).strip() + ) + logger.info(f"parsed {tool_call_text_parts=} into {parsed=}") + if parsed is not None: + yield ToolCallResponse( + tool_calls=parsed, usage=response.usage, stats=response.stats + ) + else: + logger.warning( + f"tool call parsing failed for text {''.join(tool_call_text_parts)}" + ) + response = response.model_copy( + update={"text": "".join(tool_call_text_parts)} + ) + yield response + + in_tool_call = False + tool_call_text_parts = [] + continue + + if response.finish_reason is not None: + logger.info( + "tool call parsing interrupted, yield partial tool call as text" + ) + response = response.model_copy( + update={ + "text": "".join(tool_call_text_parts), + "token": 0, + } + ) + yield response + + continue + + yield response + + +EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL" +EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT" + + +def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None: + if len(task_params.input) == 0: + return + prompt = task_params.input[0].content + + if not prompt: + return + + if EXO_RUNNER_MUST_FAIL in prompt: + logger.info("raising exception") + raise Exception("Artificial runner exception - for testing purposes only.") + if EXO_RUNNER_MUST_TIMEOUT in prompt: + time.sleep(100) diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index e8a06a77c3..28aee80f05 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -1,7 +1,8 @@ import contextlib import signal from dataclasses import dataclass, field -from multiprocessing import Process +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection from typing import Self import anyio @@ -52,6 +53,7 @@ class RunnerSupervisor: pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False) completed: set[TaskId] = field(default_factory=set, init=False) cancelled: set[TaskId] = field(default_factory=set, init=False) + _death_conn: Connection @classmethod def create( @@ -65,6 +67,8 @@ def create( task_sender, task_recv = mp_channel[Task]() cancel_sender, cancel_recv = mp_channel[TaskId]() + parent_conn, child_conn = Pipe() + runner_process = Process( target=entrypoint, args=( @@ -73,8 +77,9 @@ def create( task_recv, cancel_recv, logger, + child_conn, ), - daemon=True, + daemon=False, ) shard_metadata = bound_instance.bound_shard @@ -88,6 +93,7 @@ def create( _task_sender=task_sender, _cancel_sender=cancel_sender, _event_sender=event_sender, + _death_conn=parent_conn, ) return self @@ -98,6 +104,7 @@ async def run(self): def shutdown(self): logger.info("Runner supervisor shutting down") + self._death_conn.close() self._ev_recv.close() self._task_sender.close() with contextlib.suppress(ClosedResourceError): diff --git a/src/exo/worker/tests/unittests/test_runner/test_runner_import.py b/src/exo/worker/tests/unittests/test_runner/test_runner_import.py new file mode 100644 index 0000000000..48eb8bcf4a --- /dev/null +++ b/src/exo/worker/tests/unittests/test_runner/test_runner_import.py @@ -0,0 +1,68 @@ +import inspect + +import pytest + + +def _mlx_backend_available() -> bool: + """Return True only if mlx.core can be fully loaded (native libs present).""" + try: + import mlx.core # noqa: F401 # pyright: ignore[reportUnusedImport] + return True + except (ImportError, OSError): + return False + +requires_mlx = pytest.mark.skipif( + not _mlx_backend_available(), + reason="MLX native backend not available (missing CUDA/Metal libraries)", +) + +def test_tinygrad_runner_imports_without_mlx() -> None: + """tinygrad_runner.py must be importable on Linux where MLX is absent.""" + from exo.worker.runner.llm_inference.tinygrad_runner import ( + main, # noqa: F401 # pyright: ignore[reportUnusedImport] + ) + +def test_engine_factory_importable() -> None: + """engine_factory.py must be importable on any platform.""" + from exo.worker.engines.engine_factory import ( + Engine, # noqa: F401 # pyright: ignore[reportUnusedImport] + create_engine, # noqa: F401 # pyright: ignore[reportUnusedImport] + ) + +def test_engine_is_immutable() -> None: + """Engine must be an immutable Pydantic model with the expected fields.""" + from pydantic import BaseModel + + from exo.worker.engines.engine_factory import Engine + assert issubclass(Engine, BaseModel) + field_names = set(Engine.model_fields.keys()) + + required_fields = [ + "initialize", "load", "generate", "warmup", "cleanup", + "apply_chat_template", "detect_thinking_prompt_suffix", + ] + + assert all(field in field_names for field in required_fields) + +def test_tokenizer_protocol_importable() -> None: + from exo.shared.types.worker.tokenizer import ( + Tokenizer, # noqa: F401 # pyright: ignore[reportUnusedImport] + ) + +@requires_mlx +def test_mlx_engine_has_postprocessing_importable() -> None: + from exo.worker.engines.mlx.generator.generate import ( + mlx_generate_with_postprocessing, # noqa: F401 # pyright: ignore[reportUnusedImport] + ) + +@requires_mlx +def test_mlx_engine_has_postprocessing_signature() -> None: + from exo.worker.engines.mlx.generator.generate import ( + mlx_generate_with_postprocessing, + ) + sig = inspect.signature(mlx_generate_with_postprocessing) + params = list(sig.parameters.keys()) + + expected_params = ["model", "tokenizer", "model_id"] + + assert all(param in params for param in expected_params) diff --git a/src/exo/worker/tests/unittests/test_tinygrad/__init__.py b/src/exo/worker/tests/unittests/test_tinygrad/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_cache.py b/src/exo/worker/tests/unittests/test_tinygrad/test_cache.py new file mode 100644 index 0000000000..a5e5833ebb --- /dev/null +++ b/src/exo/worker/tests/unittests/test_tinygrad/test_cache.py @@ -0,0 +1,58 @@ +# pyright: reportUnknownMemberType=false +import math + +from tinygrad.tensor import Tensor + + +def test_cache_initial_seq_len(): + """Fresh cache should have seq_len equal to max_seq_len (pre-allocated).""" + from exo.worker.engines.tinygrad.cache import KVCache + + cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128) + assert cache.seq_len == 128 + + +def test_cache_update_writes_to_position(): + """Update should write keys/values at the specified position.""" + from exo.worker.engines.tinygrad.cache import KVCache + + cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128) + k = Tensor.ones(1, 8, 3, 64) + v = Tensor.ones(1, 8, 3, 64) + k_out, _v_out = cache.update(0, k, v, position=0) + + # Output shape is always the full pre-allocated buffer + assert k_out.shape == (1, 8, 128, 64) + # The written positions should be non-zero + assert float(k_out[0, 0, 0, 0].item()) != 0.0 + assert float(k_out[0, 0, 2, 0].item()) != 0.0 + # Positions beyond the write should remain zero + assert float(k_out[0, 0, 3, 0].item()) == 0.0 + + +def test_cache_update_tensor_position(): + """Update with Tensor position should write a single token.""" + from exo.worker.engines.tinygrad.cache import KVCache + + cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128) + k = Tensor.ones(1, 8, 1, 64) * 5.0 + v = Tensor.ones(1, 8, 1, 64) * 5.0 + pos = Tensor([10]).reshape(1, 1, 1, 1) + k_out, _v_out = cache.update(0, k, v, position=pos) + + # Position 10 should have the value we wrote + assert math.isclose(float(k_out[0, 0, 10, 0].item()), 5.0, rel_tol=1e-2) + # Other positions should remain zero + assert float(k_out[0, 0, 0, 0].item()) == 0.0 + + +def test_cache_layers_are_independent(): + """Updates to layer 0 should not affect layer 1.""" + from exo.worker.engines.tinygrad.cache import KVCache + + cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128) + k = Tensor.ones(1, 8, 3, 64) + v = Tensor.ones(1, 8, 3, 64) + cache.update(0, k, v, position=0) + # Layer 1 should still be all zeros (pre-allocated but untouched) + assert float(cache.keys[1].sum().item()) == 0.0 diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_forward.py b/src/exo/worker/tests/unittests/test_tinygrad/test_forward.py new file mode 100644 index 0000000000..2be8f52274 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_tinygrad/test_forward.py @@ -0,0 +1,48 @@ +# pyright: reportUnknownMemberType=false +import pytest +from tinygrad.dtype import dtypes +from tinygrad.tensor import Tensor + + +@pytest.mark.slow +def test_forward_pass_shape(): + """Forward pass should produce logits of shape (batch, seq, vocab_size).""" + from pathlib import Path + + from exo.shared.model_config import parse_model_config + from exo.worker.engines.tinygrad.forward import forward_pass + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2) + + input_ids = Tensor([[1, 2, 3, 4]], dtype=dtypes.int32) + _logits, cache = forward_pass(weights, input_ids, cache=None, position_offset=0) + assert _logits.shape == (1, 4, config.vocab_size) + assert cache.seq_len == 4 + +@pytest.mark.slow +def test_decode_step_after_prefill(): + """A single-token decode step should extend the cache by 1.""" + from pathlib import Path + + from exo.shared.model_config import parse_model_config + from exo.worker.engines.tinygrad.forward import forward_pass + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2) + + input_ids = Tensor([[1, 2, 3, 4]], dtype=dtypes.int32) + _logits, cache = forward_pass(weights, input_ids, cache=None, position_offset=0) + + next_input = Tensor([[5]], dtype=dtypes.int32) + logits2, cache2 = forward_pass(weights, next_input, cache, position_offset=4) + assert logits2.shape == (1, 1, config.vocab_size) + assert cache2.seq_len == 5 diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_generate.py b/src/exo/worker/tests/unittests/test_tinygrad/test_generate.py new file mode 100644 index 0000000000..be64c47b0d --- /dev/null +++ b/src/exo/worker/tests/unittests/test_tinygrad/test_generate.py @@ -0,0 +1,184 @@ +# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false +import pytest + + +def test_constants_exist(): + """Default constants should be defined.""" + from exo.worker.engines.tinygrad.constants import ( + DEFAULT_MAX_TOKENS, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_P, + ) + + assert isinstance(DEFAULT_MAX_TOKENS, int) + assert DEFAULT_MAX_TOKENS > 0 + assert isinstance(DEFAULT_TEMPERATURE, float) + assert isinstance(DEFAULT_TOP_P, float) + +def test_initialize_tinygrad_callable(): + """initialize_tinygrad must be importable and callable.""" + from exo.worker.engines.tinygrad.utils_tinygrad import initialize_tinygrad + + assert callable(initialize_tinygrad) + +def test_load_tinygrad_items_callable(): + """load_tinygrad_items must be importable and callable.""" + from exo.worker.engines.tinygrad.utils_tinygrad import load_tinygrad_items + + assert callable(load_tinygrad_items) + +def test_tinygrad_generate_accepts_prompt_parameter(): + """tinygrad_generate must accept prompt as a parameter (matching MLX pattern).""" + import inspect + + from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate + + sig = inspect.signature(tinygrad_generate) + assert "prompt" in sig.parameters, "tinygrad_generate must accept a 'prompt' parameter" + assert "model" in sig.parameters + assert "tokenizer" in sig.parameters + assert "task" in sig.parameters + +@pytest.mark.slow +def test_generate_yields_generation_responses(): + """tinygrad_generate should yield GenerationResponse objects.""" + from pathlib import Path + from unittest.mock import MagicMock + + from exo.shared.model_config import parse_model_config + from exo.shared.types.worker.runner_response import GenerationResponse + from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2) + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + + task = MagicMock() + task.max_tokens = 3 + task.temperature = 0.0 + task.top_p = 0.9 + task.logprobs = False + task.top_logprobs = 0 + + # Runner computes prompt via apply_chat_template, then passes it + prompt = "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" + + responses = list(tinygrad_generate(weights, tokenizer, task, prompt=prompt)) + assert len(responses) > 0 + assert all(isinstance(r, GenerationResponse) for r in responses) + + # Last response should have finish_reason, stats, and usage + assert responses[-1].finish_reason is not None + assert responses[-1].stats is not None + assert responses[-1].usage is not None + assert responses[-1].usage.prompt_tokens > 0 + assert responses[-1].usage.completion_tokens > 0 + + # Intermediate responses should have no stats/usage + if len(responses) > 1: + assert responses[0].stats is None + assert responses[0].usage is None + +@pytest.mark.slow +def test_generate_populates_logprobs_when_requested(): + """When task.logprobs=True, responses should include logprob and top_logprobs.""" + from pathlib import Path + from unittest.mock import MagicMock + + from exo.shared.model_config import parse_model_config + from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2) + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + + task = MagicMock() + task.max_tokens = 2 + task.temperature = 0.0 + task.top_p = 0.9 + task.logprobs = True + task.top_logprobs = 3 + + prompt = "<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n" + responses = list(tinygrad_generate(weights, tokenizer, task, prompt=prompt)) + + for r in responses: + assert r.logprob is not None, "logprob must be set when task.logprobs=True" + assert r.logprob <= 0.0, "logprob must be <= 0" + assert r.top_logprobs is not None + assert len(r.top_logprobs) == 3 + +@pytest.mark.slow +def test_generate_omits_logprobs_when_not_requested(): + """When task.logprobs=False, responses should have logprob=None.""" + from pathlib import Path + from unittest.mock import MagicMock + + from exo.shared.model_config import parse_model_config + from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2) + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + + task = MagicMock() + task.max_tokens = 2 + task.temperature = 0.0 + task.top_p = 0.9 + task.logprobs = False + task.top_logprobs = 0 + + prompt = "<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n" + responses = list(tinygrad_generate(weights, tokenizer, task, prompt=prompt)) + + for r in responses: + assert r.logprob is None + assert r.top_logprobs is None + +@pytest.mark.slow +def test_warmup_runs_full_generation(): + """warmup_inference should run a real generation loop, not just a forward pass.""" + from pathlib import Path + + from exo.shared.model_config import parse_model_config + from exo.worker.engines.tinygrad.generator.generate import warmup_inference + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2) + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + + tokens_generated = warmup_inference(model=weights, tokenizer=tokenizer) + assert tokens_generated >= 1, "warmup must generate at least 1 token" + assert tokens_generated <= 10, "warmup should be short (not full generation)" diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_layers.py b/src/exo/worker/tests/unittests/test_tinygrad/test_layers.py new file mode 100644 index 0000000000..5c0933cbf9 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_tinygrad/test_layers.py @@ -0,0 +1,120 @@ +# pyright: reportUnknownMemberType=false +from tinygrad.tensor import Tensor + +# --- RMS Norm --- + +def test_rms_norm_shape(): + """Output shape must match input shape.""" + from exo.worker.engines.tinygrad.layers.normalization import rms_norm + + x = Tensor.randn(1, 4, 2048) + weight = Tensor.ones(2048) + out = rms_norm(x, weight, eps=1e-6) + assert out.shape == (1, 4, 2048) + +def test_rms_norm_unit_weight(): + """With unit weight, RMS norm output should have the same shape and be a different tensor.""" + from exo.worker.engines.tinygrad.layers.normalization import rms_norm + + x = Tensor.randn(1, 4, 2048) + weight = Tensor.ones(2048) + out = rms_norm(x, weight, eps=1e-6) + assert out.shape == x.shape + # Verify the computation graph produces the right intermediate shapes + rms = ((out * out).mean(axis=-1)).sqrt() + assert rms.shape == (1, 4) + +def test_rms_norm_scales_with_weight(): + """Doubling the weight should produce an output with the same shape.""" + from exo.worker.engines.tinygrad.layers.normalization import rms_norm + + x = Tensor.randn(1, 4, 2048) + weight_1x = Tensor.ones(2048) + weight_2x = Tensor.ones(2048) * 2.0 + out_1x = rms_norm(x, weight_1x, eps=1e-6) + out_2x = rms_norm(x, weight_2x, eps=1e-6) + assert out_1x.shape == out_2x.shape == (1, 4, 2048) + # Verify the difference tensor has the right shape + diff = (out_2x - out_1x * 2.0) + assert diff.shape == (1, 4, 2048) + +# --- RoPE --- + +def test_rope_frequency_shapes(): + """cos/sin frequency tables should have shape (max_seq_len, head_dim//2).""" + from exo.worker.engines.tinygrad.layers.rotary import compute_rope_frequencies + + cos_f, sin_f = compute_rope_frequencies(head_dim=64, max_seq_len=128) + assert cos_f.shape == (128, 32) + assert sin_f.shape == (128, 32) + +def test_rope_changes_with_position(): + """apply_rope at different offsets should produce tensors of the same shape.""" + from exo.worker.engines.tinygrad.layers.rotary import ( + apply_rope, + compute_rope_frequencies, + ) + + cos_f, sin_f = compute_rope_frequencies(head_dim=64, max_seq_len=128) + x = Tensor.randn(1, 8, 4, 64) # (batch, heads, seq=4, head_dim) + out_0 = apply_rope(x, cos_f, sin_f, position_offset=0) + out_10 = apply_rope(x, cos_f, sin_f, position_offset=10) + assert out_0.shape == out_10.shape == (1, 8, 4, 64) + # Verify the difference tensor can be constructed (computation graph is valid) + diff = (out_0 - out_10).abs() + assert diff.shape == (1, 8, 4, 64) + +def test_rope_preserves_shape(): + """apply_rope must not change the tensor shape.""" + from exo.worker.engines.tinygrad.layers.rotary import ( + apply_rope, + compute_rope_frequencies, + ) + + cos_f, sin_f = compute_rope_frequencies(head_dim=64, max_seq_len=128) + x = Tensor.randn(1, 8, 4, 64) + out = apply_rope(x, cos_f, sin_f, position_offset=0) + assert out.shape == x.shape + +# --- Embedding --- + +def test_embedding_lookup(): + """Embedding lookup should index into the weight table.""" + from exo.worker.engines.tinygrad.layers.embedding import apply_embedding + + embed = Tensor.randn(100, 32) # vocab=100, dim=32 + ids = Tensor([[0, 1, 2]]) + out = apply_embedding(embed, ids) + assert out.shape == (1, 3, 32) + +def test_lm_head_shape(): + """LM head should project hidden_size to vocab_size.""" + from exo.worker.engines.tinygrad.layers.embedding import apply_lm_head + + x = Tensor.randn(1, 4, 2048) + lm_head = Tensor.randn(128256, 2048) # (vocab, hidden) + out = apply_lm_head(x, lm_head) + assert out.shape == (1, 4, 128256) + +# --- MLP --- + +def test_swiglu_mlp_shape(): + """SwiGLU MLP: hidden_size in, hidden_size out (merged gate+up proj).""" + from exo.worker.engines.tinygrad.layers.mlp import swiglu_mlp + + x = Tensor.randn(1, 4, 2048) + gate_up = Tensor.randn(16384, 2048) # gate(8192) + up(8192) merged + down = Tensor.randn(2048, 8192) + out = swiglu_mlp(x, gate_up, down) + assert out.shape == (1, 4, 2048) + +# --- Attention --- + +def test_linear_forward_shape(): + """linear_forward(x, weight) should compute x @ weight.T.""" + from exo.worker.engines.tinygrad.layers.attention import linear_forward + + x = Tensor.randn(1, 4, 2048) + weight = Tensor.randn(512, 2048) # (out, in) + out = linear_forward(x, weight) + assert out.shape == (1, 4, 512) diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_quantization.py b/src/exo/worker/tests/unittests/test_tinygrad/test_quantization.py new file mode 100644 index 0000000000..89809f38c4 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_tinygrad/test_quantization.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +# pyright: reportUnknownMemberType=false +import pytest +from tinygrad.device import Device +from tinygrad.dtype import dtypes +from tinygrad.tensor import Tensor + +from exo.shared.model_config import ModelConfig + +Device.DEFAULT = "CPU" + +# ── Helpers ── + +def _make_model_config( + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_key_value_heads: int | None = None, + intermediate_size: int = 8192, + vocab_size: int = 128256, +) -> ModelConfig: + """Build a ModelConfig with sensible defaults for quantization tests.""" + from exo.shared.architecture.llama import LLAMA_SPEC + + n_kv = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + return ModelConfig( + architecture_spec=LLAMA_SPEC, + num_hidden_layers=32, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=n_kv, + vocab_size=vocab_size, + head_dim=hidden_size // num_attention_heads, + rope_theta=10000.0, + rope_scaling=None, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + quantization_config=None, + ) + +# ── Packing / Unpacking ── + +def test_calculate_pack_factor(): + """pack_factor = 32 // bits.""" + from exo.worker.engines.tinygrad.quantization.packing import calculate_pack_factor + + assert calculate_pack_factor(4) == 8 + assert calculate_pack_factor(8) == 4 + assert calculate_pack_factor(2) == 16 + +def test_pack_bits_output_shape(): + """Packed last dim should be original_last_dim / pack_factor.""" + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + original = Tensor.zeros(64, 4096, dtype=dtypes.float16) + packed = pack_bits(original, bits=4) + assert packed.tensor.shape == (64, 512) + +def test_packed_tensor_metadata(): + """PackedTensor must carry original_shape, pack_factor, bits.""" + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + original = Tensor.zeros(64, 4096, dtype=dtypes.float16) + packed = pack_bits(original, bits=4) + assert packed.original_shape == (64, 4096) + assert packed.pack_factor == 8 + assert packed.bits == 4 + +def test_packed_tensor_is_frozen(): + """PackedTensor must be immutable (frozen dataclass).""" + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + packed = pack_bits(Tensor.zeros(4, 8, dtype=dtypes.float16), bits=4) + with pytest.raises(AttributeError): + packed.bits = 8 + +def test_pack_unpack_roundtrip_4bit(): + """pack then unpack should recover original values (4-bit).""" + from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits + + original = Tensor([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=dtypes.float16) + packed = pack_bits(original, bits=4) + unpacked = unpack_bits(packed) + assert unpacked.shape == original.shape + for i in range(8): + assert abs(unpacked[0, i].item() - original[0, i].item()) < 1e-3 + +def test_pack_unpack_roundtrip_8bit(): + """pack then unpack should recover original values (8-bit).""" + from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits + + original = Tensor([[10, 20, 30, 40]], dtype=dtypes.float16) + packed = pack_bits(original, bits=8) + unpacked = unpack_bits(packed) + assert unpacked.shape == original.shape + for i in range(4): + assert abs(unpacked[0, i].item() - original[0, i].item()) < 1e-3 + +def test_pack_unpack_non_divisible_dim(): + """Unpacking handles last dims not divisible by pack_factor.""" + from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits + + original = Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=dtypes.float16) + packed = pack_bits(original, bits=4) + unpacked = unpack_bits(packed) + assert unpacked.shape == (1, 10) + for i in range(10): + assert abs(unpacked[0, i].item() - original[0, i].item()) < 1e-3 + +def test_unpack_dtype_is_float16(): + """Unpacked tensor must be float16.""" + from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits + + original = Tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=dtypes.float16) + unpacked = unpack_bits(pack_bits(original, bits=4)) + assert unpacked.dtype == dtypes.float16 + +# ── Dequantization ── + +def test_affine_dequantize_identity(): + """scale=1, bias=0 → identity.""" + from exo.worker.engines.tinygrad.quantization.dequantization import ( + affine_dequantize, + ) + + quantized = Tensor([[1.0, 2.0, 3.0, 4.0]]) + scales = Tensor([[1.0]]) + biases = Tensor([[0.0]]) + result = affine_dequantize(quantized, scales, biases, group_size=4) + assert result.shape == (1, 4) + for i in range(4): + assert abs(result[0, i].item() - quantized[0, i].item()) < 1e-5 + +def test_affine_dequantize_scaling(): + """scale=2 should double values.""" + from exo.worker.engines.tinygrad.quantization.dequantization import ( + affine_dequantize, + ) + + quantized = Tensor([[1.0, 2.0, 3.0, 4.0]]) + scales = Tensor([[2.0]]) + biases = Tensor([[0.0]]) + result = affine_dequantize(quantized, scales, biases, group_size=4) + for i in range(4): + assert abs(result[0, i].item() - quantized[0, i].item() * 2.0) < 1e-5 + +def test_affine_dequantize_bias_offset(): + """bias=10 should offset by 10.""" + from exo.worker.engines.tinygrad.quantization.dequantization import ( + affine_dequantize, + ) + + quantized = Tensor([[0.0, 0.0, 0.0, 0.0]]) + scales = Tensor([[1.0]]) + biases = Tensor([[10.0]]) + result = affine_dequantize(quantized, scales, biases, group_size=4) + for i in range(4): + assert abs(result[0, i].item() - 10.0) < 1e-5 + +def test_affine_dequantize_multiple_groups(): + """Each group uses its own scale and bias.""" + from exo.worker.engines.tinygrad.quantization.dequantization import ( + affine_dequantize, + ) + + quantized = Tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + scales = Tensor([[2.0, 3.0]]) # 2 groups of 4 + biases = Tensor([[0.0, 10.0]]) + result = affine_dequantize(quantized, scales, biases, group_size=4) + assert abs(result[0, 0].item() - 2.0) < 1e-5 # group 0: 1*2+0 + assert abs(result[0, 4].item() - 13.0) < 1e-5 # group 1: 1*3+10 + +def test_affine_dequantize_2d_weight_matrix(): + """Dequantize a realistic 2D weight matrix [out_features, in_features].""" + from exo.worker.engines.tinygrad.quantization.dequantization import ( + affine_dequantize, + ) + + out_features, in_features, group_size = 32, 64, 32 + quantized = Tensor.ones(out_features, in_features) + num_groups = in_features // group_size + scales = Tensor.ones(out_features, num_groups) * 0.5 + biases = Tensor.zeros(out_features, num_groups) + result = affine_dequantize(quantized, scales, biases, group_size=group_size) + assert result.shape == (out_features, in_features) + assert abs(result[0, 0].item() - 0.5) < 1e-5 + +# ── QuantizedLinear ── + +def test_quantized_linear_is_callable(): + """QuantizedLinear instances must be callable via __call__.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + weight = Tensor.ones(16, 32, dtype=dtypes.float16) + packed = pack_bits(weight, bits=4) + layer = QuantizedLinear( + weight_q=packed, + scales=Tensor.ones(16, 1), + biases=Tensor.zeros(16, 1), + group_size=32, + ) + x = Tensor.randn(1, 4, 32) + out = layer(x) + assert out.shape == (1, 4, 16) + +def test_quantized_linear_output_shape(): + """QuantizedLinear(in=64, out=32) → (batch, seq, 32).""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + weight = Tensor.ones(32, 64, dtype=dtypes.float16) + packed = pack_bits(weight, bits=4) + layer = QuantizedLinear( + weight_q=packed, + scales=Tensor.ones(32, 1), + biases=Tensor.zeros(32, 1), + group_size=64, + ) + x = Tensor.randn(1, 8, 64) + out = layer(x) + assert out.shape == (1, 8, 32) + +def test_quantized_linear_dequantize_returns_correct_shape(): + """dequantize() returns a tensor with the original weight shape.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + weight = Tensor.ones(16, 32, dtype=dtypes.float16) + packed = pack_bits(weight, bits=4) + layer = QuantizedLinear( + weight_q=packed, + scales=Tensor.ones(16, 1), + biases=Tensor.zeros(16, 1), + group_size=32, + ) + dequantized = layer.dequantize() + assert dequantized.shape == (16, 32) + +def test_quantized_linear_in_out_features(): + """in_features and out_features reflect the original weight shape.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + weight = Tensor.ones(128, 256, dtype=dtypes.float16) + packed = pack_bits(weight, bits=4) + layer = QuantizedLinear( + weight_q=packed, + scales=Tensor.ones(128, 4), + biases=Tensor.zeros(128, 4), + group_size=64, + ) + assert layer.out_features == 128 + assert layer.in_features == 256 + +def test_quantized_linear_no_bias_by_default(): + """QuantizedLinear should not add a bias term by default.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + weight = Tensor.ones(16, 32, dtype=dtypes.float16) + packed = pack_bits(weight, bits=4) + layer = QuantizedLinear( + weight_q=packed, + scales=Tensor.ones(16, 1), + biases=Tensor.zeros(16, 1), + group_size=32, + ) + assert layer.bias is None + +def test_quantized_linear_is_final(): + """QuantizedLinear must be decorated with @final.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + + assert getattr(QuantizedLinear, "__final__", False) + +# ── QuantizedEmbedding ── + +def test_quantized_embedding_lookup(): + """QuantizedEmbedding must return correct shape for index lookup.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedEmbedding + from exo.worker.engines.tinygrad.quantization.packing import pack_bits + + weight = Tensor.ones(100, 32, dtype=dtypes.float16) + packed = pack_bits(weight, bits=4) + layer = QuantizedEmbedding( + num_embeddings=100, + embedding_dim=32, + weight_q=packed, + scales=Tensor.ones(100, 1), + biases=Tensor.zeros(100, 1), + group_size=32, + ) + ids = Tensor([[0, 1, 2]]) + out = layer(ids) + assert out.shape == (1, 3, 32) + +def test_quantized_embedding_is_final(): + """QuantizedEmbedding must be decorated with @final.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedEmbedding + + assert getattr(QuantizedEmbedding, "__final__", False) + +# ── Shape Inference (requires Phase 1 ModelConfig) ── + +def test_infer_shape_q_proj(): + """q_proj shape = (hidden_size, hidden_size).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config(hidden_size=2048, num_attention_heads=32) + shape = infer_weight_shape("model.layers.0.self_attn.q_proj.weight", config) + assert shape == (2048, 2048) + +def test_infer_shape_k_proj_gqa(): + """k_proj with GQA = (num_kv_heads * head_dim, hidden_size).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config( + hidden_size=2048, num_attention_heads=32, num_key_value_heads=8, + ) + # head_dim = 2048/32 = 64, kv_dim = 8*64 = 512 + shape = infer_weight_shape("model.layers.0.self_attn.k_proj.weight", config) + assert shape == (512, 2048) + +def test_infer_shape_v_proj_gqa(): + """v_proj with GQA = (num_kv_heads * head_dim, hidden_size).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config( + hidden_size=2048, num_attention_heads=32, num_key_value_heads=8, + ) + shape = infer_weight_shape("model.layers.0.self_attn.v_proj.weight", config) + assert shape == (512, 2048) + +def test_infer_shape_k_proj_mha(): + """When num_kv_heads == num_heads (MHA), k_proj = (hidden, hidden).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config( + hidden_size=2048, num_attention_heads=32, num_key_value_heads=32, + ) + shape = infer_weight_shape("model.layers.0.self_attn.k_proj.weight", config) + assert shape == (2048, 2048) + +def test_detect_layer_type_v_proj(): + """Must detect v_proj correctly (not w_proj).""" + from exo.worker.engines.tinygrad.quantization.shapes import detect_layer_type + + assert detect_layer_type("model.layers.0.self_attn.v_proj.weight") == "v_proj" + +def test_infer_shape_gate_proj(): + """gate_proj = (intermediate_size, hidden_size).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config(hidden_size=2048, intermediate_size=8192) + shape = infer_weight_shape("model.layers.0.mlp.gate_proj.weight", config) + assert shape == (8192, 2048) + +def test_infer_shape_down_proj(): + """down_proj = (hidden_size, intermediate_size).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config(hidden_size=2048, intermediate_size=8192) + shape = infer_weight_shape("model.layers.0.mlp.down_proj.weight", config) + assert shape == (2048, 8192) + +def test_infer_shape_embed_tokens(): + """embed_tokens = (vocab_size, hidden_size).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config(hidden_size=2048, vocab_size=128256) + shape = infer_weight_shape("model.embed_tokens.weight", config) + assert shape == (128256, 2048) + + +def test_infer_shape_lm_head(): + """lm_head = (vocab_size, hidden_size).""" + from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape + + config = _make_model_config(hidden_size=2048, vocab_size=128256) + shape = infer_weight_shape("lm_head.weight", config) + assert shape == (128256, 2048) + +# ── Package Exports ── + +def test_package_exports_quantized_linear(): + from exo.worker.engines.tinygrad.quantization import ( + QuantizedLinear, # noqa: F401 # pyright: ignore[reportUnusedImport] + ) + +def test_package_exports_quantized_embedding(): + from exo.worker.engines.tinygrad.quantization import ( + QuantizedEmbedding, # noqa: F401 # pyright: ignore[reportUnusedImport] + ) + +def test_package_exports_packed_tensor(): + from exo.worker.engines.tinygrad.quantization import ( + PackedTensor, # noqa: F401 # pyright: ignore[reportUnusedImport] + ) diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_sampling.py b/src/exo/worker/tests/unittests/test_tinygrad/test_sampling.py new file mode 100644 index 0000000000..322681231d --- /dev/null +++ b/src/exo/worker/tests/unittests/test_tinygrad/test_sampling.py @@ -0,0 +1,94 @@ +# pyright: reportUnknownMemberType=false +from tinygrad.tensor import Tensor + + +def test_sample_result_structure(): + """sample_token should return a SampleResult with token_id, logprob, top_logprobs.""" + from exo.worker.engines.tinygrad.sampling import SampleResult, sample_token + + logits = Tensor([[[0.1, 0.2, 0.9, 0.3]]]) + result = sample_token(logits, temperature=0.0) + assert isinstance(result, SampleResult) + assert isinstance(result.token_id, int) + assert isinstance(result.logprob, float) + assert isinstance(result.top_logprobs, list) + +def test_greedy_sampling(): + """Temperature=0 should return the argmax token.""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor([[[0.1, 0.2, 0.9, 0.3]]]) # token 2 has highest logit + result = sample_token(logits, temperature=0.0) + assert result.token_id == 2 + +def test_temperature_zero_is_deterministic(): + """Greedy sampling must always produce the same result.""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor.randn(1, 1, 1000) + r1 = sample_token(logits, temperature=0.0) + r2 = sample_token(logits, temperature=0.0) + assert r1.token_id == r2.token_id + +def test_sampling_returns_valid_token(): + """Sampled token must be in [0, vocab_size).""" + from exo.worker.engines.tinygrad.sampling import sample_token + + vocab_size = 100 + logits = Tensor.randn(1, 1, vocab_size) + result = sample_token(logits, temperature=0.7) + assert 0 <= result.token_id < vocab_size + +def test_logprob_is_negative(): + """Log-probabilities must be <= 0 (probabilities are in [0, 1]).""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor.randn(1, 1, 100) + result = sample_token(logits, temperature=0.7, request_logprobs=True) + assert result.logprob <= 0.0 + +def test_greedy_logprob_is_highest(): + """Greedy token should have the highest logprob.""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor([[[0.1, 0.2, 0.9, 0.3]]]) + result = sample_token(logits, temperature=0.0, top_logprobs_count=4, request_logprobs=True) + # The selected token's logprob should match the top entry + assert result.token_id == result.top_logprobs[0][0] + assert abs(result.logprob - result.top_logprobs[0][1]) < 1e-5 + +def test_top_logprobs_count(): + """top_logprobs should return exactly the requested count.""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor.randn(1, 1, 100) + result = sample_token(logits, temperature=0.0, top_logprobs_count=5, request_logprobs=True) + assert len(result.top_logprobs) == 5 + +def test_top_logprobs_empty_when_not_requested(): + """top_logprobs should be empty list when count is 0.""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor.randn(1, 1, 100) + result = sample_token(logits, temperature=0.0, top_logprobs_count=0) + assert result.top_logprobs == [] + +def test_top_logprobs_sorted_descending(): + """top_logprobs entries should be sorted by logprob (highest first).""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor.randn(1, 1, 100) + result = sample_token(logits, temperature=0.0, top_logprobs_count=5, request_logprobs=True) + logprobs = [lp for _, lp in result.top_logprobs] + assert logprobs == sorted(logprobs, reverse=True) + +def test_high_temperature_is_more_random(): + """Higher temperature should produce more token diversity across runs.""" + from exo.worker.engines.tinygrad.sampling import sample_token + + logits = Tensor.randn(1, 1, 10) + low_temp_tokens = {sample_token(logits, temperature=0.01).token_id for _ in range(20)} + high_temp_tokens = {sample_token(logits, temperature=2.0).token_id for _ in range(20)} + # High temp should generally produce more unique tokens + # (probabilistic — use a lenient check) + assert len(high_temp_tokens) >= len(low_temp_tokens) diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_weight_loader.py b/src/exo/worker/tests/unittests/test_tinygrad/test_weight_loader.py new file mode 100644 index 0000000000..1773ffdc44 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_tinygrad/test_weight_loader.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from pathlib import Path + +# pyright: reportUnknownMemberType=false +import pytest +from tinygrad.device import Device +from tinygrad.dtype import dtypes +from tinygrad.tensor import Tensor + +from exo.shared.model_config import ModelConfig, QuantizationConfig + +Device.DEFAULT = "CPU" + +# ── Helpers ── + +def _make_model_config( + hidden_size: int = 64, + num_attention_heads: int = 4, + num_key_value_heads: int | None = None, + intermediate_size: int = 128, + vocab_size: int = 256, + quantized: bool = False, +) -> ModelConfig: + """Build a ModelConfig with small defaults for weight-loader tests.""" + from exo.shared.architecture.llama import LLAMA_SPEC + + n_kv = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + return ModelConfig( + architecture_spec=LLAMA_SPEC, + num_hidden_layers=2, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=n_kv, + vocab_size=vocab_size, + head_dim=hidden_size // num_attention_heads, + rope_theta=10000.0, + rope_scaling=None, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + quantization_config=QuantizationConfig(bits=4, group_size=32) if quantized else None, + ) + + +def _mlx_quantized_raw( + key_prefix: str, + out_features: int, + in_features: int, + bits: int = 4, + group_size: int = 32, +) -> dict[str, Tensor]: + """Create a synthetic MLX-format quantized weight dict (.weight + .scales + .biases).""" + pack_factor = 32 // bits + packed_dim = in_features // pack_factor + num_groups = in_features // group_size + + return { + f"{key_prefix}.weight": Tensor.zeros(out_features, packed_dim, dtype=dtypes.uint32), + f"{key_prefix}.scales": Tensor.ones(out_features, num_groups, dtype=dtypes.float16), + f"{key_prefix}.biases": Tensor.zeros(out_features, num_groups, dtype=dtypes.float16), + } + + +# ── NamedTuple structure tests (pre-existing) ── + +def test_layer_weights_is_named_tuple(): + """LayerWeights should be a NamedTuple with expected fields.""" + from exo.worker.engines.tinygrad.weight_loader import LayerWeights + + assert hasattr(LayerWeights, "_fields") + fields = LayerWeights._fields + assert "qkv_proj" in fields + assert "o_proj" in fields + assert "gate_up_proj" in fields + assert "down_proj" in fields + assert "input_norm" in fields + assert "post_attn_norm" in fields + +def test_transformer_weights_is_named_tuple(): + """TransformerWeights should contain embed, lm_head, final_norm, layers, config.""" + from exo.worker.engines.tinygrad.weight_loader import TransformerWeights + + fields = TransformerWeights._fields + assert "embed_tokens" in fields + assert "lm_head" in fields + assert "final_norm" in fields + assert "layers" in fields + assert "config" in fields + +@pytest.mark.slow +def test_load_llama_weights(): + """Load 2 layers from a real Llama model and verify structure.""" + from exo.shared.model_config import parse_model_config + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2) + assert len(weights.layers) == 2 + assert weights.embed_tokens is not None + assert weights.final_norm is not None + assert weights.lm_head is not None + +@pytest.mark.slow +def test_load_respects_layer_range(): + """Loading layers 2-4 should produce exactly 2 LayerWeights.""" + from exo.shared.model_config import parse_model_config + from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights + + model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit" + if not model_path.exists(): + pytest.skip("Model not downloaded") + config = parse_model_config(model_path / "config.json") + weights = load_transformer_weights(model_path, config, start_layer=2, end_layer=4) + assert len(weights.layers) == 2 + +def test_load_missing_safetensors_raises(tmp_path: Path): + """Loading from an empty directory should raise FileNotFoundError.""" + from exo.worker.engines.tinygrad.weight_loader import ( + _load_all_safetensors, # pyright: ignore[reportPrivateUsage] + ) + + with pytest.raises(FileNotFoundError, match="No .safetensors"): + _load_all_safetensors(tmp_path) + +# ── _build_weight: MLX quantized format ── + +def test_build_weight_mlx_quantized_linear(): + """MLX quantized linear (.weight + .scales + .biases with quantization_config) → QuantizedLinear.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.weight_loader import ( + _build_weight, # pyright: ignore[reportPrivateUsage] + ) + + config = _make_model_config(quantized=True) + # o_proj shape: (hidden_size, hidden_size) = (64, 64) + raw = _mlx_quantized_raw( + "model.layers.0.self_attn.o_proj", out_features=64, in_features=64, + ) + result = _build_weight(raw, "model.layers.0.self_attn.o_proj.weight", config) + assert isinstance(result, QuantizedLinear) + +def test_build_weight_mlx_quantized_embedding(): + """MLX quantized embedding (.weight + .scales + .biases with is_embedding=True) → QuantizedEmbedding.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedEmbedding + from exo.worker.engines.tinygrad.weight_loader import ( + _build_weight, # pyright: ignore[reportPrivateUsage] + ) + + config = _make_model_config(quantized=True) + # embed_tokens shape: (vocab_size, hidden_size) = (256, 64) + raw = _mlx_quantized_raw( + "model.embed_tokens", out_features=256, in_features=64, + ) + result = _build_weight(raw, "model.embed_tokens.weight", config, is_embedding=True) + assert isinstance(result, QuantizedEmbedding) + +def test_build_weight_unquantized_returns_tensor(): + """Unquantized model (.weight only, no quantization_config) → plain Tensor.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.weight_loader import ( + _build_weight, # pyright: ignore[reportPrivateUsage] + ) + + config = _make_model_config(quantized=False) + raw = {"model.layers.0.self_attn.o_proj.weight": Tensor.zeros(64, 64)} + result = _build_weight(raw, "model.layers.0.self_attn.o_proj.weight", config) + assert isinstance(result, Tensor) + assert not isinstance(result, QuantizedLinear) + +def test_build_weight_quantized_config_but_no_companions(): + """Quantized config present but no .scales/.biases companions → plain Tensor (not quantized).""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.weight_loader import ( + _build_weight, # pyright: ignore[reportPrivateUsage] + ) + + config = _make_model_config(quantized=True) + # Only .weight, no .scales or .biases + raw = {"model.layers.0.self_attn.o_proj.weight": Tensor.zeros(64, 64)} + result = _build_weight(raw, "model.layers.0.self_attn.o_proj.weight", config) + assert isinstance(result, Tensor) + assert not isinstance(result, QuantizedLinear) + +def test_build_weight_missing_key_raises(): + """Empty dict → KeyError.""" + from exo.worker.engines.tinygrad.weight_loader import ( + _build_weight, # pyright: ignore[reportPrivateUsage] + ) + + config = _make_model_config(quantized=False) + with pytest.raises(KeyError): + _build_weight({}, "model.layers.0.self_attn.o_proj.weight", config) + +def test_build_weight_mlx_quantized_shape_correctness(): + """MLX quantized gate_proj should have correct in_features and out_features.""" + from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear + from exo.worker.engines.tinygrad.weight_loader import ( + _build_weight, # pyright: ignore[reportPrivateUsage] + ) + + config = _make_model_config(quantized=True) + # gate_proj shape: (intermediate_size, hidden_size) = (128, 64) + raw = _mlx_quantized_raw( + "model.layers.0.mlp.gate_proj", out_features=128, in_features=64, + ) + result = _build_weight(raw, "model.layers.0.mlp.gate_proj.weight", config) + assert isinstance(result, QuantizedLinear) + assert result.out_features == 128 + assert result.in_features == 64