diff --git a/README.md b/README.md
index 785f6fbab7..9404315528 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,7 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link.
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
+- **Experimental tinygrad support:** exo uses [tinygrad](https://github.com/tinygrad/tinygrad) as an inference backend for non-Apple systems, with support for AMD (ROCm/HIP) and NVIDIA (CUDA) GPU backends. The current implementation works on single instance, **model sharding will be tested in the near future.**
## Dashboard
@@ -136,6 +137,8 @@ This starts the exo dashboard and API at http://localhost:52415/
**Installation methods:**
**Option 1: Using system package manager (Ubuntu/Debian example):**
+
+- For Ubuntu/Debian based distros
```bash
# Install Node.js and npm
sudo apt update
@@ -149,6 +152,15 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
+- For Arch based distros
+```bash
+sudo pacman -S uv rustup
+
+# Install the nightly toolchain
+rustup toolchain install nightly
+```
+
+
**Option 2: Using Homebrew on Linux (if preferred):**
```bash
# Install Homebrew on Linux
@@ -164,7 +176,149 @@ rustup toolchain install nightly
**Note:** The `macmon` package is macOS-only and not required for Linux.
-Clone the repo, build the dashboard, and run exo:
+**Install GPU drivers for tinygrad**
+
+The tinygrad inference backend supports multiple GPU backends. Install the appropriate
+drivers for your hardware. The backend is selected at runtime via the `DEV` environment
+variable (e.g. `DEV=HIP`, `DEV=CUDA`). If not set, tinygrad auto-detects the best
+available backend.
+
+
+AMD GPUs (ROCm / HIP)
+
+ROCm provides GPU compute support for AMD GPUs. The tinygrad backend uses different
+device modes depending on your GPU architecture:
+
+| GPU Architecture | Device | Notes |
+|---|---|---|
+| **RDNA 3+** (RX 7000 series, MI300, etc.) | `DEV=AMD` | Optimized HCQ backend |
+| **RDNA 2** (RX 6000 series) | `DEV=HIP` | HIP runtime backend |
+
+> **Important:** The `DEV=AMD` backend (HCQ) only supports RDNA 3 and newer. RDNA 2 users
+> **must** use `DEV=HIP`. Both modes require a ROCm installation.
+
+**Ubuntu / Debian:**
+
+```bash
+# Install prerequisites
+sudo apt update && sudo apt install -y wget gpg
+
+# Add the AMD ROCm repository
+# Check https://rocm.docs.amd.com/en/latest/deploy/linux/install.html for the latest version
+wget https://repo.radeon.com/amdgpu-install/6.4/ubuntu/$(lsb_release -cs)/amdgpu-install_6.4.60400-1_all.deb
+sudo apt install -y ./amdgpu-install_6.4.60400-1_all.deb
+
+# Install ROCm HIP libraries
+sudo amdgpu-install --usecase=hip --no-dkms
+
+# Add your user to the render and video groups
+sudo usermod -aG render,video $USER
+```
+
+> For Debian, replace `$(lsb_release -cs)` with the equivalent Ubuntu codename
+> (e.g. `noble` for Debian Trixie). Check [AMD's documentation](https://rocm.docs.amd.com/en/latest/deploy/linux/install.html)
+> for supported versions.
+
+**Fedora:**
+
+```bash
+# Install prerequisites
+sudo dnf install -y kernel-headers kernel-devel
+
+# Add the AMD ROCm repository
+# Check https://rocm.docs.amd.com/en/latest/deploy/linux/install.html for the latest version
+sudo tee /etc/yum.repos.d/amdgpu.repo <<'EOF'
+[amdgpu]
+name=amdgpu
+baseurl=https://repo.radeon.com/amdgpu/6.4/el/9/main/x86_64/
+enabled=1
+gpgcheck=1
+gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
+EOF
+
+sudo tee /etc/yum.repos.d/rocm.repo <<'EOF'
+[rocm]
+name=rocm
+baseurl=https://repo.radeon.com/rocm/el9/6.4/main
+enabled=1
+gpgcheck=1
+gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
+EOF
+
+sudo dnf install -y rocm-hip-runtime hip-devel
+
+# Add your user to the render and video groups
+sudo usermod -aG render,video $USER
+```
+
+**Arch:**
+
+```bash
+# Install ROCm HIP from official repositories
+sudo pacman -S rocm-hip-runtime rocm-hip-sdk
+
+# Add your user to the render and video groups
+sudo usermod -aG render,video $USER
+```
+
+**Running exo with AMD GPU:**
+
+```bash
+# RDNA 3+ GPUs
+DEV=AMD uv run exo
+
+# RDNA 2 GPUs (RX 6000 series)
+DEV=HIP uv run exo
+```
+
+
+
+
+NVIDIA GPUs (CUDA)
+
+NVIDIA GPUs require proprietary drivers and the CUDA toolkit.
+
+**Ubuntu / Debian:**
+
+```bash
+# Option 1: Install via ubuntu-drivers (Ubuntu only, simplest method)
+sudo add-apt-repository ppa:graphics-drivers/ppa && sudo apt update
+sudo ubuntu-drivers autoinstall
+
+# Option 2: Install CUDA toolkit directly (Ubuntu / Debian)
+# Check https://developer.nvidia.com/cuda-downloads for the latest version
+sudo apt install -y nvidia-driver-560 nvidia-cuda-toolkit
+```
+
+**Fedora:**
+
+```bash
+# Enable RPM Fusion repositories
+sudo dnf install -y \
+ https://download1.rpmfusion.org/free/fedora/rpmfusion-free-release-$(rpm -E %fedora).noarch.rpm \
+ https://download1.rpmfusion.org/nonfree/fedora/rpmfusion-nonfree-release-$(rpm -E %fedora).noarch.rpm
+
+# Install NVIDIA drivers and CUDA
+sudo dnf install -y akmod-nvidia xorg-x11-drv-nvidia-cuda
+```
+
+**Arch:**
+
+```bash
+# Install NVIDIA drivers and CUDA toolkit
+sudo pacman -S nvidia cuda
+```
+
+**Running exo with NVIDIA GPU:**
+
+```bash
+DEV=CUDA uv run exo
+```
+
+
+
+> **Tip:** You can verify your GPU is detected by running with debug logging:
+> `DEBUG=1 DEV=HIP uv run exo` (substitute your device backend).
```bash
# Clone exo
@@ -179,8 +333,6 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
-**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
-
**Configuration Options:**
- `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity.
@@ -426,10 +578,10 @@ The tool outputs performance metrics including prompt tokens per second (prompt_
## Hardware Accelerator Support
-On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.
+On macOS, exo uses the GPU via [MLX](https://github.com/ml-explore/mlx). On Linux, exo uses the GPU via [tinygrad](https://github.com/tinygrad/tinygrad) with support for AMD (ROCm/HIP) and NVIDIA (CUDA) backends. See the [Linux installation guide](#run-from-source-linux) for GPU driver setup. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.
---
## Contributing
-See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
\ No newline at end of file
+See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
diff --git a/dashboard/package-lock.json b/dashboard/package-lock.json
index 345c73d257..e28a6b6003 100644
--- a/dashboard/package-lock.json
+++ b/dashboard/package-lock.json
@@ -865,6 +865,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -904,6 +905,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,6 +1522,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1529,6 +1532,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
+ "peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,6 +1945,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
+ "peer": true,
"engines": {
"node": ">=12"
}
@@ -2648,6 +2653,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
+ "peer": true,
"engines": {
"node": ">=12"
},
@@ -2690,6 +2696,7 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
+ "peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2862,6 +2869,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
+ "peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3006,6 +3014,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
+ "peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3027,6 +3036,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",
diff --git a/dashboard/src/lib/components/ModelCard.svelte b/dashboard/src/lib/components/ModelCard.svelte
index b432b7a87b..5b1edc2f10 100644
--- a/dashboard/src/lib/components/ModelCard.svelte
+++ b/dashboard/src/lib/components/ModelCard.svelte
@@ -21,7 +21,7 @@
} | null;
nodes?: Record;
sharding?: "Pipeline" | "Tensor";
- runtime?: "MlxRing" | "MlxJaccl";
+ runtime?: "MlxRing" | "MlxJaccl" | "Tinygrad";
onLaunch?: () => void;
tags?: string[];
apiPreview?: PlacementPreview | null;
diff --git a/dashboard/src/lib/stores/app.svelte.ts b/dashboard/src/lib/stores/app.svelte.ts
index 00c6669f23..3c382540d9 100644
--- a/dashboard/src/lib/stores/app.svelte.ts
+++ b/dashboard/src/lib/stores/app.svelte.ts
@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
export interface PlacementPreview {
model_id: string;
sharding: "Pipeline" | "Tensor";
- instance_meta: "MlxRing" | "MlxJaccl";
+ instance_meta: "MlxRing" | "MlxJaccl" | "Tinygrad";
instance: unknown | null;
memory_delta_by_node: Record | null;
error: string | null;
@@ -580,6 +580,7 @@ class AppStore {
debugMode = $state(false);
topologyOnlyMode = $state(false);
chatSidebarVisible = $state(true); // Shown by default
+ enableLogprobs = $state(false);
// Image generation params
imageGenerationParams = $state({
@@ -608,6 +609,7 @@ class AppStore {
this.loadTopologyOnlyModeFromStorage();
this.loadChatSidebarVisibleFromStorage();
this.loadImageGenerationParamsFromStorage();
+ this.loadEnableLogprobsFromStorage();
}
}
@@ -677,6 +679,25 @@ class AppStore {
}
}
+ private loadEnableLogprobsFromStorage() {
+ try {
+ const stored = localStorage.getItem("exo-enable-logprobs");
+ if (stored !== null) {
+ this.enableLogprobs = stored === "true";
+ }
+ } catch (error) {
+ console.error("Failed to load enable logprobs:", error);
+ }
+ }
+
+ private saveEnableLogprobsToStorage() {
+ try {
+ localStorage.setItem("exo-enable-logprobs", this.enableLogprobs ? "true" : "false");
+ } catch (error) {
+ console.error("Failed to save enable logprobs:", error);
+ }
+ }
+
private loadTopologyOnlyModeFromStorage() {
try {
const stored = localStorage.getItem("exo-topology-only-mode");
@@ -1213,6 +1234,20 @@ class AppStore {
this.saveDebugModeToStorage();
}
+ getEnableLogprobs(): boolean {
+ return this.enableLogprobs;
+ }
+
+ setEnableLogprobs(enabled: boolean) {
+ this.enableLogprobs = enabled;
+ this.saveEnableLogprobsToStorage();
+ }
+
+ toggleEnableLogprobs() {
+ this.enableLogprobs = !this.enableLogprobs;
+ this.saveEnableLogprobsToStorage();
+ }
+
getTopologyOnlyMode(): boolean {
return this.topologyOnlyMode;
}
@@ -1658,8 +1693,7 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
- logprobs: true,
- top_logprobs: 5,
+ ...(this.enableLogprobs && { logprobs: true, top_logprobs: 5 }),
}),
});
@@ -1865,8 +1899,7 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
- logprobs: true,
- top_logprobs: 5,
+ ...(this.enableLogprobs && { logprobs: true, top_logprobs: 5 }),
}),
});
@@ -2345,8 +2378,7 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
- logprobs: true,
- top_logprobs: 5,
+ ...(this.enableLogprobs && { logprobs: true, top_logprobs: 5 }),
...(enableThinking != null && {
enable_thinking: enableThinking,
}),
@@ -3254,6 +3286,10 @@ export const toggleSidebar = () => appStore.toggleSidebar();
export const toggleDebugMode = () => appStore.toggleDebugMode();
export const setDebugMode = (enabled: boolean) =>
appStore.setDebugMode(enabled);
+export const enableLogprobs = () => appStore.getEnableLogprobs();
+export const toggleEnableLogprobs = () => appStore.toggleEnableLogprobs();
+export const setEnableLogprobs = (enabled: boolean) =>
+ appStore.setEnableLogprobs(enabled);
export const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode();
export const setTopologyOnlyMode = (enabled: boolean) =>
appStore.setTopologyOnlyMode(enabled);
diff --git a/dashboard/src/routes/+page.svelte b/dashboard/src/routes/+page.svelte
index 7f374dd748..9d2023faf2 100644
--- a/dashboard/src/routes/+page.svelte
+++ b/dashboard/src/routes/+page.svelte
@@ -815,7 +815,7 @@
return model.tasks.includes("ImageToImage");
}
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
- type InstanceMeta = "MlxRing" | "MlxJaccl";
+ type InstanceMeta = "MlxRing" | "MlxJaccl" | "Tinygrad";
// Launch defaults persistence
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults-v2";
@@ -882,6 +882,7 @@
}
let selectedInstanceType = $state("MlxRing");
+ let instanceTypeInitialized = $state(false);
let selectedMinNodes = $state(1);
let minNodesInitialized = $state(false);
let launchingModelId = $state(null);
@@ -1074,9 +1075,11 @@
}
const matchesSelectedRuntime = (runtime: InstanceMeta): boolean =>
- selectedInstanceType === "MlxRing"
- ? runtime === "MlxRing"
- : runtime === "MlxJaccl";
+ selectedInstanceType === "Tinygrad"
+ ? runtime === "Tinygrad"
+ : selectedInstanceType === "MlxRing"
+ ? runtime === "MlxRing"
+ : runtime === "MlxJaccl";
// Helper to check if a model can be launched (has valid placement with >= minNodes)
function canModelFit(modelId: string): boolean {
@@ -2492,6 +2495,22 @@
}
});
+ // Auto-detect instance type from available placement previews (e.g., Tinygrad on Linux)
+ $effect(() => {
+ if (instanceTypeInitialized || previewsData.length === 0) return;
+ const hasCurrentType = previewsData.some(
+ (p: PlacementPreview) => matchesSelectedRuntime(p.instance_meta),
+ );
+ if (!hasCurrentType) {
+ // Switch to whatever the server provides
+ const firstMeta = previewsData[0]?.instance_meta;
+ if (firstMeta) {
+ selectedInstanceType = firstMeta;
+ }
+ }
+ instanceTypeInitialized = true;
+ });
+
// Calculate total memory usage across all nodes
const clusterMemory = $derived(() => {
if (!data) return { used: 0, total: 0 };
@@ -4985,6 +5004,29 @@
RDMA (Fast)
+
diff --git a/pyproject.toml b/pyproject.toml
index c4bfa55052..cf11937df0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,6 +30,7 @@ dependencies = [
"python-multipart>=0.0.21",
"msgspec>=0.19.0",
"zstandard>=0.23.0",
+ "tinygrad>=0.12.0; sys_platform == 'linux'",
]
[project.scripts]
diff --git a/src/exo/main.py b/src/exo/main.py
index 60729cc692..528dd401fb 100644
--- a/src/exo/main.py
+++ b/src/exo/main.py
@@ -43,7 +43,7 @@ class Node:
@classmethod
async def create(cls, args: "Args") -> Self:
keypair = get_node_id_keypair()
- node_id = NodeId(keypair.to_node_id())
+ node_id = NodeId(keypair.to_peer_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)
diff --git a/src/exo/master/api.py b/src/exo/master/api.py
index 5ecc1575d2..5e4f807480 100644
--- a/src/exo/master/api.py
+++ b/src/exo/master/api.py
@@ -2,6 +2,7 @@
import contextlib
import json
import random
+import sys
import time
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator
from datetime import datetime, timezone
@@ -47,7 +48,9 @@
)
from exo.master.event_log import DiskEventLog
from exo.master.image_store import ImageStore
-from exo.master.placement import place_instance as get_instance_placements
+from exo.master.placement import (
+ place_instance as get_instance_placements,
+)
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
@@ -169,7 +172,12 @@
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
-from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
+from exo.shared.types.worker.instances import (
+ Instance,
+ InstanceId,
+ InstanceMeta,
+ default_instance_meta,
+)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -394,9 +402,11 @@ async def get_placement(
self,
model_id: ModelId,
sharding: Sharding = Sharding.Pipeline,
- instance_meta: InstanceMeta = InstanceMeta.MlxRing,
+ instance_meta: InstanceMeta | None = None,
min_nodes: int = 1,
) -> Instance:
+ if instance_meta is None:
+ instance_meta = default_instance_meta()
model_card = await ModelCard.load(model_id)
try:
@@ -446,8 +456,13 @@ async def get_placement_previews(
status_code=400, detail=f"Failed to load model card: {exc}"
) from exc
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
+ instance_metas = (
+ (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl)
+ if sys.platform == "darwin"
+ else (InstanceMeta.Tinygrad,)
+ )
for sharding in (Sharding.Pipeline, Sharding.Tensor):
- for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
+ for instance_meta in instance_metas:
instance_combinations.extend(
[
(sharding, instance_meta, i)
diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py
index 5f191b846f..57b39f61d3 100644
--- a/src/exo/master/placement.py
+++ b/src/exo/master/placement.py
@@ -41,6 +41,7 @@
InstanceMeta,
MlxJacclInstance,
MlxRingInstance,
+ TinygradInstance,
)
from exo.shared.types.worker.shards import Sharding
@@ -142,7 +143,7 @@ def place_instance(
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
- if len(selected_cycle) == 1:
+ if len(selected_cycle) == 1 and command.instance_meta != InstanceMeta.Tinygrad:
command.instance_meta = InstanceMeta.MlxRing
match command.instance_meta:
@@ -192,6 +193,11 @@ def get_device_rank(node_id: NodeId) -> int:
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
+ case InstanceMeta.Tinygrad:
+ target_instances[instance_id] = TinygradInstance(
+ instance_id=instance_id,
+ shard_assignments=shard_assignments,
+ )
return target_instances
diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py
index cad495ead2..06042f677b 100644
--- a/src/exo/master/tests/test_placement.py
+++ b/src/exo/master/tests/test_placement.py
@@ -25,6 +25,7 @@
InstanceMeta,
MlxJacclInstance,
MlxRingInstance,
+ TinygradInstance,
)
from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@@ -456,3 +457,25 @@ def test_tensor_rdma_backend_connectivity_matrix(
else:
ip_part = coordinator.split(":")[0]
assert len(ip_part.split(".")) == 4
+
+def test_place_tinygrad_single_node(model_card: ModelCard):
+ """Tinygrad placement should create a TinygradInstance for a single node."""
+ topology = Topology()
+ node_id = NodeId()
+ topology.add_node(node_id)
+ node_memory = {node_id: create_node_memory(model_card.storage_size.in_bytes * 2)}
+ node_network = {node_id: create_node_network()}
+
+ command = PlaceInstance(
+ command_id=CommandId(),
+ model_card=model_card,
+ sharding=Sharding.Pipeline,
+ instance_meta=InstanceMeta.Tinygrad,
+ min_nodes=1,
+ )
+
+ placements = place_instance(command, topology, {}, node_memory, node_network)
+
+ assert len(placements) == 1
+ instance = list(placements.values())[0]
+ assert isinstance(instance, TinygradInstance)
diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py
index d71275b763..083baba074 100644
--- a/src/exo/routing/router.py
+++ b/src/exo/routing/router.py
@@ -221,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
- return Keypair.generate()
+ return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -241,6 +241,6 @@ def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
- keypair = Keypair.generate_ed25519()
+ keypair = Keypair.generate()
f.write(keypair.to_bytes())
return keypair
diff --git a/src/exo/shared/architecture/__init__.py b/src/exo/shared/architecture/__init__.py
new file mode 100644
index 0000000000..670bdeaef8
--- /dev/null
+++ b/src/exo/shared/architecture/__init__.py
@@ -0,0 +1,72 @@
+from typing import Any, Literal
+
+from pydantic import BaseModel
+
+AttentionType = Literal["multi_head", "grouped_query", "multi_latent"]
+MLPType = Literal["swiglu", "moe_top_k"]
+NormType = Literal["rms_norm", "layer_norm"]
+RoPEType = Literal["standard", "ntk_aware", "yarn"]
+
+class ArchitectureSpec(BaseModel, frozen=True, strict=True):
+ name: str
+ attention_type: AttentionType
+ mlp_type: MLPType
+ norm_type: NormType
+ rope_type: RoPEType
+
+ # Weight key templates ({layer_idx} per layer)
+ layer_prefix: str
+ q_proj_key: str
+ k_proj_key: str
+ v_proj_key: str
+ o_proj_key: str
+ gate_proj_key: str
+ up_proj_key: str
+ down_proj_key: str
+ input_norm_key: str
+ post_attn_norm_key: str
+ final_norm_key: str
+ lm_head_key: str
+
+ # Optional normalised keys
+
+ q_norm_key: str | None = None
+ k_norm_key: str | None = None
+
+ # MoE (optional, follow-up)
+ router_key: str | None = None
+ expert_prefix: str | None = None
+
+ # MLA (optional, follow-up)
+ kv_lora_rank: int | None = None
+ q_lora_rank: int | None = None
+
+ # Embedding key
+ embed_key: str | None = None
+
+ # RoPE default
+ rope_theta: float = 10000.0
+
+ # Tokenizer
+ tokenizer_type: Literal["huggingface", "tiktoken"] = "huggingface"
+
+ARCHITECTURE_REGISTRY: dict[str, ArchitectureSpec] = {}
+
+def register(huggingface_name: str, spec: ArchitectureSpec) -> None:
+ ARCHITECTURE_REGISTRY[huggingface_name] = spec
+
+def detect_architecture(raw_config: dict[str, Any]) -> ArchitectureSpec:
+ architectures: list[str] = raw_config.get("architectures", []) # pyright: ignore[reportAny]
+ for arch_name in architectures:
+ if arch_name in ARCHITECTURE_REGISTRY:
+ return ARCHITECTURE_REGISTRY[arch_name]
+
+ model_type: str = raw_config.get("model_type", "") # pyright: ignore[reportAny]
+ if model_type in ARCHITECTURE_REGISTRY:
+ return ARCHITECTURE_REGISTRY[model_type]
+ raise ValueError(
+ f"Unsupported Architecture: {raw_config.get('architectures')}"
+ )
+
+from . import llama as _llama # noqa: F401, E402 # pyright: ignore[reportUnusedImport]
+from . import qwen as _qwen # noqa: F401, E402 # pyright: ignore[reportUnusedImport]
diff --git a/src/exo/shared/architecture/config.py b/src/exo/shared/architecture/config.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/exo/shared/architecture/llama.py b/src/exo/shared/architecture/llama.py
new file mode 100644
index 0000000000..c1fedaf9a8
--- /dev/null
+++ b/src/exo/shared/architecture/llama.py
@@ -0,0 +1,25 @@
+from . import ArchitectureSpec, register
+
+LLAMA_SPEC = ArchitectureSpec(
+ name = "llama",
+ attention_type = "grouped_query",
+ mlp_type = "swiglu",
+ norm_type = "rms_norm",
+ rope_type = "standard",
+ layer_prefix = "model.layers.{layer_idx}",
+ q_proj_key = "self_attn.q_proj",
+ k_proj_key = "self_attn.k_proj",
+ v_proj_key = "self_attn.v_proj",
+ o_proj_key = "self_attn.o_proj",
+ gate_proj_key = "mlp.gate_proj",
+ up_proj_key = "mlp.up_proj",
+ down_proj_key = "mlp.down_proj",
+ input_norm_key = "input_layernorm",
+ post_attn_norm_key = "post_attention_layernorm",
+ embed_key = "model.embed_tokens",
+ final_norm_key = "model.norm",
+ lm_head_key = "lm_head",
+)
+
+register("LlamaForCausalLM", LLAMA_SPEC)
+register("MistralForCausalLM", LLAMA_SPEC)
diff --git a/src/exo/shared/architecture/qwen.py b/src/exo/shared/architecture/qwen.py
new file mode 100644
index 0000000000..8f698127fe
--- /dev/null
+++ b/src/exo/shared/architecture/qwen.py
@@ -0,0 +1,48 @@
+from . import ArchitectureSpec, register
+
+QWEN_DENSE_SPEC = ArchitectureSpec(
+ name = "llama",
+ attention_type = "grouped_query",
+ mlp_type = "swiglu",
+ norm_type = "rms_norm",
+ rope_type = "standard",
+ layer_prefix = "model.layers.{layer_idx}",
+ q_proj_key = "self_attn.q_proj",
+ k_proj_key = "self_attn.k_proj",
+ v_proj_key = "self_attn.v_proj",
+ o_proj_key = "self_attn.o_proj",
+ gate_proj_key = "mlp.gate_proj",
+ up_proj_key = "mlp.up_proj",
+ down_proj_key = "mlp.down_proj",
+ input_norm_key = "input_layernorm",
+ post_attn_norm_key = "post_attention_layernorm",
+ embed_key = "model.embed_tokens",
+ final_norm_key = "model.norm",
+ lm_head_key = "lm_head",
+)
+
+QWEN3_DENSE_SPEC = ArchitectureSpec(
+ name = "llama",
+ attention_type = "grouped_query",
+ mlp_type = "swiglu",
+ norm_type = "rms_norm",
+ rope_type = "standard",
+ layer_prefix = "model.layers.{layer_idx}",
+ q_proj_key = "self_attn.q_proj",
+ k_proj_key = "self_attn.k_proj",
+ v_proj_key = "self_attn.v_proj",
+ o_proj_key = "self_attn.o_proj",
+ gate_proj_key = "mlp.gate_proj",
+ up_proj_key = "mlp.up_proj",
+ down_proj_key = "mlp.down_proj",
+ input_norm_key = "input_layernorm",
+ post_attn_norm_key = "post_attention_layernorm",
+ embed_key = "model.embed_tokens",
+ final_norm_key = "model.norm",
+ lm_head_key = "lm_head",
+ q_norm_key = "self_attn.q_norm",
+ k_norm_key = "self_attn.k_norm",
+)
+
+register("Qwen2ForCausalLM", QWEN_DENSE_SPEC)
+register("Qwen3ForCausalLM", QWEN3_DENSE_SPEC)
diff --git a/src/exo/shared/model_config.py b/src/exo/shared/model_config.py
new file mode 100644
index 0000000000..0abd6509ce
--- /dev/null
+++ b/src/exo/shared/model_config.py
@@ -0,0 +1,77 @@
+import json
+from pathlib import Path
+from typing import Any
+
+from pydantic import BaseModel, PositiveInt
+
+from exo.shared.architecture import ArchitectureSpec, detect_architecture
+
+
+class QuantizationConfig(BaseModel, frozen=True, strict=True):
+ bits: int
+ group_size: int
+
+class ModelConfig(BaseModel, frozen=True, strict=True):
+ architecture_spec: ArchitectureSpec
+ num_hidden_layers: PositiveInt
+ hidden_size: PositiveInt
+ intermediate_size: PositiveInt
+ num_attention_heads: PositiveInt
+ num_key_value_heads: PositiveInt
+ vocab_size: PositiveInt
+ head_dim: PositiveInt
+ rope_theta: float
+ rope_scaling: dict[str, Any] | None
+ max_position_embeddings: PositiveInt
+ rms_norm_eps: float
+ tie_word_embeddings: bool
+ quantization_config: QuantizationConfig | None
+
+def parse_model_config(config_path: Path) -> ModelConfig: # noqa: C901
+ with open(config_path) as f:
+ raw: dict[str, Any] = json.load(f) # pyright: ignore[reportAny]
+
+ arch_spec = detect_architecture(raw)
+ hidden_size: int = raw["hidden_size"] # pyright: ignore[reportAny]
+ num_attention_heads: int = raw["num_attention_heads"] # pyright: ignore[reportAny]
+ num_key_value_heads: int = raw.get("num_key_value_heads", num_attention_heads) # pyright: ignore[reportAny]
+
+ intermediate_size: int = raw.get("intermediate_size", hidden_size * 4) # pyright: ignore[reportAny]
+
+ head_dim: int = raw.get("head_dim", hidden_size // num_attention_heads) # pyright: ignore[reportAny]
+ num_hidden_layers: int = raw.get("num_hidden_layers") or raw.get("num_heads") or 32
+
+ rope_theta: float = raw.get("rope_theta", arch_spec.rope_theta) # pyright: ignore[reportAny]
+
+ max_position_embeddings: int = raw.get("max_position_embeddings", 4096) # pyright: ignore[reportAny]
+ rms_norm_eps: float = raw.get("rms_norm_eps", 1e-6) # pyright: ignore[reportAny]
+ tie_word_embeddings: bool = raw.get("tie_word_embeddings", False) # pyright: ignore[reportAny]
+
+ quant_raw: dict[str, Any] | None = raw.get("quantization_config")
+ quantization_config: QuantizationConfig | None = None
+
+ if quant_raw and "bits" in quant_raw and "group_size" in quant_raw:
+ quantization_config = QuantizationConfig(
+ bits = int(quant_raw["bits"]), # pyright: ignore[reportAny]
+ group_size = int(quant_raw["group_size"]), # pyright: ignore[reportAny]
+ )
+
+ vocab_size: int = raw["vocab_size"] # pyright: ignore[reportAny]
+ rope_scaling: dict[str, Any] | None = raw.get("rope_scaling")
+
+ return ModelConfig(
+ architecture_spec = arch_spec,
+ num_attention_heads = num_attention_heads,
+ num_hidden_layers = num_hidden_layers,
+ num_key_value_heads = num_key_value_heads,
+ hidden_size = hidden_size,
+ intermediate_size = intermediate_size,
+ vocab_size = vocab_size,
+ head_dim = head_dim,
+ rope_theta = rope_theta,
+ rope_scaling = rope_scaling,
+ max_position_embeddings = max_position_embeddings,
+ rms_norm_eps = rms_norm_eps,
+ tie_word_embeddings = tie_word_embeddings,
+ quantization_config = quantization_config,
+ )
diff --git a/src/exo/shared/tests/test_architecture.py b/src/exo/shared/tests/test_architecture.py
new file mode 100644
index 0000000000..c977f49f23
--- /dev/null
+++ b/src/exo/shared/tests/test_architecture.py
@@ -0,0 +1,158 @@
+import json
+from pathlib import Path
+
+import pytest
+from pydantic import ValidationError
+
+
+def test_registered_architecture():
+ """
+ Tests the following architecture to be present in the
+ architecture registry:
+ - Llama
+ - Qwen
+ - Mistral
+ """
+
+ from exo.shared.architecture import ARCHITECTURE_REGISTRY
+
+ architectures = [
+ "LlamaForCausalLM",
+ "Qwen2ForCausalLM",
+ "MistralForCausalLM",
+ ]
+
+ assert all(arch in ARCHITECTURE_REGISTRY for arch in architectures)
+
+def test_detect_llama():
+ """
+ detect_architecture() should find Llama from a new config
+ dict.
+ """
+
+ from exo.shared.architecture import detect_architecture
+
+ raw = {
+ "architectures": ["LlamaForCausalLM"],
+ "model_type": "llama",
+ }
+
+ spec = detect_architecture(raw)
+
+ assert spec.name == "llama"
+ assert spec.attention_type == "grouped_query"
+ assert spec.mlp_type == "swiglu"
+ assert spec.norm_type == "rms_norm"
+
+def test_detect_unknown_raises():
+ """
+ Unsupported architectures should raise ValueError.
+ """
+ from exo.shared.architecture import detect_architecture
+
+ with pytest.raises(ValueError, match="Unsupported Architecture"):
+ detect_architecture({"architectures": ["FakeModelForCausalLM"]})
+
+def test_architecture_spec_is_frozen():
+ """
+ Architectures must be immutable
+ """
+ from exo.shared.architecture import ARCHITECTURE_REGISTRY
+
+ spec = ARCHITECTURE_REGISTRY["LlamaForCausalLM"]
+ with pytest.raises(ValidationError):
+ spec.name = "modified" # type: ignore[misc]
+
+
+def test_parse_model_config(tmp_path: Path):
+ """
+ parse_mode_config() must produce a correct ModelConfig
+ from JSON.
+ """
+
+ from exo.shared.model_config import parse_model_config
+
+ config = {
+ "architectures" : ["LlamaForCausalLM"],
+ "hidden_size": 2048,
+ "intermediate_size": 8192,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "num_hidden_layers": 16,
+ "vocab_size": 128256,
+ "rope_theta": 500000.0,
+ "max_position_embeddings": 4096,
+ "rms_norm_eps": 1e-5,
+ }
+
+ config_path = tmp_path / "config.json"
+ config_path.write_text(json.dumps(config))
+ parsed = parse_model_config(config_path)
+ assert parsed.num_attention_heads == 32
+ assert parsed.num_key_value_heads == 8
+ assert parsed.head_dim == 64 # 2048 // 32
+ assert parsed.hidden_size == 2048
+ assert parsed.quantization_config is None
+
+
+def test_parse_model_config_with_quantization(tmp_path: Path):
+ """
+ Quantized models should populate quantization_config.
+ """
+ from exo.shared.model_config import parse_model_config
+
+ config = {
+ "architectures": ["LlamaForCausalLM"],
+ "hidden_size": 2048,
+ "intermediate_size": 8192,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "num_hidden_layers": 16,
+ "vocab_size": 128256,
+ "quantization_config": {"bits": 4, "group_size": 64},
+ }
+ config_path = tmp_path / "config.json"
+ config_path.write_text(json.dumps(config))
+ parsed = parse_model_config(config_path)
+ assert parsed.quantization_config is not None
+ assert parsed.quantization_config.bits == 4
+ assert parsed.quantization_config.group_size == 64
+
+
+def test_parse_model_config_defaults(tmp_path: Path):
+ """
+ Missing optional fields should use sensible defaults.
+ """
+ from exo.shared.model_config import parse_model_config
+
+ config = {
+ "architectures": ["LlamaForCausalLM"],
+ "hidden_size": 2048,
+ "num_attention_heads": 32,
+ "vocab_size": 128256,
+ }
+ config_path = tmp_path / "config.json"
+ config_path.write_text(json.dumps(config))
+ parsed = parse_model_config(config_path)
+ assert parsed.num_key_value_heads == 32 # defaults to num_attention_heads
+ assert parsed.tie_word_embeddings is False
+ assert parsed.rope_scaling is None
+
+
+def test_model_config_is_frozen(tmp_path: Path):
+ """
+ ModelConfig must be immutable (frozen=True).
+ """
+ from exo.shared.model_config import parse_model_config
+
+ config = {
+ "architectures": ["LlamaForCausalLM"],
+ "hidden_size": 2048,
+ "num_attention_heads": 32,
+ "vocab_size": 128256,
+ }
+ config_path = tmp_path / "config.json"
+ config_path.write_text(json.dumps(config))
+ parsed = parse_model_config(config_path)
+ with pytest.raises(ValidationError):
+ parsed.hidden_size = 4096 # type: ignore[misc]
diff --git a/src/exo/shared/tests/test_tokenizer_shared.py b/src/exo/shared/tests/test_tokenizer_shared.py
new file mode 100644
index 0000000000..29798441ad
--- /dev/null
+++ b/src/exo/shared/tests/test_tokenizer_shared.py
@@ -0,0 +1,88 @@
+from typing import Any
+from unittest.mock import MagicMock
+
+from exo.shared.types.common import ModelId
+
+
+def test_kimi_eos() -> None:
+ """Kimi K2 has a known EOS token ID."""
+ from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model
+
+ assert get_eos_token_ids_for_model(ModelId("moonshotai/Kimi-K2-Instruct")) == [163586]
+
+
+def test_glm_flash_eos() -> None:
+ """GLM-4.7-flash has specific EOS token IDs."""
+ from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model
+
+ result = get_eos_token_ids_for_model(ModelId("THUDM/glm-4.7-flash"))
+ assert result == [154820, 154827, 154829]
+
+
+def test_glm_generic_eos() -> None:
+ """Generic GLM models have different EOS tokens than flash."""
+ from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model
+
+ result = get_eos_token_ids_for_model(ModelId("THUDM/glm-4-9b"))
+ assert result == [151336, 151329, 151338]
+
+
+def test_unknown_model_returns_none() -> None:
+ """Unknown models should return None (use tokenizer default)."""
+ from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model
+
+ assert get_eos_token_ids_for_model(ModelId("unknown-org/UnknownModel-7B")) is None
+
+
+def test_chat_template_chatml_fallback() -> None:
+ """Without apply_chat_template, fall back to ChatML format."""
+ from exo.shared.tokenizer.chat_template import apply_chat_template
+ from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
+
+ tokenizer = MagicMock(spec=[])
+
+ task = TextGenerationTaskParams(
+ model=ModelId("test-model"),
+ input=[InputMessage(role="user", content="Hello")],
+ )
+
+ result: Any = apply_chat_template(tokenizer, task)
+ assert "<|im_start|>user" in result
+ assert "Hello" in result
+ assert "<|im_end|>" in result
+ assert result.endswith("<|im_start|>assistant\n") # pyright: ignore[reportAny]
+
+
+def test_chat_template_delegates_to_tokenizer() -> None:
+ """When tokenizer has apply_chat_template, use it."""
+ from exo.shared.tokenizer.chat_template import apply_chat_template
+ from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
+
+ tokenizer: Any = MagicMock()
+ tokenizer.apply_chat_template.return_value = "" # pyright: ignore[reportAny]
+
+ task = TextGenerationTaskParams(
+ model=ModelId("test-model"),
+ input=[InputMessage(role="user", content="Hello")],
+ )
+
+ result: Any = apply_chat_template(tokenizer, task)
+ assert result == ""
+ tokenizer.apply_chat_template.assert_called_once() # pyright: ignore[reportAny]
+
+
+def test_normalize_tool_calls_parses_json_string() -> None:
+ """Tool call arguments stored as strings should be parsed to dicts."""
+ from exo.shared.tokenizer.chat_template import (
+ _normalize_tool_calls, # pyright: ignore[reportPrivateUsage]
+ )
+
+ msg: dict[str, Any] = {
+ "role": "assistant",
+ "tool_calls": [
+ {"function": {"name": "get_weather", "arguments": '{"city": "London"}'}}
+ ],
+ }
+ _normalize_tool_calls(msg)
+ tool_calls: Any = msg["tool_calls"] # pyright: ignore[reportAny]
+ assert tool_calls[0]["function"]["arguments"] == {"city": "London"}
diff --git a/src/exo/shared/tokenizer/chat_template.py b/src/exo/shared/tokenizer/chat_template.py
new file mode 100644
index 0000000000..bee80ffe4f
--- /dev/null
+++ b/src/exo/shared/tokenizer/chat_template.py
@@ -0,0 +1,84 @@
+import contextlib
+import json
+from typing import Any, cast
+
+from loguru import logger
+
+from exo.shared.types.text_generation import TextGenerationTaskParams
+
+
+def apply_chat_template(
+ tokenizer: Any, # pyright: ignore[reportAny]
+ task_params: TextGenerationTaskParams,
+) -> str:
+ """Convert TextGenerationTaskParams to a chat template prompt.
+
+ Converts the internal format (input + instructions) to a messages list
+ that can be processed by the tokenizer's chat template.
+
+ When chat_template_messages is available (from Chat Completions API),
+ uses those directly to preserve tool_calls, thinking, and other fields.
+ Otherwise builds messages from the task params input/instructions.
+ """
+ formatted_messages: list[dict[str, Any]] = []
+ if task_params.chat_template_messages is not None:
+ formatted_messages = list(task_params.chat_template_messages)
+ for msg in formatted_messages:
+ _normalize_tool_calls(msg)
+ else:
+ if task_params.instructions:
+ formatted_messages.append(
+ {"role": "system", "content": task_params.instructions}
+ )
+
+ for msg in task_params.input:
+ if not msg.content:
+ logger.warning("Received message with empty content, skipping")
+ continue
+ formatted_messages.append({"role": msg.role, "content": msg.content})
+
+ # For assistant prefilling, append content after templating to avoid a closing turn token.
+ partial_assistant_content: str | None = None
+ if formatted_messages and formatted_messages[-1].get("role") == "assistant":
+ partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
+ formatted_messages = formatted_messages[:-1]
+
+ if hasattr(tokenizer, "apply_chat_template"): # pyright: ignore[reportAny]
+ prompt: str = tokenizer.apply_chat_template( # pyright: ignore[reportAny]
+ formatted_messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ tools=task_params.tools,
+ )
+ if partial_assistant_content:
+ prompt += partial_assistant_content
+ return prompt
+
+ # Fallback: ChatML format
+ parts: list[str] = []
+ for msg in formatted_messages:
+ parts.append(f"<|im_start|>{msg['role']}\n{msg.get('content', '')}<|im_end|>\n")
+ parts.append("<|im_start|>assistant\n")
+ result = "".join(parts)
+ if partial_assistant_content:
+ result += partial_assistant_content
+ return result
+
+
+def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
+ tool_calls: list[Any] | None = msg_dict.get("tool_calls")
+ if not tool_calls:
+ return
+
+ for tool_call in tool_calls: # pyright: ignore[reportAny]
+ if not isinstance(tool_call, dict):
+ continue
+
+ func: dict[str, Any] | None = tool_call.get("function") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
+ if not isinstance(func, dict):
+ continue
+
+ args: Any = func.get("arguments") # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
+ if isinstance(args, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ func["arguments"] = json.loads(args)
diff --git a/src/exo/shared/tokenizer/eos_tokens.py b/src/exo/shared/tokenizer/eos_tokens.py
new file mode 100644
index 0000000000..4f05ccbb3e
--- /dev/null
+++ b/src/exo/shared/tokenizer/eos_tokens.py
@@ -0,0 +1,16 @@
+from exo.shared.models.model_cards import ModelId
+
+
+def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
+ model_id_lower = model_id.lower()
+
+ if "kimi-k2" in model_id_lower:
+ return [163586]
+ elif "glm-4.7-flash" in model_id_lower:
+ return [154820, 154827, 154829]
+ elif "glm" in model_id_lower:
+ return [151336, 151329, 151338]
+ elif "llama" in model_id_lower:
+ return [128001, 128008, 128009]
+
+ return None
diff --git a/src/exo/shared/tokenizer/loader.py b/src/exo/shared/tokenizer/loader.py
new file mode 100644
index 0000000000..fb56f54950
--- /dev/null
+++ b/src/exo/shared/tokenizer/loader.py
@@ -0,0 +1,40 @@
+import sys
+from pathlib import Path
+from typing import Any
+
+from exo.shared.models.model_cards import ModelId
+
+
+def _apply_transformers_5x_compact() -> None:
+ """Monkey patch for transformers 5.x compatibility (Kimi tokenizer)."""
+ try:
+ import transformers.models.gpt2.tokenization_gpt2 as gpt2_tok
+ from transformers.convert_slow_tokenizer import bytes_to_unicode
+
+ if not hasattr(gpt2_tok, "bytes_to_unicode"):
+ gpt2_tok.bytes_to_unicode = bytes_to_unicode # pyright: ignore[reportAttributeAccessIssue]
+ except ImportError:
+ pass
+
+
+def load_tokenizer_for_model(model_id: ModelId, model_path: Path) -> Any: # pyright: ignore[reportAny]
+ model_id_lower = model_id.lower()
+
+ if "kimi-k2" in model_id_lower:
+ _apply_transformers_5x_compact()
+ sys.path.insert(0, str(model_path))
+
+ from tokenization_kimi import ( # pyright: ignore[reportMissingImports]
+ TikTokenTokenizer, # pyright: ignore[reportUnknownVariableType]
+ )
+ hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
+
+ def _patched_encode(text: str, **_kwargs: object) -> list[int]:
+ return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
+
+ hf_tokenizer.encode = _patched_encode
+ return hf_tokenizer # pyright: ignore[reportUnknownVariableType]
+
+ from transformers import AutoTokenizer
+
+ return AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
diff --git a/src/exo/shared/tokenizer/wrapper.py b/src/exo/shared/tokenizer/wrapper.py
new file mode 100644
index 0000000000..8e296b0000
--- /dev/null
+++ b/src/exo/shared/tokenizer/wrapper.py
@@ -0,0 +1,102 @@
+from collections.abc import Callable
+from typing import Any, final
+
+_THINK_TOKENS: list[tuple[str, str]] = [
+ ("", ""),
+ ("", ""),
+]
+
+
+@final
+class TokenizerWrapper:
+ """Engine-agnostic tokenizer wrapper that satisfies the Tokenizer protocol.
+
+ Wraps a raw HuggingFace tokenizer and adds tool-calling and thinking
+ properties expected by the runner. All other attribute access is proxied
+ to the underlying tokenizer.
+ """
+
+ _tokenizer: Any
+ _tool_call_start: str | None
+ _tool_call_end: str | None
+ _tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None
+ _think_start: str | None
+ _think_start_id: int | None
+ _think_end_id: int | None
+
+ def __init__(
+ self,
+ tokenizer: Any, # pyright: ignore[reportAny]
+ *,
+ tool_call_start: str | None = None,
+ tool_call_end: str | None = None,
+ tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None = None,
+ ) -> None:
+ object.__setattr__(self, "_tokenizer", tokenizer)
+ object.__setattr__(self, "_tool_call_start", tool_call_start)
+ object.__setattr__(self, "_tool_call_end", tool_call_end)
+ object.__setattr__(self, "_tool_parser", tool_parser)
+
+ # Detect thinking tokens from vocabulary
+ think_start: str | None = None
+ think_start_id: int | None = None
+ think_end_id: int | None = None
+ vocab: dict[str, int] = tokenizer.get_vocab() # pyright: ignore[reportAny]
+ for start_tok, end_tok in _THINK_TOKENS:
+ if start_tok in vocab and end_tok in vocab:
+ think_start = start_tok
+ think_start_id = vocab[start_tok]
+ think_end_id = vocab[end_tok]
+ break
+ object.__setattr__(self, "_think_start", think_start)
+ object.__setattr__(self, "_think_start_id", think_start_id)
+ object.__setattr__(self, "_think_end_id", think_end_id)
+
+ # Disable tool calling if tokens aren't in the vocabulary
+ if (tool_call_start and tool_call_start not in vocab) or (
+ tool_call_end and tool_call_end not in vocab
+ ):
+ object.__setattr__(self, "_tool_call_start", None)
+ object.__setattr__(self, "_tool_call_end", None)
+ object.__setattr__(self, "_tool_parser", None)
+
+ @property
+ def has_tool_calling(self) -> bool:
+ return self._tool_call_start is not None
+
+ @property
+ def has_call_start(self) -> bool:
+ return self._tool_call_start is not None
+
+ @property
+ def tool_call_start(self) -> str | None:
+ return self._tool_call_start
+
+ @property
+ def tool_call_end(self) -> str | None:
+ return self._tool_call_end
+
+ @property
+ def tool_parser(self) -> Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None:
+ return self._tool_parser
+
+ @property
+ def think_start(self) -> str | None:
+ return self._think_start
+
+ @property
+ def think_start_id(self) -> int | None:
+ return self._think_start_id
+
+ @property
+ def think_end(self) -> int | None:
+ return self._think_end_id
+
+ def __getattr__(self, name: str) -> Any: # pyright: ignore[reportAny]
+ return getattr(self._tokenizer, name) # pyright: ignore[reportAny]
+
+ def __setattr__(self, name: str, value: Any) -> None: # pyright: ignore[reportAny]
+ if name.startswith("_"):
+ object.__setattr__(self, name, value)
+ else:
+ setattr(self._tokenizer, name, value) # pyright: ignore[reportAny]
diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py
index 23ca9b7b6a..842a5e90d4 100644
--- a/src/exo/shared/types/api.py
+++ b/src/exo/shared/types/api.py
@@ -8,7 +8,12 @@
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
-from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
+from exo.shared.types.worker.instances import (
+ Instance,
+ InstanceId,
+ InstanceMeta,
+ default_instance_meta,
+)
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
@@ -224,7 +229,7 @@ class HuggingFaceSearchResult(BaseModel):
class PlaceInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
- instance_meta: InstanceMeta = InstanceMeta.MlxRing
+ instance_meta: InstanceMeta = Field(default_factory=default_instance_meta)
min_nodes: int = 1
diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py
index 76bd6fd4e6..f504ec12c8 100644
--- a/src/exo/shared/types/worker/instances.py
+++ b/src/exo/shared/types/worker/instances.py
@@ -1,3 +1,4 @@
+import sys
from enum import Enum
from pydantic import model_validator
@@ -15,6 +16,13 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
+ Tinygrad = "Tinygrad"
+
+
+def default_instance_meta() -> InstanceMeta:
+ if sys.platform == "darwin":
+ return InstanceMeta.MlxRing
+ return InstanceMeta.Tinygrad
class BaseInstance(TaggedModel):
@@ -34,9 +42,12 @@ class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
+class TinygradInstance(BaseInstance):
+ hosts_by_node: dict[NodeId, list[Host]] | None = None
+ ephemeral_port: int | None = None
# TODO: Single node instance
-Instance = MlxRingInstance | MlxJacclInstance
+Instance = MlxRingInstance | MlxJacclInstance | TinygradInstance
class BoundInstance(CamelCaseModel):
diff --git a/src/exo/shared/types/worker/tokenizer.py b/src/exo/shared/types/worker/tokenizer.py
new file mode 100644
index 0000000000..28b71d093a
--- /dev/null
+++ b/src/exo/shared/types/worker/tokenizer.py
@@ -0,0 +1,36 @@
+from collections.abc import Callable
+from typing import Any, Protocol
+
+
+class Tokenizer(Protocol):
+ @property
+ def has_tool_calling(self) -> bool: ...
+
+ @property
+ def has_call_start(self) -> bool: ...
+
+ @property
+ def tool_call_start(self) -> str | None: ...
+
+ @property
+ def tool_call_end(self) -> str | None: ...
+
+ @property
+ def tool_parser(self) -> Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None: ...
+
+ @property
+ def think_start(self) -> str | None: ...
+
+ @property
+ def think_start_id(self) -> int | None: ...
+
+ @property
+ def think_end(self) -> int | None: ...
+
+
+class MutableTokenizer(Tokenizer, Protocol):
+ """Extended protocol for tokenizers whose tool-calling fields can be reassigned."""
+
+ _tool_call_start: str | None
+ _tool_call_end: str | None
+ _tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None
diff --git a/src/exo/worker/engines/engine_factory.py b/src/exo/worker/engines/engine_factory.py
new file mode 100644
index 0000000000..2500997c4e
--- /dev/null
+++ b/src/exo/worker/engines/engine_factory.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+from collections.abc import Callable, Generator
+from typing import Any
+
+from pydantic import BaseModel
+
+from exo.shared.types.worker.instances import BoundInstance
+from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
+
+"""
+The annotation frozen=True attribute makes the class immutable, preventing
+side-effects during runtime. This is critical since we are working with
+heterogenuous consumer-grade devices.
+"""
+
+class Engine(BaseModel, frozen=True):
+ # (BoundInstance) -> context
+ initialize: Callable[[BoundInstance], Any]
+
+ # (BoundInstance, context) -> model, tokenizer
+ load: Callable[..., tuple[Any, Any]]
+
+ # (model, tokenizer, task, prompt) -> Generator[GenerationResponse]
+ generate: Callable[..., Generator[GenerationResponse | ToolCallResponse]]
+
+ apply_chat_template: Callable[..., str]
+ detect_thinking_prompt_suffix: Callable[..., bool]
+
+ # (model, tokenizer) -> initialize
+ warmup: Callable[..., int]
+ cleanup: Callable[[], None]
+
+def create_engine(bound_instance: BoundInstance) -> Engine:
+ from exo.shared.types.worker.instances import (
+ MlxJacclInstance,
+ MlxRingInstance,
+ TinygradInstance,
+ )
+
+ match bound_instance.instance:
+ case MlxRingInstance() | MlxJacclInstance():
+ # Lazy import - MLX must be loaded only on MacOS.
+ from exo.worker.engines.mlx.generator.generate import (
+ mlx_generate_with_postprocessing,
+ warmup_inference,
+ )
+ from exo.worker.engines.mlx.utils_mlx import (
+ apply_chat_template,
+ detect_thinking_prompt_suffix,
+ initialize_mlx,
+ load_mlx_items,
+ )
+
+ return Engine(
+ initialize = initialize_mlx,
+ load = load_mlx_items,
+ generate=mlx_generate_with_postprocessing,
+ apply_chat_template=apply_chat_template,
+ detect_thinking_prompt_suffix=detect_thinking_prompt_suffix,
+ warmup = warmup_inference,
+ cleanup = _mlx_cleanup,
+ )
+
+ case TinygradInstance():
+ from exo.shared.tokenizer.chat_template import (
+ apply_chat_template,
+ )
+ from exo.worker.engines.tinygrad.generator.generate import (
+ tinygrad_generate,
+ warmup_inference,
+ )
+ from exo.worker.engines.tinygrad.utils_tinygrad import (
+ initialize_tinygrad,
+ load_tinygrad_items,
+ )
+
+ return Engine(
+ initialize=initialize_tinygrad,
+ load=load_tinygrad_items,
+ generate=tinygrad_generate,
+ apply_chat_template=apply_chat_template,
+ detect_thinking_prompt_suffix=_tinygrad_detect_thinking,
+ warmup=warmup_inference,
+ cleanup=_tinygrad_cleanup,
+ )
+
+ case _: # pyright: ignore[reportUnnecessaryComparison]
+ raise ValueError(f"Unsupported Instance: {type(bound_instance.instance)}")
+
+
+def _mlx_cleanup() -> None:
+ from mlx.core import clear_cache
+ clear_cache()
+
+
+def _tinygrad_cleanup() -> None:
+ from exo.worker.engines.tinygrad.generator.generate import cleanup_jit_state
+ cleanup_jit_state()
+
+
+def _tinygrad_detect_thinking(prompt: str, tokenizer: Any) -> bool: # pyright: ignore[reportAny]
+ return False
diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py
index 6e92620513..d6e83fe1f1 100644
--- a/src/exo/worker/engines/mlx/generator/generate.py
+++ b/src/exo/worker/engines/mlx/generator/generate.py
@@ -1,18 +1,27 @@
import time
from copy import deepcopy
-from typing import Callable, Generator, cast, get_args
+from functools import cache
+from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
+from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
+from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
+ HarmonyEncodingName,
+ Role,
+ StreamableParser,
+ load_harmony_encoding,
+)
from exo.shared.types.api import (
CompletionTokensDetails,
FinishReason,
GenerationStats,
PromptTokensDetails,
+ ToolCallItem,
TopLogprobItem,
Usage,
)
@@ -22,6 +31,7 @@
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
+ ToolCallResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import set_pipeline_prefill
@@ -494,3 +504,247 @@ def mlx_generate(
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
+
+
+def mlx_generate_with_postprocessing(
+ model: Model,
+ tokenizer: TokenizerWrapper,
+ task: TextGenerationTaskParams,
+ prompt: str,
+ model_id: str,
+ kv_prefix_cache: KVPrefixCache | None = None,
+ group: mx.distributed.Group | None = None,
+) -> Generator[GenerationResponse | ToolCallResponse]:
+ """
+ This wrapper function for mlx_generate() includes model specific
+ post-processing required by GPT-OSS, Kimi and GLM.
+
+ This will ensure engine-dependent post-processing to be contained
+ within the engine package, requiring minimal changes in `runner.py`
+ file.
+ """
+
+ gen = mlx_generate(model=model, tokenizer=tokenizer, task=task,
+ prompt=prompt, kv_prefix_cache=kv_prefix_cache, group=group)
+
+ # Kimi-K2 has tool call sections - we don't care about them
+ if "kimi" in model_id.lower():
+ gen = filter_kimi_tokens(gen)
+ patch_kimi_tokenizer(tokenizer)
+
+ # GLM models need patched parser (upstream has bug with None regex match)
+ elif "glm" in model_id.lower():
+ patch_glm_tokenizer(tokenizer)
+
+ # GPT-OSS specific parsing to match other model formats.
+ elif isinstance(model, GptOssModel):
+ gen = parse_gpt_oss(gen)
+
+ return gen
+
+
+def parse_gpt_oss(
+ responses: Generator[GenerationResponse | ToolCallResponse],
+) -> Generator[GenerationResponse | ToolCallResponse]:
+ encoding = get_gpt_oss_encoding() # pyright: ignore[reportAny]
+ stream: Any = StreamableParser(encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
+ thinking = False
+ current_tool_name: str | None = None
+ tool_arg_parts: list[str] = []
+
+ for response in responses:
+ assert isinstance(response, GenerationResponse)
+ stream.process(response.token) # pyright: ignore[reportAny]
+
+ delta: str | None = stream.last_content_delta # pyright: ignore[reportAny]
+ ch: str | None = stream.current_channel # pyright: ignore[reportAny]
+ recipient: str | None = stream.current_recipient # pyright: ignore[reportAny]
+
+ if recipient != current_tool_name:
+ if current_tool_name is not None:
+ prefix = "functions."
+ if current_tool_name.startswith(prefix):
+ current_tool_name = current_tool_name[len(prefix) :]
+ yield ToolCallResponse(
+ tool_calls=[
+ ToolCallItem(
+ name=current_tool_name,
+ arguments="".join(tool_arg_parts).strip(),
+ )
+ ],
+ usage=response.usage,
+ )
+ tool_arg_parts = []
+ current_tool_name = recipient
+
+ # If inside a tool call, accumulate arguments
+ if current_tool_name is not None:
+ if delta:
+ tool_arg_parts.append(delta)
+ continue
+
+ if ch == "analysis" and not thinking:
+ thinking = True
+ yield response.model_copy(update={"text": ""})
+
+ if ch != "analysis" and thinking:
+ thinking = False
+ yield response.model_copy(update={"text": ""})
+
+ if delta:
+ yield response.model_copy(update={"text": delta})
+
+ if response.finish_reason is not None:
+ if thinking:
+ yield response.model_copy(update={"text": ""})
+ yield response
+
+
+@cache
+def get_gpt_oss_encoding() -> Any: # pyright: ignore[reportAny]
+ return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
+
+
+def filter_kimi_tokens(
+ responses: Generator[GenerationResponse | ToolCallResponse],
+) -> Generator[GenerationResponse]:
+ for resp in responses:
+ assert isinstance(resp, GenerationResponse)
+ if (
+ resp.text == "<|tool_calls_section_begin|>"
+ or resp.text == "<|tool_calls_section_end|>"
+ ):
+ continue
+ yield resp
+
+
+def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
+ """
+ Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
+ """
+ import ast
+ import json
+ from typing import Any
+
+ import regex as re
+
+ _func_name_regex = re.compile(r"^(.*?)", re.DOTALL)
+ _func_arg_regex = re.compile(
+ r"(.*?)(?:\n|\s)*(.*?)(?:|(?=)|$)",
+ re.DOTALL,
+ )
+
+ tool_call_start = ""
+ tool_call_end = ""
+
+ def _is_string_type(
+ tool_name: str,
+ arg_name: str,
+ tools: list[Any] | None,
+ ) -> bool:
+ if tools is None:
+ return False
+ for tool in tools: # pyright: ignore[reportAny]
+ func = tool["function"] # pyright: ignore[reportAny]
+ if func["name"] == tool_name:
+ params = func["parameters"] # pyright: ignore[reportAny]
+ if params is None:
+ return False
+ props = params.get("properties", {}) # pyright: ignore[reportAny]
+ arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
+ arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
+ return arg_type == "string" # pyright: ignore[reportAny]
+ return False
+
+ def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
+ try:
+ return json.loads(value) # pyright: ignore[reportAny]
+ except Exception:
+ pass
+ try:
+ return ast.literal_eval(value) # pyright: ignore[reportAny]
+ except Exception:
+ pass
+ return value
+
+
+ def parse_tool_call(text: str, tools: list[Any] | None = None):
+ func_name_match = _func_name_regex.search(text)
+ if func_name_match is None:
+ raise ValueError(f"Could not parse function name from tool call: {text!r}")
+ func_name = func_name_match.group(1)
+
+ pairs = _func_arg_regex.findall(text)
+ arg_dct: dict[str, Any] = {}
+ for key, value in pairs: # pyright: ignore[reportAny]
+ arg_key = key.strip() # pyright: ignore[reportAny]
+ arg_val = value.strip() # pyright: ignore[reportAny]
+ if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
+ arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
+ arg_dct[arg_key] = arg_val
+ return dict(name=func_name, arguments=arg_dct)
+
+ tokenizer._tool_call_start = tool_call_start
+ tokenizer._tool_call_end = tool_call_end
+ tokenizer._tool_parser = parse_tool_call
+
+
+def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
+ """
+ Version of to-be-upstreamed kimi-k2 tool parser
+ """
+ import ast
+ import json
+ from typing import Any
+
+ import regex as re
+
+ # kimi has a fixed function naming scheme, with a json formatted arg
+ # functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
+ # Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
+ _func_name_regex = re.compile(
+ r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
+ )
+ _func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
+
+ # kimi has a tool_calls_section - we're leaving this up to the caller to handle
+ tool_call_start = "<|tool_call_begin|>"
+ tool_call_end = "<|tool_call_end|>"
+
+ def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
+ try:
+ return json.loads(value) # pyright: ignore[reportAny]
+ except Exception:
+ pass
+
+ try:
+ return ast.literal_eval(value) # pyright: ignore[reportAny]
+ except Exception:
+ pass
+ return value
+
+ def parse_tool_call(text: str, tools: Any | None = None):
+ func_name_match = _func_name_regex.search(text)
+ if func_name_match is None:
+ raise ValueError(f"Could not parse function name from tool call: {text!r}")
+ original_func_name = func_name_match.group(1)
+ tool_id = func_name_match.group(2)
+ # strip off the `functions.` prefix, if it exists.
+ func_name = original_func_name[original_func_name.find(".") + 1 :]
+
+ func_args_match = _func_arg_regex.search(text)
+ if func_args_match is None:
+ raise ValueError(f"Could not parse function args from tool call: {text!r}")
+ func_args = func_args_match.group(1)
+ # the args should be valid json - no need to check against our tools to deserialize
+ arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
+
+ return dict(
+ id=f"{original_func_name}:{tool_id}",
+ name=func_name,
+ arguments=arg_dct, # pyright: ignore[reportAny]
+ )
+
+ tokenizer._tool_call_start = tool_call_start
+ tokenizer._tool_call_end = tool_call_end
+ tokenizer._tool_parser = parse_tool_call
diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py
index 48b902ff8c..bedece01d4 100644
--- a/src/exo/worker/engines/mlx/utils_mlx.py
+++ b/src/exo/worker/engines/mlx/utils_mlx.py
@@ -149,6 +149,9 @@ def mlx_distributed_init(
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
+ case _:
+ raise TypeError(f"Unsupported instance type for MLX: {type(bound_instance.instance)}")
+
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
diff --git a/src/exo/worker/engines/tinygrad/__init__.py b/src/exo/worker/engines/tinygrad/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/exo/worker/engines/tinygrad/cache.py b/src/exo/worker/engines/tinygrad/cache.py
new file mode 100644
index 0000000000..b837a03c3f
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/cache.py
@@ -0,0 +1,73 @@
+from tinygrad.dtype import dtypes
+from tinygrad.tensor import Tensor
+
+
+class KVCache:
+ def __init__(self,
+ num_layers: int,
+ num_kv_heads: int,
+ head_dim: int,
+ max_seq_len: int,
+ ) -> None:
+ self.keys: list[Tensor] = [Tensor.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtypes.float16) for _ in range(num_layers)] # pyright: ignore[reportUnknownMemberType]
+ self.values: list[Tensor] = [Tensor.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtypes.float16) for _ in range(num_layers)] # pyright: ignore[reportUnknownMemberType]
+
+ self.max_seq_len = max_seq_len
+ self._positions: Tensor = Tensor.arange(max_seq_len).reshape(1, 1, max_seq_len, 1) # pyright: ignore[reportUnknownMemberType]
+ self.col_indices: Tensor = Tensor.arange(max_seq_len).reshape(1, 1, 1, max_seq_len) # pyright: ignore[reportUnknownMemberType]
+
+ def update(
+ self,
+ layers_idx: int,
+ key: Tensor,
+ value: Tensor,
+ position: int | Tensor = 0,
+ ) -> tuple[Tensor, Tensor]:
+ seq_len = key.shape[2]
+
+ """
+ Mask: For positions:
+ [self.position, self.position + seq_len]
+
+ We are using mask + pad since tinygrad tensor
+ cannot handle slice assignment.
+ """
+
+ positions = self._positions
+
+ if isinstance(position, Tensor):
+ mask = positions == position
+ self.keys[layers_idx] = Tensor.where(
+ mask, key.half(), self.keys[layers_idx]
+ )
+ self.values[layers_idx] = Tensor.where(
+ mask, value.half(), self.values[layers_idx]
+ )
+ else:
+ mask = (positions >= position) & (positions < position + seq_len)
+ pad_prev = position
+ pad_next = self.max_seq_len - position - seq_len
+ new_k = key.pad(
+ ((0, 0), (0, 0), (pad_prev, pad_next), (0, 0))
+ ).half()
+ new_v = value.pad(
+ ((0, 0), (0, 0), (pad_prev, pad_next), (0, 0))
+ ).half()
+
+ self.keys[layers_idx] = new_k
+ self.values[layers_idx] = new_v
+
+ if position > 0:
+ self.keys[layers_idx] = Tensor.where(
+ mask, new_k, self.keys[layers_idx]
+ )
+
+ self.values[layers_idx] = Tensor.where(
+ mask, new_v, self.values[layers_idx]
+ )
+
+ return self.keys[layers_idx], self.values[layers_idx]
+
+ @property
+ def seq_len(self) -> int:
+ return int(self.keys[0].shape[2])
diff --git a/src/exo/worker/engines/tinygrad/constants.py b/src/exo/worker/engines/tinygrad/constants.py
new file mode 100644
index 0000000000..d1276e486a
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/constants.py
@@ -0,0 +1,4 @@
+DEFAULT_MAX_TOKENS: int = 32168
+DEFAULT_TEMPERATURE: float = 0.7
+DEFAULT_TOP_P: float = 1.0
+DEFAULT_TOP_LOGPROBS: int = 5
diff --git a/src/exo/worker/engines/tinygrad/forward.py b/src/exo/worker/engines/tinygrad/forward.py
new file mode 100644
index 0000000000..92934e1861
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/forward.py
@@ -0,0 +1,117 @@
+from typing import NamedTuple
+
+from tinygrad.tensor import Tensor
+
+from exo.shared.model_config import ModelConfig
+from exo.worker.engines.tinygrad.cache import KVCache
+from exo.worker.engines.tinygrad.layers.attention import grouped_query_attention
+from exo.worker.engines.tinygrad.layers.embedding import apply_embedding, apply_lm_head
+from exo.worker.engines.tinygrad.layers.mlp import swiglu_mlp
+from exo.worker.engines.tinygrad.layers.normalization import rms_norm
+from exo.worker.engines.tinygrad.weight_loader import LayerWeights, TransformerWeights
+
+
+class TransformerBlockBuilder(NamedTuple):
+ cos_freqs: Tensor
+ sin_freqs: Tensor
+ layer: LayerWeights
+ idx: int
+ offset: int | Tensor
+
+def forward_pass(
+ weights: TransformerWeights,
+ input_ids: Tensor,
+ cache: KVCache | None,
+ position_offset: int | Tensor = 0,
+ rope_cos: Tensor | None = None,
+ rope_sin: Tensor | None = None,
+) -> tuple[Tensor, KVCache]:
+ config = weights.config
+ _batch, _seq_len = input_ids.shape
+
+ x = apply_embedding(weights.embed_tokens, input_ids)
+
+ if cache is None:
+ """
+ I am reducing the max_seq_len down to 4096 to work
+ with consumer grade GPUs. Unlike Apple systems,
+ most computers have memory statically partionined
+ if using integrated memory. With discrete GPUs, the
+ VRAM issue becomes explicit.
+
+ When testing out on my AMD RX6600M, this is my way
+ of handling OOM errors.
+ """
+ cache = KVCache(
+ num_layers = len(weights.layers),
+ num_kv_heads = config.num_key_value_heads,
+ head_dim = config.head_dim,
+ max_seq_len = min(config.max_position_embeddings, 4096),
+ )
+
+ cos = rope_cos if rope_cos is not None else weights.rope_cos
+ sin = rope_sin if rope_sin is not None else weights.rope_sin
+
+ for layer_idx, layer in enumerate(weights.layers):
+ builder = TransformerBlockBuilder(
+ cos, sin,
+ layer, layer_idx, position_offset,
+ )
+ x = _transformer_block(x, config, cache, builder)
+
+ if isinstance(position_offset, int):
+ x = x.realize(cache.keys[layer_idx], cache.values[layer_idx])
+
+ x = rms_norm(x, weights.final_norm, config.rms_norm_eps)
+ logits = apply_lm_head(x, weights.lm_head)
+
+ return logits, cache
+
+def _transformer_block(
+ x: Tensor,
+ config: ModelConfig,
+ cache: KVCache,
+ builder: TransformerBlockBuilder,
+) -> Tensor:
+ residual = x
+ layer = builder.layer
+ x = rms_norm(x, layer.input_norm, config.rms_norm_eps)
+
+ match config.architecture_spec.attention_type:
+ case "grouped_query" | "multi_head":
+ x = grouped_query_attention(
+ x, qkv_proj = layer.qkv_proj,
+ o_proj = layer.o_proj,
+ cos_freqs = builder.cos_freqs,
+ sin_freqs = builder.sin_freqs,
+ cache = cache, layer_idx = builder.idx,
+ position_offset = builder.offset,
+ cache_position = builder.offset,
+ num_heads = config.num_attention_heads,
+ num_kv_heads = config.num_key_value_heads,
+ head_dim = config.head_dim,
+ q_norm = layer.q_norm,
+ k_norm = layer.k_norm,
+ rms_norm_eps = config.rms_norm_eps,
+ )
+ case "multi_latent":
+ raise NotImplementedError(
+ "MLA attention: not yet been implemented"
+ )
+
+ x = x + residual
+ residual = x
+
+ x = rms_norm(x, layer.post_attn_norm, config.rms_norm_eps)
+ match config.architecture_spec.mlp_type:
+ case "swiglu":
+ x = swiglu_mlp(
+ x, layer.gate_up_proj, layer.down_proj
+ )
+ case "moe_top_k":
+ raise NotImplementedError(
+ "MoE MLP: not yet implemented"
+ )
+
+ x = x + residual
+ return x
diff --git a/src/exo/worker/engines/tinygrad/generator/generate.py b/src/exo/worker/engines/tinygrad/generator/generate.py
new file mode 100644
index 0000000000..7ba551af1e
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/generator/generate.py
@@ -0,0 +1,340 @@
+import struct
+import time
+from collections.abc import Callable, Generator
+from dataclasses import dataclass
+from typing import Any
+
+from tinygrad.dtype import dtypes
+from tinygrad.engine.jit import TinyJit
+from tinygrad.helpers import Context
+from tinygrad.tensor import Tensor
+
+from exo.shared.model_config import ModelConfig
+from exo.shared.models.model_cards import ModelId
+from exo.shared.tokenizer.eos_tokens import get_eos_token_ids_for_model
+from exo.shared.types.api import (
+ CompletionTokensDetails,
+ GenerationStats,
+ PromptTokensDetails,
+ TopLogprobItem,
+ Usage,
+)
+from exo.shared.types.memory import Memory
+from exo.shared.types.text_generation import TextGenerationTaskParams
+from exo.shared.types.worker.runner_response import GenerationResponse
+from exo.worker.engines.tinygrad.constants import (
+ DEFAULT_MAX_TOKENS,
+ DEFAULT_TEMPERATURE,
+ DEFAULT_TOP_P,
+)
+
+from ..cache import KVCache
+from ..forward import forward_pass
+from ..sampling import sample_token
+from ..weight_loader import TransformerWeights
+
+_PREFILL_BUCKETS: list[int] = [32, 64, 128, 256, 512]
+
+def _pad_to_bucket(input_ids: list[int], pad_id: int = 0) -> list[int]:
+ """
+ Padding input_ids to the nearest bucket size will cache the tensor
+ sizes. This will increase the guarentee to hit the cache, leading to
+ quicker time to first token.
+ """
+
+ for bucket in _PREFILL_BUCKETS:
+ if len(input_ids) <= bucket:
+ return input_ids + [pad_id] * (bucket - len(input_ids))
+ return input_ids
+
+
+@dataclass
+class _JitState:
+ """Persistent decode state reused across requests."""
+ jit_decode: Callable[..., tuple[Tensor, ...]]
+ cache: KVCache
+ input_buffer: Tensor
+ position_buffer: Tensor
+
+_jit_registry: dict[int, _JitState] = {}
+
+def cleanup_jit_state() -> None:
+ """Called by engine cleanup to free all JIT state."""
+ _jit_registry.clear()
+
+
+def _build_jit_decode(
+ weights: TransformerWeights,
+ cache: KVCache,
+) -> Callable[..., tuple[Tensor, ...]]:
+ num_layers = len(weights.layers)
+
+ @TinyJit
+ def decode(
+ input_ids: Tensor, position: Tensor,
+ rope_cos_table: Tensor, rope_sin_table: Tensor,
+ *cache_kv: Tensor,
+ ) -> tuple[Tensor, ...]:
+ for i in range(num_layers):
+ cache.keys[i] = cache_kv[i]
+ cache.values[i] = cache_kv[num_layers + i]
+
+ logits, _ = forward_pass(
+ weights, input_ids, cache,
+ position_offset=position,
+ rope_cos=rope_cos_table, rope_sin=rope_sin_table,
+ )
+
+ # Realize everything at once — JIT captures one fused kernel schedule
+ # instead of 32 fragmented ones.
+ logits = logits.realize(*cache.keys, *cache.values)
+
+ return (logits, *cache.keys, *cache.values)
+
+ return decode
+
+def tinygrad_generate(
+ model: TransformerWeights,
+ tokenizer: Any, # pyright: ignore[reportAny]
+ task: TextGenerationTaskParams,
+ prompt: str,
+ kv_prefix_cache: Any = None, # pyright: ignore[reportAny]
+ on_prefill_progress: Callable[[int, int], None] | None = None,
+ group: None = None,
+) -> Generator[GenerationResponse]:
+ input_ids = _encode_prompt(tokenizer, prompt)
+
+ max_tokens = task.max_output_tokens or DEFAULT_MAX_TOKENS
+ temperature = task.temperature or DEFAULT_TEMPERATURE
+ top_p = task.top_p or DEFAULT_TOP_P
+
+ request_logprobs = task.logprobs
+ top_logprobs_count = task.top_logprobs or 0
+
+ eos_ids = _get_eos_ids(tokenizer, model.config)
+ print(f"[DEBUG] eos_ids={eos_ids}")
+ prompt_tokens = len(input_ids)
+ input_ids = _pad_to_bucket(input_ids)
+
+ model_key = id(model)
+ num_layers = len(model.layers)
+ state = _jit_registry.get(model_key)
+
+ if state is None:
+ # First request: create cache, JIT, and pre-allocate buffers
+ config = model.config
+ cache = KVCache(
+ num_layers=len(model.layers),
+ num_kv_heads=config.num_key_value_heads,
+ head_dim=config.head_dim,
+ max_seq_len=min(config.max_position_embeddings, 4096),
+ )
+ # Realize cache tensors — Tensor.zeros() produces lazy/const tensors
+ # that TinyJit rejects as inputs. Force device buffer allocation.
+ for i in range(len(model.layers)):
+ cache.keys[i] = cache.keys[i].contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ cache.values[i] = cache.values[i].contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ jit_decode = _build_jit_decode(model, cache)
+ input_buffer = Tensor.empty(1, 1, dtype=dtypes.int32).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ position_buffer = Tensor.empty(1, dtype=dtypes.int32).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ state = _JitState(
+ jit_decode=jit_decode,
+ cache=cache,
+ input_buffer=input_buffer,
+ position_buffer=position_buffer,
+ )
+ _jit_registry[model_key] = state
+
+ cache = state.cache
+
+ # Batched prefill: process all prompt tokens in a single forward pass.
+ # Uses position_offset=0 (int) which triggers local attention (seq_len × seq_len)
+ # instead of full-cache attention, producing ~324 kernel dispatches total
+ # instead of 324 × N token-by-token dispatches.
+ # BEAM is disabled because prefill shapes vary per prompt length (not cacheable)
+ # and BEAM may select WMMA kernels incompatible with RDNA 2 (gfx1032).
+ if not input_ids:
+ raise ValueError("Prompt must contain at least one token")
+
+ prefill_start = time.time()
+ prompt_tensor = Tensor(input_ids, dtype=dtypes.int32).reshape(1, -1).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ with Context(BEAM=0):
+ logits, _ = forward_pass(
+ model, prompt_tensor, cache,
+ position_offset=0,
+ rope_cos=model.rope_cos, rope_sin=model.rope_sin,
+ )
+ # Take the real last token's logits (not the padded last position).
+ logits = logits[:, prompt_tokens - 1:prompt_tokens, :].contiguous() # pyright: ignore[reportUnknownMemberType]
+ # Make cache tensors contiguous for JIT compatibility.
+ for i in range(num_layers):
+ cache.keys[i] = cache.keys[i].contiguous() # pyright: ignore[reportUnknownMemberType]
+ cache.values[i] = cache.values[i].contiguous() # pyright: ignore[reportUnknownMemberType]
+ # Realize everything at once — same pattern as _build_jit_decode.
+ logits = logits.realize(*cache.keys, *cache.values)
+
+ # Rebuild the JIT after prefill. The batched prefill creates entirely new
+ # cache tensor objects (via Tensor.where + contiguous + realize) that differ
+ # from the JIT's captured output buffers. Rebuilding ensures the JIT
+ # re-captures with the correct buffer objects. The cost is 2 slow decode
+ # steps per request (cnt=0 jit-ignore, cnt=1 jit-capture), after which
+ # all subsequent tokens use fast JIT replay.
+ jit_decode = _build_jit_decode(model, cache)
+ state.jit_decode = jit_decode
+
+ prefill_time = time.time() - prefill_start
+ prompt_tps = prompt_tokens / max(prefill_time, 1e-9)
+
+ position = prompt_tokens
+
+ # Decode
+ generation_start = time.time()
+ for token_idx in range(max_tokens):
+ result = sample_token(
+ logits, temperature=temperature, top_p=top_p,
+ top_logprobs_count=top_logprobs_count,
+ request_logprobs=request_logprobs,
+ )
+
+ token_text: str = tokenizer.decode([result.token_id]) # pyright: ignore[reportAny]
+
+ is_eos = result.token_id in eos_ids
+ if is_eos:
+ print(f"[DEBUG] EOS detected: token_id={result.token_id}, text={token_text!r}")
+ if "<|eot_id|>" in token_text or "<|end" in token_text:
+ print(f"[DEBUG] Special token in text but is_eos={is_eos}, token_id={result.token_id}")
+ tokens_generated = token_idx + 1
+ elapsed = time.time() - generation_start
+ generation_tps = tokens_generated / max(elapsed, 1e-9)
+
+ finish_reason = None
+ stats = None
+ usage = None
+
+ if is_eos:
+ finish_reason = "stop"
+ elif token_idx == max_tokens - 1:
+ finish_reason = "length"
+
+ if finish_reason is not None:
+ stats = GenerationStats(
+ prompt_tps=prompt_tps, generation_tps=generation_tps,
+ prompt_tokens=prompt_tokens,
+ generation_tokens=tokens_generated,
+ peak_memory_usage=Memory.from_bytes(0),
+ )
+
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=tokens_generated,
+ total_tokens=prompt_tokens + tokens_generated,
+ prompt_tokens_details=PromptTokensDetails(),
+ completion_tokens_details=CompletionTokensDetails(),
+ )
+
+ logprob = result.logprob if request_logprobs else None
+ top_lps = None
+ if request_logprobs and task.top_logprobs:
+ top_lps = [
+ TopLogprobItem(
+ token=str(tokenizer.decode([tok_id])), # pyright: ignore[reportAny]
+ logprob=lp,
+ bytes=list(str(tokenizer.decode([tok_id])).encode("utf-8")), # pyright: ignore[reportAny]
+ )
+ for tok_id, lp in result.top_logprobs
+ ]
+
+ if is_eos:
+ token_text = ""
+
+ yield GenerationResponse(
+ text=token_text, token=result.token_id,
+ logprob=logprob, top_logprobs=top_lps,
+ finish_reason=finish_reason, stats=stats, usage=usage,
+ )
+
+ if finish_reason is not None:
+ break
+
+ state.input_buffer._buffer().copyin(memoryview(bytearray(struct.pack('=i', result.token_id)))) # pyright: ignore[reportPrivateUsage]
+ state.position_buffer._buffer().copyin(memoryview(bytearray(struct.pack('=i', position)))) # pyright: ignore[reportPrivateUsage]
+ results = jit_decode(
+ state.input_buffer, state.position_buffer,
+ model.rope_cos, model.rope_sin,
+ *cache.keys, *cache.values,
+ )
+
+ logits = results[0]
+ for i in range(num_layers):
+ cache.keys[i] = results[1 + i]
+ cache.values[i] = results[1 + num_layers + i]
+ position += 1
+
+def warmup_inference(model: TransformerWeights, tokenizer: Any, group: None = None) -> int: # pyright: ignore[reportAny]
+ """Run a full generation loop to warm up forward pass, KV cache, and sampling."""
+ from exo.shared.tokenizer.chat_template import apply_chat_template
+ from exo.shared.types.common import ModelId as CommonModelId
+ from exo.shared.types.text_generation import InputMessage
+
+ warmup_task = TextGenerationTaskParams(
+ model=CommonModelId("warmup"),
+ input=[InputMessage(role="user", content="Time to warm up!")],
+ )
+
+ prompt: str = apply_chat_template(tokenizer, warmup_task)
+ tokens_generated = 0
+
+ for _ in tinygrad_generate(model, tokenizer, warmup_task, prompt):
+ tokens_generated += 1
+ if tokens_generated >= 5:
+ break
+
+ _warmup_prefill_buckets(model)
+
+ return tokens_generated
+
+
+def _warmup_prefill_buckets(model: TransformerWeights) -> None:
+ """Pre-compile prefill kernels at each bucket size to avoid first-request compilation."""
+ model_key = id(model)
+ state = _jit_registry.get(model_key)
+ if state is None:
+ return
+
+ cache = state.cache
+ num_layers = len(model.layers)
+
+ with Context(BEAM=0):
+ for bucket_size in _PREFILL_BUCKETS:
+ dummy = Tensor.zeros(1, bucket_size, dtype=dtypes.int32).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ logits, _ = forward_pass(
+ model, dummy, cache,
+ position_offset=0,
+ rope_cos=model.rope_cos, rope_sin=model.rope_sin,
+ )
+ logits = logits[:, -1:, :].contiguous() # pyright: ignore[reportUnknownMemberType]
+ for i in range(num_layers):
+ cache.keys[i] = cache.keys[i].contiguous() # pyright: ignore[reportUnknownMemberType]
+ cache.values[i] = cache.values[i].contiguous() # pyright: ignore[reportUnknownMemberType]
+ logits.realize(*cache.keys, *cache.values)
+
+def _encode_prompt(tokenizer: Any, prompt: str) -> list[int]: # pyright: ignore[reportAny]
+ result: Any = tokenizer.encode(prompt) # pyright: ignore[reportAny]
+
+ return result.ids if hasattr(result, "ids") else result # pyright: ignore[reportAny]
+
+def _get_eos_ids(tokenizer: Any, config: ModelConfig) -> set[int]: # pyright: ignore[reportAny]
+ eos_ids: set[int] = set()
+
+ model_eos = get_eos_token_ids_for_model(ModelId(config.architecture_spec.name))
+
+ if model_eos:
+ eos_ids.update(model_eos)
+
+ if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: # pyright: ignore[reportAny]
+ eos_ids.add(int(tokenizer.eos_token_id)) # pyright: ignore[reportAny]
+
+ if not eos_ids:
+ eos_ids.add(2)
+
+ return eos_ids
diff --git a/src/exo/worker/engines/tinygrad/layers/__init__.py b/src/exo/worker/engines/tinygrad/layers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/exo/worker/engines/tinygrad/layers/attention.py b/src/exo/worker/engines/tinygrad/layers/attention.py
new file mode 100644
index 0000000000..04cf6266cb
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/layers/attention.py
@@ -0,0 +1,96 @@
+from __future__ import annotations
+
+import math
+
+from tinygrad.tensor import Tensor
+
+from exo.worker.engines.tinygrad.cache import KVCache
+from exo.worker.engines.tinygrad.layers.normalization import rms_norm
+from exo.worker.engines.tinygrad.layers.rotary import apply_rope
+from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+
+LinearWeight = Tensor | QuantizedLinear
+
+def linear_forward(x: Tensor, weight: LinearWeight) -> Tensor:
+ if isinstance(weight, QuantizedLinear):
+ return weight(x)
+ return x @ weight.T
+
+def grouped_query_attention(
+ x: Tensor,
+ qkv_proj: LinearWeight,
+ o_proj: LinearWeight,
+ cos_freqs: Tensor,
+ sin_freqs: Tensor,
+ cache: KVCache,
+ layer_idx: int,
+ position_offset: int | Tensor,
+ cache_position: int | Tensor,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ q_norm: Tensor | None = None,
+ k_norm: Tensor | None = None,
+ rms_norm_eps: float = 1e-6,
+) -> Tensor:
+ _batch, seq_len, _ = x.shape
+
+ q_dim = num_heads * head_dim
+ kv_dim = num_kv_heads * head_dim
+ qkv = linear_forward(x, qkv_proj)
+ q = qkv[..., :q_dim].reshape(int(_batch), seq_len, num_heads, head_dim).permute(0, 2, 1, 3) # pyright: ignore[reportUnknownMemberType]
+ k = qkv[..., q_dim:q_dim + kv_dim].reshape(int(_batch), seq_len, num_kv_heads, head_dim).permute(0, 2, 1, 3) # pyright: ignore[reportUnknownMemberType]
+ v = qkv[..., q_dim + kv_dim:].reshape(int(_batch), seq_len, num_kv_heads, head_dim).permute(0, 2, 1, 3) # pyright: ignore[reportUnknownMemberType]
+
+ if q_norm is not None:
+ q = rms_norm(q, q_norm, rms_norm_eps)
+ if k_norm is not None:
+ k = rms_norm(k, k_norm, rms_norm_eps)
+
+ q = apply_rope(q, cos_freqs, sin_freqs, position_offset)
+ k = apply_rope(k, cos_freqs, sin_freqs, position_offset)
+
+ # Store K,V in cache for future decode steps
+ cache.update(layer_idx, k, v, position = cache_position)
+
+ if isinstance(position_offset, int):
+ # Prefill: compute attention against local K,V (seq_len × seq_len).
+ # This avoids the wasteful (seq_len × max_seq_len) matmul that the
+ # full-cache path would produce — up to 80× less work for short prompts.
+ k_attn, v_attn = k, v
+ else:
+ # Decode (JIT): use full pre-allocated cache K,V.
+ # Shapes are fixed (max_seq_len) which is required for TinyJit replay.
+ k_attn = cache.keys[layer_idx]
+ v_attn = cache.values[layer_idx]
+
+ if num_kv_heads < num_heads:
+ repeat_factor = num_heads // num_kv_heads
+ k_attn = k_attn.unsqueeze(2).expand( # pyright: ignore[reportUnknownMemberType]
+ int(_batch), num_kv_heads, repeat_factor, -1, head_dim,
+ ).reshape(int(_batch), num_heads, -1, head_dim)
+ v_attn = v_attn.unsqueeze(2).expand( # pyright: ignore[reportUnknownMemberType]
+ int(_batch), num_kv_heads, repeat_factor, -1, head_dim,
+ ).reshape(int(_batch), num_heads, -1, head_dim)
+
+ scale = 1.0 / math.sqrt(head_dim)
+ scores: Tensor = (q @ k_attn.transpose(-2, -1)) * scale
+
+ if isinstance(position_offset, int):
+ # Prefill: standard causal mask over (seq_len × seq_len).
+ # All positions in the local K,V are valid, so no unfilled-position mask needed.
+ if seq_len > 1:
+ causal_mask = Tensor.ones(seq_len, seq_len).triu(1).reshape(1, 1, seq_len, seq_len) # pyright: ignore[reportUnknownMemberType]
+ scores = scores + causal_mask * float("-1e9")
+ else:
+ # Decode: mask unfilled positions in the full cache.
+ valid_len = cache_position + seq_len # pyright: ignore[reportOperatorIssue, reportUnknownVariableType]
+ col_indeces: Tensor = cache.col_indices
+ unfilled_mask: Tensor = col_indeces >= valid_len # pyright: ignore[reportOperatorIssue, reportUnknownVariableType]
+ scores = scores + unfilled_mask * float("-1e9") # pyright: ignore[reportUnknownVariableType]
+
+ attn_weights: Tensor = scores.softmax(axis=-1) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
+ out: Tensor = attn_weights @ v_attn # pyright: ignore[reportUnknownVariableType]
+ out = out.permute(0, 2, 1, 3).reshape(int(_batch), seq_len, num_heads * head_dim) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
+
+ return linear_forward(out, o_proj) # pyright: ignore[reportUnknownArgumentType]
diff --git a/src/exo/worker/engines/tinygrad/layers/embedding.py b/src/exo/worker/engines/tinygrad/layers/embedding.py
new file mode 100644
index 0000000000..808ac465c6
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/layers/embedding.py
@@ -0,0 +1,22 @@
+from tinygrad.tensor import Tensor
+
+from exo.worker.engines.tinygrad.quantization.layers import (
+ QuantizedEmbedding,
+ QuantizedLinear,
+)
+
+EmbedWeight = Tensor | QuantizedEmbedding
+LinearWeight = Tensor | QuantizedLinear
+
+def apply_embedding(
+ embed: EmbedWeight,
+ input_ids: Tensor,
+) -> Tensor:
+ if isinstance(embed, QuantizedEmbedding):
+ return embed(input_ids)
+ return embed[input_ids]
+
+def apply_lm_head(x: Tensor, lm_head_weight: LinearWeight) -> Tensor:
+ if isinstance(lm_head_weight, QuantizedLinear):
+ return lm_head_weight(x)
+ return x @ lm_head_weight.T
diff --git a/src/exo/worker/engines/tinygrad/layers/mlp.py b/src/exo/worker/engines/tinygrad/layers/mlp.py
new file mode 100644
index 0000000000..06b964c1fb
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/layers/mlp.py
@@ -0,0 +1,16 @@
+from tinygrad.tensor import Tensor
+
+from exo.worker.engines.tinygrad.layers.attention import LinearWeight, linear_forward
+
+
+def swiglu_mlp(
+ x: Tensor,
+ gate_up_proj: LinearWeight,
+ down_proj: LinearWeight,
+) -> Tensor:
+ gate_up = linear_forward(x, gate_up_proj)
+ half = gate_up.shape[-1] // 2
+ gate = gate_up[..., :half]
+ up = gate_up[..., half:]
+ activated = gate.silu() * up
+ return linear_forward(activated, down_proj)
diff --git a/src/exo/worker/engines/tinygrad/layers/normalization.py b/src/exo/worker/engines/tinygrad/layers/normalization.py
new file mode 100644
index 0000000000..e73ba92f3e
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/layers/normalization.py
@@ -0,0 +1,8 @@
+from tinygrad.tensor import Tensor
+
+
+def rms_norm(x: Tensor, weight: Tensor, eps: float) -> Tensor:
+ variance = (x * x).mean(axis=-1, keepdim=True)
+ x_normed = x * (variance + eps).rsqrt()
+
+ return x_normed * weight
diff --git a/src/exo/worker/engines/tinygrad/layers/rotary.py b/src/exo/worker/engines/tinygrad/layers/rotary.py
new file mode 100644
index 0000000000..7e533ce14b
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/layers/rotary.py
@@ -0,0 +1,52 @@
+from typing import Tuple
+
+from tinygrad.dtype import dtypes
+from tinygrad.tensor import Tensor
+
+
+def compute_rope_frequencies(
+ head_dim: int,
+ max_seq_len: int,
+ rope_theta: float = 10000.0,
+) -> Tuple[Tensor, Tensor]:
+ dim_pairs = head_dim // 2
+ freq_exponents = Tensor.arange(0, dim_pairs, dtype=dtypes.float32) * 2.0 / head_dim # pyright: ignore[reportUnknownMemberType]
+ inv_freq = 1.0 / (rope_theta ** freq_exponents)
+
+ positions = Tensor.arange(0, max_seq_len, dtype=dtypes.float32) # pyright: ignore[reportUnknownMemberType]
+ angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
+
+ cos_freqs = angles.cos()
+ sin_freqs = angles.sin()
+
+ return cos_freqs, sin_freqs # Each: (max_seq_len, head_dim // 2)
+
+def apply_rope(
+ x: Tensor,
+ cos_freqs: Tensor,
+ sin_freqs: Tensor,
+ position_offset: int | Tensor = 0,
+) -> Tensor:
+ seq_len = x.shape[2]
+
+ cos = cos_freqs
+ sin = sin_freqs
+
+ if isinstance(position_offset, Tensor):
+ # position_offset is shape (1,) — index into full tables.
+ # cos_freqs is (max_seq_len, dim/2), result: (1, dim/2) → (1, 1, 1, dim/2)
+ cos = cos_freqs[position_offset].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType]
+ sin = sin_freqs[position_offset].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType]
+ else:
+ cos = cos_freqs[position_offset:position_offset + seq_len].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType]
+ sin = sin_freqs[position_offset:position_offset + seq_len].reshape(1, 1, seq_len, -1) # pyright: ignore[reportUnknownMemberType]
+
+ half = x.shape[-1] // 2
+
+ x1 = x[..., :half]
+ x2 = x[..., half:]
+
+ out1 = x1 * cos - x2 * sin
+ out2 = x2 * cos + x1 * sin
+
+ return out1.cat(out2, dim=-1)
diff --git a/src/exo/worker/engines/tinygrad/quantization/__init__.py b/src/exo/worker/engines/tinygrad/quantization/__init__.py
new file mode 100644
index 0000000000..9ad8af517e
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/quantization/__init__.py
@@ -0,0 +1,13 @@
+from .dequantization import affine_dequantize
+from .layers import QuantizedEmbedding, QuantizedLinear
+from .packing import PackedTensor, calculate_pack_factor, pack_bits, unpack_bits
+
+__all__ = [
+ "PackedTensor",
+ "QuantizedEmbedding",
+ "QuantizedLinear",
+ "affine_dequantize",
+ "calculate_pack_factor",
+ "pack_bits",
+ "unpack_bits",
+]
diff --git a/src/exo/worker/engines/tinygrad/quantization/dequantization.py b/src/exo/worker/engines/tinygrad/quantization/dequantization.py
new file mode 100644
index 0000000000..2a99d57790
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/quantization/dequantization.py
@@ -0,0 +1,19 @@
+from tinygrad.tensor import Tensor
+
+
+def affine_dequantize(
+ quantized: Tensor,
+ scales: Tensor,
+ biases: Tensor,
+ group_size: int,
+) -> Tensor:
+ """
+ Apply group-wise affine dequantization:
+ out = quantized * scale + bias
+ This is the same algorithm used by MLX quantization
+ """
+
+ extended_scales = scales.repeat_interleave(group_size, dim=-1)
+ extended_biases = biases.repeat_interleave(group_size, dim=-1)
+
+ return quantized * extended_scales + extended_biases
diff --git a/src/exo/worker/engines/tinygrad/quantization/layers.py b/src/exo/worker/engines/tinygrad/quantization/layers.py
new file mode 100644
index 0000000000..8711fc5c47
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/quantization/layers.py
@@ -0,0 +1,84 @@
+from typing import final
+
+from tinygrad.tensor import Tensor
+
+from .dequantization import affine_dequantize
+from .packing import PackedTensor, unpack_bits
+
+
+@final
+class QuantizedLinear:
+ """
+ This is a drop-in replacement for tinygrad's Linear class while
+ supporting quantization.
+
+ Weights are eagerly dequantized in __init__ so that tinygrad's kernel
+ cache can reuse compiled kernels for same-shape weights, avoiding
+ redundant HIP/CUDA compilations during the first forward pass.
+ """
+
+ def __init__(
+ self,
+ weight_q: PackedTensor,
+ scales: Tensor,
+ biases: Tensor,
+ group_size: int = 64,
+ bias: Tensor | None = None,
+ ) -> None:
+ self.weight_q = weight_q
+ self.scales = scales
+ self.biases = biases
+ self.group_size = group_size
+ self.bias = bias
+
+ def __call__(self, x: Tensor) -> Tensor:
+ result = x @ self.dequantize().T
+ if self.bias is not None:
+ result = result + self.bias
+ return result
+
+ def dequantize(self) -> Tensor:
+ unpacked = unpack_bits(self.weight_q)
+ return affine_dequantize(
+ unpacked, self.scales,
+ self.biases, self.group_size)
+
+ @property
+ def in_features(self) -> int:
+ return self.weight_q.original_shape[1]
+
+ @property
+ def out_features(self) -> int:
+ return self.weight_q.original_shape[0]
+
+@final
+class QuantizedEmbedding:
+ """
+ This is a drop-in replacement for tinygrad's Embedding class providing
+ the embedding support.
+ """
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ weight_q: PackedTensor,
+ scales: Tensor,
+ biases: Tensor,
+ group_size: int = 64,
+ ) -> None:
+ self.num_embeddings = num_embeddings
+ self.weight_q = weight_q
+ self.embedding_dim = embedding_dim
+ self.scales = scales
+ self.biases = biases
+ self.group_size = group_size
+
+ def __call__(self, indices: Tensor) -> Tensor:
+ return self.dequantize()[indices]
+
+ def dequantize(self) -> Tensor:
+ unpacked = unpack_bits(self.weight_q)
+ return affine_dequantize(
+ unpacked, self.scales, self.biases, self.group_size
+ )
diff --git a/src/exo/worker/engines/tinygrad/quantization/packing.py b/src/exo/worker/engines/tinygrad/quantization/packing.py
new file mode 100644
index 0000000000..fdbb2eca4d
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/quantization/packing.py
@@ -0,0 +1,85 @@
+from math import ceil
+from typing import final
+
+from pydantic import BaseModel, ConfigDict
+from tinygrad.dtype import dtypes
+from tinygrad.tensor import Tensor
+
+
+@final
+class PackedTensor(BaseModel):
+ # Immutable container for bit-packed weight data with unpacking metadata
+
+ model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)
+
+ tensor: Tensor
+ original_shape: tuple[int, ...]
+ pack_factor: int
+ bits: int
+
+ def __setattr__(self, name: str, value: object) -> None:
+ raise AttributeError(f"'{type(self).__name__}' is frozen")
+
+def calculate_pack_factor(bits: int) -> int:
+ """
+ Pack factor is the number of values packed into a bits. The general
+ formula is pack_factor = (packed_type_bits) // bits.
+ In this case, we are dealing with a uint32 value.
+ """
+
+ return 32 // bits
+
+def pack_bits(unpacked: Tensor, bits: int) -> PackedTensor:
+ """
+ This functions converts a float16 tensor into a packed uint32 tensor, as
+ used by MLX community for their quantized models.
+ """
+
+ pack_factor = calculate_pack_factor(bits)
+ original_shape = tuple(int(d) for d in unpacked.shape)
+ last_dim = original_shape[-1]
+ padded_last_dim = ceil(last_dim / pack_factor) * pack_factor
+ packed_last_dim = padded_last_dim // pack_factor
+ packed_shape = (*original_shape[:-1], packed_last_dim)
+
+ if last_dim != padded_last_dim:
+ pad_shape = (*original_shape[:-1], padded_last_dim - last_dim)
+ padding = Tensor.zeros(*pad_shape, dtype = unpacked.dtype) # pyright: ignore[reportUnknownMemberType]
+ unpacked = unpacked.cat(padding, dim = -1)
+
+ packed = Tensor.zeros(*packed_shape, dtype=dtypes.uint32) # pyright: ignore[reportUnknownMemberType]
+ for slot in range(pack_factor):
+ values = unpacked[..., slot::pack_factor].cast(dtypes.uint32)
+ packed = packed | (values << (slot * bits))
+
+ return PackedTensor(
+ tensor = packed,
+ original_shape = original_shape,
+ pack_factor = pack_factor,
+ bits = bits,
+ )
+
+def unpack_bits(packed_tensor: PackedTensor) -> Tensor:
+ """
+ Unpack bit-packed uint32 tensor backed to individual float16 values
+ """
+ packed = packed_tensor.tensor
+ original_shape = packed_tensor.original_shape
+ pack_factor = packed_tensor.pack_factor
+ bits = packed_tensor.bits
+ mask = (1 << bits) - 1
+
+ slices: list[Tensor] = []
+ for slot in range(pack_factor):
+ slices.append((packed >> (slot * bits)) & mask)
+
+ stacked = Tensor.stack(*slices, dim=0)
+ padded_last_dim = packed.shape[-1] * pack_factor
+ padded_shape = (*original_shape[:-1], padded_last_dim)
+ result = stacked.permute(*range(1, len(stacked.shape)), 0).reshape(padded_shape) # pyright: ignore[reportUnknownMemberType]
+
+ if padded_last_dim != original_shape[-1]:
+ # Trim padding for padded, packed tensors
+ result = result[..., :original_shape[-1]]
+
+ return result.cast(dtypes.float16)
diff --git a/src/exo/worker/engines/tinygrad/quantization/shapes.py b/src/exo/worker/engines/tinygrad/quantization/shapes.py
new file mode 100644
index 0000000000..138bb8ea15
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/quantization/shapes.py
@@ -0,0 +1,71 @@
+from typing import Literal
+
+from exo.shared.model_config import ModelConfig
+
+LayerType = Literal[
+ "q_proj", "k_proj", "v_proj", "o_proj", "qkv_proj",
+ "gate_proj", "up_proj", "down_proj",
+ "embed_tokens", "lm_head",
+ "unknown",
+]
+
+_LAYER_PATTERNS: dict[LayerType, list[str]] = {
+ "q_proj": [".q_proj.", ".q.", ".attn.q"],
+ "k_proj": [".k_proj.", ".k.", ".attn.k"],
+ "v_proj": [".v_proj.", ".v.", ".attn.v"],
+ "o_proj": [".o_proj.", ".o.", ".attn.o"],
+ "qkv_proj": [".c_attn.", ".qkv.", ".attn.c_attn"],
+ "gate_proj": [".gate_proj.", ".gate.", ".mlp.gate."],
+ "up_proj": [".up_proj.", ".up.", ".mlp.up."],
+ "down_proj": [".down_proj.", ".down.", ".mlp.down."],
+ "embed_tokens": ["embed_tokens", "wte"],
+ "lm_head": ["lm_head", "output_layer"],
+}
+
+def detect_layer_type(weight_key: str) -> LayerType:
+ """
+ Detect canonical layer type from a safetensors key.
+ """
+
+ key = weight_key.lower()
+
+ for layer_type, patterns in _LAYER_PATTERNS.items():
+ if any(pattern in key for pattern in patterns):
+ return layer_type
+
+ return "unknown"
+
+def infer_weight_shape(weight_key: str, config: ModelConfig) -> tuple[int, ...]:
+ """
+ Infer original weight shape form layer name and ModelConfig.
+
+ Handles GQA correctly: k_proj and v_proj use kv_dim instead of
+ hidden_size.
+ """
+
+ hidden_size = config.hidden_size
+ intermediate_size = config.intermediate_size
+ vocab_size = config.vocab_size
+
+ kv_dim = config.num_key_value_heads * config.head_dim
+ q_dim = config.num_attention_heads * config.head_dim
+
+ shape_map: dict[LayerType, tuple[int, ...]] = {
+ "q_proj": (q_dim, hidden_size),
+ "k_proj": (kv_dim, hidden_size),
+ "v_proj": (kv_dim, hidden_size),
+ "o_proj": (hidden_size, q_dim),
+ "qkv_proj": (q_dim + 2 * kv_dim, hidden_size),
+ "gate_proj": (intermediate_size, hidden_size),
+ "up_proj": (intermediate_size, hidden_size),
+ "down_proj": (hidden_size, intermediate_size),
+ "embed_tokens": (vocab_size, hidden_size),
+ "lm_head": (vocab_size, hidden_size),
+ }
+
+ layer_type = detect_layer_type(weight_key)
+ if layer_type in shape_map:
+ return shape_map[layer_type]
+
+ # Generic fallback
+ return (q_dim, hidden_size)
diff --git a/src/exo/worker/engines/tinygrad/sampling.py b/src/exo/worker/engines/tinygrad/sampling.py
new file mode 100644
index 0000000000..6fbcb6f15d
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/sampling.py
@@ -0,0 +1,44 @@
+from typing import NamedTuple
+
+from tinygrad.tensor import Tensor
+
+
+class SampleResult(NamedTuple):
+ token_id: int
+ logprob: float
+ top_logprobs: list[tuple[int, float]] # (token_id, logprob)
+
+def sample_token(
+ logits: Tensor,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_logprobs_count: int = 0,
+ request_logprobs: bool = False,
+) -> SampleResult:
+ last_logits = logits[0, -1, :]
+
+ if temperature == 0:
+ token_id = int(last_logits.argmax().item()) # pyright: ignore[reportUnknownMemberType]
+ else:
+ # Gumble-max trick for reducing GPU -> CPU sync.
+
+ scaled = last_logits / temperature
+ gumble_noise = -(-Tensor.rand(scaled.shape).log()).log() # pyright: ignore[reportUnknownMemberType]
+ token_id = int((scaled + gumble_noise).argmax().item()) # pyright: ignore[reportUnknownMemberType]
+
+ selected_logprob: float = 0.0
+ top_logprobs: list[tuple[int, float]] = []
+
+ if request_logprobs:
+ log_probs = last_logits.log_softmax(axis = -1)
+ selected_logprob = float(log_probs[token_id].item())
+
+ if top_logprobs_count > 0:
+ values, indices = log_probs.topk(top_logprobs_count)
+ top_logprobs = [(int(idx), float(val)) for val, idx in zip(values.tolist(), indices.tolist(), strict=True)] # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType, reportArgumentType]
+
+ return SampleResult(
+ token_id=token_id,
+ logprob=selected_logprob,
+ top_logprobs=top_logprobs,
+ )
diff --git a/src/exo/worker/engines/tinygrad/utils_tinygrad.py b/src/exo/worker/engines/tinygrad/utils_tinygrad.py
new file mode 100644
index 0000000000..b85ac6efec
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/utils_tinygrad.py
@@ -0,0 +1,36 @@
+from collections.abc import Callable
+from typing import Any
+
+from exo.download.download_utils import build_model_path
+from exo.shared.model_config import parse_model_config
+from exo.shared.tokenizer.loader import load_tokenizer_for_model
+from exo.shared.tokenizer.wrapper import TokenizerWrapper
+from exo.shared.types.worker.instances import BoundInstance, TinygradInstance
+
+from .weight_loader import TransformerWeights, load_transformer_weights
+
+
+def initialize_tinygrad(bound_instance: BoundInstance) -> None:
+ instance = bound_instance.instance
+ assert isinstance(instance, TinygradInstance)
+ # For single instance, we can let tinygrad instance decide device.
+ # When sharding models, we may have to add further explicit initialization
+ # routines to ensure effective sharding.
+
+
+def load_tinygrad_items(
+ bound_instance: BoundInstance, group: None,
+ on_timeout: Callable[[], None] | None = None,
+) -> tuple[TransformerWeights, Any]:
+ shard = bound_instance.bound_shard
+ model_id = shard.model_card.model_id
+ model_path = build_model_path(model_id)
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(
+ model_path=model_path, config=config,
+ start_layer=shard.start_layer, end_layer=shard.end_layer,
+ )
+
+ tokenizer = TokenizerWrapper(load_tokenizer_for_model(model_id, model_path))
+
+ return weights, tokenizer
diff --git a/src/exo/worker/engines/tinygrad/weight_loader.py b/src/exo/worker/engines/tinygrad/weight_loader.py
new file mode 100644
index 0000000000..7d23fb1131
--- /dev/null
+++ b/src/exo/worker/engines/tinygrad/weight_loader.py
@@ -0,0 +1,269 @@
+from pathlib import Path
+from typing import Literal, NamedTuple, overload
+
+from tinygrad.device import Device
+from tinygrad.nn.state import safe_load
+from tinygrad.tensor import Tensor
+
+from exo.shared.architecture import ArchitectureSpec
+from exo.shared.model_config import ModelConfig
+from exo.worker.engines.tinygrad.layers.rotary import compute_rope_frequencies
+from exo.worker.engines.tinygrad.quantization.layers import (
+ QuantizedEmbedding,
+ QuantizedLinear,
+)
+from exo.worker.engines.tinygrad.quantization.packing import PackedTensor
+from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+LinearWeight = Tensor | QuantizedLinear
+EmbedWeight = Tensor | QuantizedEmbedding
+
+class LayerWeights(NamedTuple):
+ qkv_proj: LinearWeight # Merged Q+K+V
+ o_proj: LinearWeight
+
+ gate_up_proj: LinearWeight # Merged gate+up
+ down_proj: LinearWeight
+
+ input_norm: Tensor
+ post_attn_norm: Tensor
+
+ # Optional layers
+ q_norm: Tensor | None = None
+ k_norm: Tensor | None = None
+
+ # MoE (None for dense models)
+ router_weight: Tensor | None = None
+ expert_gate_projs: list[LinearWeight] | None = None
+ expert_up_projs: list[LinearWeight] | None = None
+ expert_down_projs: list[LinearWeight] | None = None
+
+class TransformerWeights(NamedTuple):
+ embed_tokens: EmbedWeight
+ lm_head: LinearWeight
+ final_norm: Tensor
+ layers: list[LayerWeights]
+ config: ModelConfig
+ rope_sin: Tensor
+ rope_cos: Tensor
+
+def load_transformer_weights(
+ model_path: Path,
+ config: ModelConfig,
+ start_layer: int = 0,
+ end_layer: int | None = None,
+) -> TransformerWeights:
+
+ if end_layer is None:
+ end_layer = config.num_hidden_layers
+
+ spec = config.architecture_spec
+ raw_weights = _load_all_safetensors(model_path)
+
+ embed_tokens = _build_weight(raw_weights,
+ f"{spec.embed_key}.weight",
+ config, is_embedding = True)
+
+ lm_head: LinearWeight
+ if config.tie_word_embeddings:
+ if isinstance(embed_tokens, QuantizedEmbedding):
+ lm_head = QuantizedLinear(
+ weight_q = embed_tokens.weight_q,
+ scales = embed_tokens.scales,
+ biases = embed_tokens.biases,
+ group_size = embed_tokens.group_size,
+ )
+ else:
+ lm_head = embed_tokens
+ else:
+ lm_head = _build_weight(
+ raw_weights, f"{spec.lm_head_key}.weight", config,
+ )
+
+ final_norm = raw_weights[f"{spec.final_norm_key}.weight"]
+
+ layers: list[LayerWeights] = []
+
+ for layer_idx in range(start_layer, end_layer):
+ prefix = spec.layer_prefix.format(layer_idx=layer_idx)
+ layers.append(_build_layer_weights(raw_weights, prefix, spec, config))
+
+ rope_cos, rope_sin = compute_rope_frequencies(
+ head_dim = config.head_dim,
+ max_seq_len = config.max_position_embeddings,
+ rope_theta = config.rope_theta,
+ )
+
+ return TransformerWeights(
+ embed_tokens=embed_tokens, lm_head=lm_head,
+ final_norm=final_norm, layers=layers, config=config,
+ rope_cos = rope_cos.realize(),
+ rope_sin = rope_sin.realize(),
+ )
+
+def _merge_linear_weights(*weights: LinearWeight) -> LinearWeight:
+ """Merge multiple LinearWeight objects by concatenating along the output dimension (dim 0).
+
+ For quantized weights, concatenates packed uint32, scales, and biases directly —
+ no extra memory from dequantization. For non-quantized or mixed weights, dequantizes
+ first then concatenates plain Tensors.
+ """
+ if all(isinstance(w, QuantizedLinear) for w in weights):
+ qls = [w for w in weights if isinstance(w, QuantizedLinear)]
+ merged_tensor = qls[0].weight_q.tensor.cat(
+ *[w.weight_q.tensor for w in qls[1:]], dim=0
+ ).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ merged_scales = qls[0].scales.cat(
+ *[w.scales for w in qls[1:]], dim=0
+ ).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ merged_biases = qls[0].biases.cat(
+ *[w.biases for w in qls[1:]], dim=0
+ ).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+ return QuantizedLinear(
+ weight_q=PackedTensor(
+ tensor=merged_tensor,
+ original_shape=(
+ sum(w.weight_q.original_shape[0] for w in qls),
+ qls[0].weight_q.original_shape[1],
+ ),
+ pack_factor=qls[0].weight_q.pack_factor,
+ bits=qls[0].weight_q.bits,
+ ),
+ scales=merged_scales,
+ biases=merged_biases,
+ group_size=qls[0].group_size,
+ )
+ # Non-quantized or mixed: dequantize if needed, then cat
+ tensors: list[Tensor] = []
+ for w in weights:
+ if isinstance(w, QuantizedLinear):
+ tensors.append(w.dequantize())
+ else:
+ tensors.append(w)
+ return tensors[0].cat(*tensors[1:], dim=0).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+
+def _build_layer_weights(
+ raw: dict[str, Tensor],
+ prefix: str,
+ spec: ArchitectureSpec,
+ config: ModelConfig,
+) -> LayerWeights:
+ def key(suffix: str) -> str:
+ return f"{prefix}.{suffix}.weight"
+
+ q_norm = raw.get(f"{prefix}.{spec.q_norm_key}.weight") if spec.q_norm_key else None
+ k_norm = raw.get(f"{prefix}.{spec.k_norm_key}.weight") if spec.k_norm_key else None
+
+ q_proj = _build_weight(raw, key(spec.q_proj_key), config)
+ k_proj = _build_weight(raw, key(spec.k_proj_key), config)
+ v_proj = _build_weight(raw, key(spec.v_proj_key), config)
+ qkv_proj = _merge_linear_weights(q_proj, k_proj, v_proj)
+
+ gate_proj = _build_weight(raw, key(spec.gate_proj_key), config)
+ up_proj = _build_weight(raw, key(spec.up_proj_key), config)
+ gate_up_proj = _merge_linear_weights(gate_proj, up_proj)
+
+ return LayerWeights(
+ qkv_proj=qkv_proj,
+ o_proj=_build_weight(raw, key(spec.o_proj_key), config),
+ gate_up_proj=gate_up_proj,
+ down_proj=_build_weight(raw, key(spec.down_proj_key), config),
+ input_norm=raw[f"{prefix}.{spec.input_norm_key}.weight"],
+ post_attn_norm=raw[f"{prefix}.{spec.post_attn_norm_key}.weight"],
+ q_norm=q_norm,
+ k_norm=k_norm,
+ )
+
+@overload
+def _build_weight(raw: dict[str, Tensor], key: str, config: ModelConfig, is_embedding: Literal[True]) -> EmbedWeight: ...
+@overload
+def _build_weight(raw: dict[str, Tensor], key: str, config: ModelConfig, is_embedding: Literal[False] = ...) -> LinearWeight: ...
+
+def _build_weight(
+ raw: dict[str, Tensor],
+ key: str,
+ config: ModelConfig,
+ is_embedding: bool = False,
+) -> LinearWeight | EmbedWeight:
+ scales_key = key.replace(".weight", ".scales")
+ biases_key = key.replace(".weight", ".biases")
+
+ # MLX quantized: .weight (packed uint32) + .scales + .biases + quantization_config
+ if key in raw and config.quantization_config is not None and scales_key in raw and biases_key in raw:
+ qcfg = config.quantization_config
+ packed = PackedTensor(
+ tensor = raw[key],
+ original_shape = infer_weight_shape(key, config),
+ pack_factor = 32 // qcfg.bits,
+ bits = qcfg.bits,
+ )
+
+ if is_embedding:
+ return QuantizedEmbedding(
+ num_embeddings = config.vocab_size,
+ embedding_dim = config.hidden_size,
+ weight_q = packed,
+ scales = raw[scales_key],
+ biases = raw[biases_key],
+ group_size = qcfg.group_size,
+ )
+
+ return QuantizedLinear(
+ weight_q = packed,
+ scales = raw[scales_key],
+ biases = raw[biases_key],
+ group_size = qcfg.group_size,
+ )
+
+ # Plain unquantized: .weight only
+ if key in raw:
+ return raw[key]
+
+ # Legacy .qweight format
+ qweight_key = key.replace(".weight", ".qweight")
+
+ if qweight_key in raw and config.quantization_config is not None:
+ qcfg = config.quantization_config
+ packed = PackedTensor(
+ tensor = raw[qweight_key],
+ original_shape = infer_weight_shape(key, config),
+ pack_factor = 32 // qcfg.bits,
+ bits = qcfg.bits,
+ )
+
+ if is_embedding:
+ return QuantizedEmbedding(
+ num_embeddings = config.vocab_size,
+ embedding_dim = config.hidden_size,
+ weight_q = packed,
+ scales = raw[scales_key],
+ biases = raw[biases_key],
+ group_size = qcfg.group_size,
+ )
+
+ return QuantizedLinear(
+ weight_q = packed,
+ scales = raw[scales_key],
+ biases = raw[biases_key],
+ group_size = qcfg.group_size,
+ )
+
+ raise KeyError(f"Weight key '{key}' not found (also tried {qweight_key})")
+
+def _load_all_safetensors(path: Path) -> dict[str, Tensor]:
+ from tinygrad.helpers import Context
+
+ merged: dict[str, Tensor] = {}
+
+ # Disable BEAM during weight loading. Copy kernels (DISK -> GPU) have unique
+ # shapes per tensor and don't benefit from beam search optimisation.
+ with Context(BEAM=0):
+ for safetensor_file in sorted(path.glob("*.safetensors")):
+ shard = safe_load(str(safetensor_file))
+ for key, tensor in shard.items():
+ merged[key] = tensor.to(Device.DEFAULT).contiguous().realize() # pyright: ignore[reportUnknownMemberType]
+
+ if not merged:
+ raise FileNotFoundError(f"No .safetensors file found in {path}")
+
+ return merged
diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py
index 9949cb7e65..746c993fcf 100644
--- a/src/exo/worker/runner/bootstrap.py
+++ b/src/exo/worker/runner/bootstrap.py
@@ -1,15 +1,41 @@
import os
+import signal
+import threading
+from multiprocessing.connection import Connection
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
-from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
+from exo.shared.types.worker.instances import (
+ BoundInstance,
+ MlxJacclInstance,
+ TinygradInstance,
+)
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
logger: "loguru.Logger" = loguru.logger
+def _start_parent_death_watchdog(conn: Connection) -> None:
+ """Watch for parent death via a multiprocessing.Connection.
+
+ The parent holds the other end of the pipe. When the parent dies,
+ the OS closes its end and our recv() returns with EOFError.
+ This is a POSIX and Windows guarantee, making it cross-platform.
+ """
+
+ def _watchdog() -> None:
+ try:
+ conn.recv()
+ except (EOFError, OSError):
+ pass
+ finally:
+ conn.close()
+ os.kill(os.getpid(), signal.SIGTERM)
+
+ t = threading.Thread(target=_watchdog, daemon=True)
+ t.start()
def entrypoint(
bound_instance: BoundInstance,
@@ -17,27 +43,45 @@ def entrypoint(
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
+ parent_death_conn: Connection,
) -> None:
- fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
- if fast_synch_override == "on" or (
- fast_synch_override != "off"
- and (
- isinstance(bound_instance.instance, MlxJacclInstance)
- and len(bound_instance.instance.jaccl_devices) >= 2
- )
- ):
- os.environ["MLX_METAL_FAST_SYNCH"] = "1"
- else:
- os.environ["MLX_METAL_FAST_SYNCH"] = "0"
+ _start_parent_death_watchdog(parent_death_conn)
+
+ is_tinygrad = isinstance(bound_instance.instance, TinygradInstance)
+
+ if not is_tinygrad:
+ fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
+ if fast_synch_override == "on" or (
+ fast_synch_override != "off"
+ and (
+ isinstance(bound_instance.instance, MlxJacclInstance)
+ and len(bound_instance.instance.jaccl_devices) >= 2
+ )
+ ):
+ os.environ["MLX_METAL_FAST_SYNCH"] = "1"
+ else:
+ os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
- logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
+ if not is_tinygrad:
+ logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
+
+ # Import main after setting global logger - this lets us just import logger from this module.
+ # Guard by instance type: TinygradInstance must never import MLX modules (macOS-only).
+ if is_tinygrad:
+ # Default JIT=1 and BEAM=2 for decode inference. Prefill/weight-loading
+ # paths override BEAM with Context(BEAM=0). User can still override via env.
+ os.environ.setdefault("JIT", "1")
+ os.environ.setdefault("BEAM", "2")
+
+ os.environ.setdefault("TC", "1")
- # Import main after setting global logger - this lets us just import logger from this module
try:
- if bound_instance.is_image_model:
+ if is_tinygrad:
+ from exo.worker.runner.llm_inference.tinygrad_runner import main
+ elif bound_instance.is_image_model:
from exo.worker.runner.image_models.runner import main
else:
from exo.worker.runner.llm_inference.runner import main
diff --git a/src/exo/worker/runner/llm_inference/tinygrad_runner.py b/src/exo/worker/runner/llm_inference/tinygrad_runner.py
new file mode 100644
index 0000000000..68db7569c1
--- /dev/null
+++ b/src/exo/worker/runner/llm_inference/tinygrad_runner.py
@@ -0,0 +1,403 @@
+import resource
+import time
+from collections.abc import Generator
+from typing import TYPE_CHECKING
+
+from exo.shared.models.model_cards import ModelTask
+from exo.shared.tokenizer.chat_template import apply_chat_template
+from exo.shared.types.chunks import (
+ ErrorChunk,
+ TokenChunk,
+ ToolCallChunk,
+)
+from exo.shared.types.events import (
+ ChunkGenerated,
+ Event,
+ RunnerStatusUpdated,
+ TaskAcknowledged,
+ TaskStatusUpdated,
+)
+from exo.shared.types.tasks import (
+ LoadModel,
+ Shutdown,
+ StartWarmup,
+ Task,
+ TaskId,
+ TaskStatus,
+ TextGeneration,
+)
+from exo.shared.types.text_generation import TextGenerationTaskParams
+from exo.shared.types.worker.instances import BoundInstance
+from exo.shared.types.worker.runner_response import (
+ GenerationResponse,
+ ToolCallResponse,
+)
+from exo.shared.types.worker.runners import (
+ RunnerFailed,
+ RunnerIdle,
+ RunnerLoaded,
+ RunnerLoading,
+ RunnerReady,
+ RunnerRunning,
+ RunnerShutdown,
+ RunnerShuttingDown,
+ RunnerStatus,
+ RunnerWarmingUp,
+)
+from exo.utils.channels import MpReceiver, MpSender
+from exo.worker.engines.tinygrad.generator.generate import (
+ cleanup_jit_state,
+ tinygrad_generate,
+ warmup_inference,
+)
+from exo.worker.engines.tinygrad.utils_tinygrad import (
+ initialize_tinygrad,
+ load_tinygrad_items,
+)
+from exo.worker.engines.tinygrad.weight_loader import TransformerWeights
+from exo.worker.runner.bootstrap import logger
+
+from .tool_parsers import ToolParser, make_mlx_parser
+
+
+def main(
+ bound_instance: BoundInstance,
+ event_sender: MpSender[Event],
+ task_receiver: MpReceiver[Task],
+ cancel_receiver: MpReceiver[TaskId],
+) -> None:
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
+
+ runner_id = bound_instance.bound_runner_id
+ shard_metadata = bound_instance.bound_shard
+
+ logger.info("hello from the tinygrad runner")
+
+ setup_start_time = time.time()
+ cancelled_tasks = set[TaskId]()
+
+ inference_model: TransformerWeights | None = None
+ tokenizer = None
+ tool_parser: ToolParser | None = None
+ check_for_cancel_every: int | None = None
+
+ current_status: RunnerStatus = RunnerIdle()
+ logger.info("tinygrad runner created")
+ event_sender.send(
+ RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
+ )
+ seen = set[TaskId]()
+ with task_receiver as tasks:
+ for task in tasks:
+ if task.task_id in seen:
+ logger.warning("repeat task - potential error")
+ seen.add(task.task_id)
+ cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
+ event_sender.send(
+ TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
+ )
+ match task:
+ case LoadModel() if isinstance(current_status, (RunnerIdle, RunnerFailed)):
+ total_layers = shard_metadata.end_layer - shard_metadata.start_layer
+ current_status = RunnerLoading(
+ layers_loaded=0, total_layers=total_layers
+ )
+ logger.info("tinygrad runner loading")
+ event_sender.send(
+ RunnerStatusUpdated(
+ runner_id=runner_id, runner_status=current_status
+ )
+ )
+ event_sender.send(TaskAcknowledged(task_id=task.task_id))
+
+ def on_model_load_timeout() -> None:
+ event_sender.send(
+ RunnerStatusUpdated(
+ runner_id=runner_id,
+ runner_status=RunnerFailed(
+ error_message="Model loading timed out"
+ ),
+ )
+ )
+ time.sleep(0.5)
+
+ assert (
+ ModelTask.TextGeneration in shard_metadata.model_card.tasks
+ ), f"Incorrect model task(s): {shard_metadata.model_card.tasks}"
+
+ initialize_tinygrad(bound_instance)
+
+ inference_model, tokenizer = load_tinygrad_items( # pyright: ignore[reportAny]
+ bound_instance,
+ None,
+ on_timeout=on_model_load_timeout,
+ )
+ logger.info(
+ f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}" # pyright: ignore[reportAny]
+ )
+ if tokenizer.has_tool_calling: # pyright: ignore[reportAny]
+ assert tokenizer.tool_call_start # pyright: ignore[reportAny]
+ assert tokenizer.tool_call_end # pyright: ignore[reportAny]
+ assert tokenizer.tool_parser # pyright: ignore[reportAny]
+ tool_parser = make_mlx_parser(
+ tokenizer.tool_call_start, # pyright: ignore[reportAny]
+ tokenizer.tool_call_end, # pyright: ignore[reportAny]
+ tokenizer.tool_parser, # pyright: ignore[reportAny]
+ )
+ current_status = RunnerLoaded()
+ logger.info("tinygrad runner loaded")
+
+ case StartWarmup() if isinstance(current_status, RunnerLoaded):
+ current_status = RunnerWarmingUp()
+ logger.info("tinygrad runner warming up")
+ event_sender.send(
+ RunnerStatusUpdated(
+ runner_id=runner_id, runner_status=current_status
+ )
+ )
+ event_sender.send(TaskAcknowledged(task_id=task.task_id))
+
+ assert inference_model
+ assert tokenizer
+
+ t = time.monotonic()
+ toks = warmup_inference(
+ model=inference_model,
+ tokenizer=tokenizer,
+ )
+ logger.info(f"warmed up by generating {toks} tokens")
+ check_for_cancel_every = min(
+ max(1, round(toks / max(time.monotonic() - t, 0.001))), 100
+ )
+ logger.info(
+ f"tinygrad runner checking for cancellation every {check_for_cancel_every} tokens"
+ )
+ logger.info(
+ f"tinygrad runner initialized in {time.time() - setup_start_time} seconds"
+ )
+ current_status = RunnerReady()
+ logger.info("tinygrad runner ready")
+
+ case TextGeneration(task_params=task_params, command_id=command_id) if (
+ isinstance(current_status, RunnerReady)
+ ):
+ logger.info(f"received chat request: {task}")
+ current_status = RunnerRunning()
+ logger.info("tinygrad runner running")
+ event_sender.send(
+ RunnerStatusUpdated(
+ runner_id=runner_id, runner_status=current_status
+ )
+ )
+ event_sender.send(TaskAcknowledged(task_id=task.task_id))
+ assert inference_model
+ assert tokenizer
+ assert check_for_cancel_every
+
+ try:
+ _check_for_debug_prompts(task_params)
+
+ prompt = apply_chat_template(tokenizer, task_params)
+
+ gen: Generator[GenerationResponse | ToolCallResponse] = tinygrad_generate(
+ model=inference_model,
+ tokenizer=tokenizer,
+ task=task_params,
+ prompt=prompt,
+ )
+
+ if tool_parser:
+ gen = _parse_tool_calls(gen, tool_parser)
+
+ completion_tokens = 0
+ tokens_since_last_cancel_check = check_for_cancel_every
+ for response in gen:
+ tokens_since_last_cancel_check += 1
+ if tokens_since_last_cancel_check >= check_for_cancel_every:
+ tokens_since_last_cancel_check = 0
+ cancelled_tasks.update(cancel_receiver.collect())
+ want_to_cancel = (task.task_id in cancelled_tasks) or (
+ TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
+ )
+ if want_to_cancel:
+ break
+
+ match response:
+ case GenerationResponse():
+ completion_tokens += 1
+ if response.finish_reason == "error":
+ event_sender.send(
+ ChunkGenerated(
+ command_id=command_id,
+ chunk=ErrorChunk(
+ error_message=response.text,
+ model=shard_metadata.model_card.model_id,
+ ),
+ )
+ )
+ else:
+ assert response.finish_reason not in (
+ "error",
+ "tool_calls",
+ "function_call",
+ )
+ event_sender.send(
+ ChunkGenerated(
+ command_id=command_id,
+ chunk=TokenChunk(
+ model=shard_metadata.model_card.model_id,
+ text=response.text,
+ token_id=response.token,
+ usage=response.usage,
+ finish_reason=response.finish_reason,
+ stats=response.stats,
+ logprob=response.logprob,
+ top_logprobs=response.top_logprobs,
+ is_thinking=response.is_thinking,
+ ),
+ )
+ )
+ case ToolCallResponse():
+ event_sender.send(
+ ChunkGenerated(
+ command_id=command_id,
+ chunk=ToolCallChunk(
+ tool_calls=response.tool_calls,
+ model=shard_metadata.model_card.model_id,
+ usage=response.usage,
+ stats=response.stats,
+ ),
+ )
+ )
+
+ except Exception as e:
+ event_sender.send(
+ ChunkGenerated(
+ command_id=command_id,
+ chunk=ErrorChunk(
+ model=shard_metadata.model_card.model_id,
+ finish_reason="error",
+ error_message=str(e),
+ ),
+ )
+ )
+ raise
+
+ current_status = RunnerReady()
+ logger.info("tinygrad runner ready")
+
+ case Shutdown():
+ current_status = RunnerShuttingDown()
+ logger.info("tinygrad runner shutting down")
+ if not TYPE_CHECKING:
+ del inference_model, tokenizer
+ cleanup_jit_state()
+ import gc
+
+ gc.collect()
+
+ event_sender.send(
+ RunnerStatusUpdated(
+ runner_id=runner_id, runner_status=current_status
+ )
+ )
+ event_sender.send(TaskAcknowledged(task_id=task.task_id))
+
+ current_status = RunnerShutdown()
+
+ case _:
+ raise ValueError(
+ f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
+ )
+
+ was_cancelled = (task.task_id in cancelled_tasks) or (
+ TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
+ )
+ if not was_cancelled:
+ event_sender.send(
+ TaskStatusUpdated(
+ task_id=task.task_id, task_status=TaskStatus.Complete
+ )
+ )
+ event_sender.send(
+ RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
+ )
+
+ if isinstance(current_status, RunnerShutdown):
+ break
+
+
+def _parse_tool_calls(
+ responses: Generator[GenerationResponse | ToolCallResponse],
+ tool_parser: ToolParser,
+) -> Generator[GenerationResponse | ToolCallResponse]:
+ """Wrap a generation stream to detect and parse tool calls."""
+ in_tool_call = False
+ tool_call_text_parts: list[str] = []
+ for response in responses:
+ if not isinstance(response, GenerationResponse):
+ yield response
+ continue
+
+ if response.text.startswith(tool_parser.start_parsing):
+ in_tool_call = True
+
+ if in_tool_call:
+ tool_call_text_parts.append(response.text)
+ if response.text.endswith(tool_parser.end_parsing):
+ parsed = tool_parser.parse_tool_calls(
+ "".join(tool_call_text_parts).strip()
+ )
+ logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
+ if parsed is not None:
+ yield ToolCallResponse(
+ tool_calls=parsed, usage=response.usage, stats=response.stats
+ )
+ else:
+ logger.warning(
+ f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
+ )
+ response = response.model_copy(
+ update={"text": "".join(tool_call_text_parts)}
+ )
+ yield response
+
+ in_tool_call = False
+ tool_call_text_parts = []
+ continue
+
+ if response.finish_reason is not None:
+ logger.info(
+ "tool call parsing interrupted, yield partial tool call as text"
+ )
+ response = response.model_copy(
+ update={
+ "text": "".join(tool_call_text_parts),
+ "token": 0,
+ }
+ )
+ yield response
+
+ continue
+
+ yield response
+
+
+EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
+EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
+
+
+def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
+ if len(task_params.input) == 0:
+ return
+ prompt = task_params.input[0].content
+
+ if not prompt:
+ return
+
+ if EXO_RUNNER_MUST_FAIL in prompt:
+ logger.info("raising exception")
+ raise Exception("Artificial runner exception - for testing purposes only.")
+ if EXO_RUNNER_MUST_TIMEOUT in prompt:
+ time.sleep(100)
diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py
index e8a06a77c3..28aee80f05 100644
--- a/src/exo/worker/runner/runner_supervisor.py
+++ b/src/exo/worker/runner/runner_supervisor.py
@@ -1,7 +1,8 @@
import contextlib
import signal
from dataclasses import dataclass, field
-from multiprocessing import Process
+from multiprocessing import Pipe, Process
+from multiprocessing.connection import Connection
from typing import Self
import anyio
@@ -52,6 +53,7 @@ class RunnerSupervisor:
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
+ _death_conn: Connection
@classmethod
def create(
@@ -65,6 +67,8 @@ def create(
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
+ parent_conn, child_conn = Pipe()
+
runner_process = Process(
target=entrypoint,
args=(
@@ -73,8 +77,9 @@ def create(
task_recv,
cancel_recv,
logger,
+ child_conn,
),
- daemon=True,
+ daemon=False,
)
shard_metadata = bound_instance.bound_shard
@@ -88,6 +93,7 @@ def create(
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
+ _death_conn=parent_conn,
)
return self
@@ -98,6 +104,7 @@ async def run(self):
def shutdown(self):
logger.info("Runner supervisor shutting down")
+ self._death_conn.close()
self._ev_recv.close()
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
diff --git a/src/exo/worker/tests/unittests/test_runner/test_runner_import.py b/src/exo/worker/tests/unittests/test_runner/test_runner_import.py
new file mode 100644
index 0000000000..48eb8bcf4a
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_runner/test_runner_import.py
@@ -0,0 +1,68 @@
+import inspect
+
+import pytest
+
+
+def _mlx_backend_available() -> bool:
+ """Return True only if mlx.core can be fully loaded (native libs present)."""
+ try:
+ import mlx.core # noqa: F401 # pyright: ignore[reportUnusedImport]
+ return True
+ except (ImportError, OSError):
+ return False
+
+requires_mlx = pytest.mark.skipif(
+ not _mlx_backend_available(),
+ reason="MLX native backend not available (missing CUDA/Metal libraries)",
+)
+
+def test_tinygrad_runner_imports_without_mlx() -> None:
+ """tinygrad_runner.py must be importable on Linux where MLX is absent."""
+ from exo.worker.runner.llm_inference.tinygrad_runner import (
+ main, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ )
+
+def test_engine_factory_importable() -> None:
+ """engine_factory.py must be importable on any platform."""
+ from exo.worker.engines.engine_factory import (
+ Engine, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ create_engine, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ )
+
+def test_engine_is_immutable() -> None:
+ """Engine must be an immutable Pydantic model with the expected fields."""
+ from pydantic import BaseModel
+
+ from exo.worker.engines.engine_factory import Engine
+ assert issubclass(Engine, BaseModel)
+ field_names = set(Engine.model_fields.keys())
+
+ required_fields = [
+ "initialize", "load", "generate", "warmup", "cleanup",
+ "apply_chat_template", "detect_thinking_prompt_suffix",
+ ]
+
+ assert all(field in field_names for field in required_fields)
+
+def test_tokenizer_protocol_importable() -> None:
+ from exo.shared.types.worker.tokenizer import (
+ Tokenizer, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ )
+
+@requires_mlx
+def test_mlx_engine_has_postprocessing_importable() -> None:
+ from exo.worker.engines.mlx.generator.generate import (
+ mlx_generate_with_postprocessing, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ )
+
+@requires_mlx
+def test_mlx_engine_has_postprocessing_signature() -> None:
+ from exo.worker.engines.mlx.generator.generate import (
+ mlx_generate_with_postprocessing,
+ )
+ sig = inspect.signature(mlx_generate_with_postprocessing)
+ params = list(sig.parameters.keys())
+
+ expected_params = ["model", "tokenizer", "model_id"]
+
+ assert all(param in params for param in expected_params)
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/__init__.py b/src/exo/worker/tests/unittests/test_tinygrad/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_cache.py b/src/exo/worker/tests/unittests/test_tinygrad/test_cache.py
new file mode 100644
index 0000000000..a5e5833ebb
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_tinygrad/test_cache.py
@@ -0,0 +1,58 @@
+# pyright: reportUnknownMemberType=false
+import math
+
+from tinygrad.tensor import Tensor
+
+
+def test_cache_initial_seq_len():
+ """Fresh cache should have seq_len equal to max_seq_len (pre-allocated)."""
+ from exo.worker.engines.tinygrad.cache import KVCache
+
+ cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128)
+ assert cache.seq_len == 128
+
+
+def test_cache_update_writes_to_position():
+ """Update should write keys/values at the specified position."""
+ from exo.worker.engines.tinygrad.cache import KVCache
+
+ cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128)
+ k = Tensor.ones(1, 8, 3, 64)
+ v = Tensor.ones(1, 8, 3, 64)
+ k_out, _v_out = cache.update(0, k, v, position=0)
+
+ # Output shape is always the full pre-allocated buffer
+ assert k_out.shape == (1, 8, 128, 64)
+ # The written positions should be non-zero
+ assert float(k_out[0, 0, 0, 0].item()) != 0.0
+ assert float(k_out[0, 0, 2, 0].item()) != 0.0
+ # Positions beyond the write should remain zero
+ assert float(k_out[0, 0, 3, 0].item()) == 0.0
+
+
+def test_cache_update_tensor_position():
+ """Update with Tensor position should write a single token."""
+ from exo.worker.engines.tinygrad.cache import KVCache
+
+ cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128)
+ k = Tensor.ones(1, 8, 1, 64) * 5.0
+ v = Tensor.ones(1, 8, 1, 64) * 5.0
+ pos = Tensor([10]).reshape(1, 1, 1, 1)
+ k_out, _v_out = cache.update(0, k, v, position=pos)
+
+ # Position 10 should have the value we wrote
+ assert math.isclose(float(k_out[0, 0, 10, 0].item()), 5.0, rel_tol=1e-2)
+ # Other positions should remain zero
+ assert float(k_out[0, 0, 0, 0].item()) == 0.0
+
+
+def test_cache_layers_are_independent():
+ """Updates to layer 0 should not affect layer 1."""
+ from exo.worker.engines.tinygrad.cache import KVCache
+
+ cache = KVCache(num_layers=4, num_kv_heads=8, head_dim=64, max_seq_len=128)
+ k = Tensor.ones(1, 8, 3, 64)
+ v = Tensor.ones(1, 8, 3, 64)
+ cache.update(0, k, v, position=0)
+ # Layer 1 should still be all zeros (pre-allocated but untouched)
+ assert float(cache.keys[1].sum().item()) == 0.0
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_forward.py b/src/exo/worker/tests/unittests/test_tinygrad/test_forward.py
new file mode 100644
index 0000000000..2be8f52274
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_tinygrad/test_forward.py
@@ -0,0 +1,48 @@
+# pyright: reportUnknownMemberType=false
+import pytest
+from tinygrad.dtype import dtypes
+from tinygrad.tensor import Tensor
+
+
+@pytest.mark.slow
+def test_forward_pass_shape():
+ """Forward pass should produce logits of shape (batch, seq, vocab_size)."""
+ from pathlib import Path
+
+ from exo.shared.model_config import parse_model_config
+ from exo.worker.engines.tinygrad.forward import forward_pass
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2)
+
+ input_ids = Tensor([[1, 2, 3, 4]], dtype=dtypes.int32)
+ _logits, cache = forward_pass(weights, input_ids, cache=None, position_offset=0)
+ assert _logits.shape == (1, 4, config.vocab_size)
+ assert cache.seq_len == 4
+
+@pytest.mark.slow
+def test_decode_step_after_prefill():
+ """A single-token decode step should extend the cache by 1."""
+ from pathlib import Path
+
+ from exo.shared.model_config import parse_model_config
+ from exo.worker.engines.tinygrad.forward import forward_pass
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2)
+
+ input_ids = Tensor([[1, 2, 3, 4]], dtype=dtypes.int32)
+ _logits, cache = forward_pass(weights, input_ids, cache=None, position_offset=0)
+
+ next_input = Tensor([[5]], dtype=dtypes.int32)
+ logits2, cache2 = forward_pass(weights, next_input, cache, position_offset=4)
+ assert logits2.shape == (1, 1, config.vocab_size)
+ assert cache2.seq_len == 5
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_generate.py b/src/exo/worker/tests/unittests/test_tinygrad/test_generate.py
new file mode 100644
index 0000000000..be64c47b0d
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_tinygrad/test_generate.py
@@ -0,0 +1,184 @@
+# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
+import pytest
+
+
+def test_constants_exist():
+ """Default constants should be defined."""
+ from exo.worker.engines.tinygrad.constants import (
+ DEFAULT_MAX_TOKENS,
+ DEFAULT_TEMPERATURE,
+ DEFAULT_TOP_P,
+ )
+
+ assert isinstance(DEFAULT_MAX_TOKENS, int)
+ assert DEFAULT_MAX_TOKENS > 0
+ assert isinstance(DEFAULT_TEMPERATURE, float)
+ assert isinstance(DEFAULT_TOP_P, float)
+
+def test_initialize_tinygrad_callable():
+ """initialize_tinygrad must be importable and callable."""
+ from exo.worker.engines.tinygrad.utils_tinygrad import initialize_tinygrad
+
+ assert callable(initialize_tinygrad)
+
+def test_load_tinygrad_items_callable():
+ """load_tinygrad_items must be importable and callable."""
+ from exo.worker.engines.tinygrad.utils_tinygrad import load_tinygrad_items
+
+ assert callable(load_tinygrad_items)
+
+def test_tinygrad_generate_accepts_prompt_parameter():
+ """tinygrad_generate must accept prompt as a parameter (matching MLX pattern)."""
+ import inspect
+
+ from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate
+
+ sig = inspect.signature(tinygrad_generate)
+ assert "prompt" in sig.parameters, "tinygrad_generate must accept a 'prompt' parameter"
+ assert "model" in sig.parameters
+ assert "tokenizer" in sig.parameters
+ assert "task" in sig.parameters
+
+@pytest.mark.slow
+def test_generate_yields_generation_responses():
+ """tinygrad_generate should yield GenerationResponse objects."""
+ from pathlib import Path
+ from unittest.mock import MagicMock
+
+ from exo.shared.model_config import parse_model_config
+ from exo.shared.types.worker.runner_response import GenerationResponse
+ from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2)
+
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
+
+ task = MagicMock()
+ task.max_tokens = 3
+ task.temperature = 0.0
+ task.top_p = 0.9
+ task.logprobs = False
+ task.top_logprobs = 0
+
+ # Runner computes prompt via apply_chat_template, then passes it
+ prompt = "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
+
+ responses = list(tinygrad_generate(weights, tokenizer, task, prompt=prompt))
+ assert len(responses) > 0
+ assert all(isinstance(r, GenerationResponse) for r in responses)
+
+ # Last response should have finish_reason, stats, and usage
+ assert responses[-1].finish_reason is not None
+ assert responses[-1].stats is not None
+ assert responses[-1].usage is not None
+ assert responses[-1].usage.prompt_tokens > 0
+ assert responses[-1].usage.completion_tokens > 0
+
+ # Intermediate responses should have no stats/usage
+ if len(responses) > 1:
+ assert responses[0].stats is None
+ assert responses[0].usage is None
+
+@pytest.mark.slow
+def test_generate_populates_logprobs_when_requested():
+ """When task.logprobs=True, responses should include logprob and top_logprobs."""
+ from pathlib import Path
+ from unittest.mock import MagicMock
+
+ from exo.shared.model_config import parse_model_config
+ from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2)
+
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
+
+ task = MagicMock()
+ task.max_tokens = 2
+ task.temperature = 0.0
+ task.top_p = 0.9
+ task.logprobs = True
+ task.top_logprobs = 3
+
+ prompt = "<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n"
+ responses = list(tinygrad_generate(weights, tokenizer, task, prompt=prompt))
+
+ for r in responses:
+ assert r.logprob is not None, "logprob must be set when task.logprobs=True"
+ assert r.logprob <= 0.0, "logprob must be <= 0"
+ assert r.top_logprobs is not None
+ assert len(r.top_logprobs) == 3
+
+@pytest.mark.slow
+def test_generate_omits_logprobs_when_not_requested():
+ """When task.logprobs=False, responses should have logprob=None."""
+ from pathlib import Path
+ from unittest.mock import MagicMock
+
+ from exo.shared.model_config import parse_model_config
+ from exo.worker.engines.tinygrad.generator.generate import tinygrad_generate
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2)
+
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
+
+ task = MagicMock()
+ task.max_tokens = 2
+ task.temperature = 0.0
+ task.top_p = 0.9
+ task.logprobs = False
+ task.top_logprobs = 0
+
+ prompt = "<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n"
+ responses = list(tinygrad_generate(weights, tokenizer, task, prompt=prompt))
+
+ for r in responses:
+ assert r.logprob is None
+ assert r.top_logprobs is None
+
+@pytest.mark.slow
+def test_warmup_runs_full_generation():
+ """warmup_inference should run a real generation loop, not just a forward pass."""
+ from pathlib import Path
+
+ from exo.shared.model_config import parse_model_config
+ from exo.worker.engines.tinygrad.generator.generate import warmup_inference
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2)
+
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
+
+ tokens_generated = warmup_inference(model=weights, tokenizer=tokenizer)
+ assert tokens_generated >= 1, "warmup must generate at least 1 token"
+ assert tokens_generated <= 10, "warmup should be short (not full generation)"
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_layers.py b/src/exo/worker/tests/unittests/test_tinygrad/test_layers.py
new file mode 100644
index 0000000000..5c0933cbf9
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_tinygrad/test_layers.py
@@ -0,0 +1,120 @@
+# pyright: reportUnknownMemberType=false
+from tinygrad.tensor import Tensor
+
+# --- RMS Norm ---
+
+def test_rms_norm_shape():
+ """Output shape must match input shape."""
+ from exo.worker.engines.tinygrad.layers.normalization import rms_norm
+
+ x = Tensor.randn(1, 4, 2048)
+ weight = Tensor.ones(2048)
+ out = rms_norm(x, weight, eps=1e-6)
+ assert out.shape == (1, 4, 2048)
+
+def test_rms_norm_unit_weight():
+ """With unit weight, RMS norm output should have the same shape and be a different tensor."""
+ from exo.worker.engines.tinygrad.layers.normalization import rms_norm
+
+ x = Tensor.randn(1, 4, 2048)
+ weight = Tensor.ones(2048)
+ out = rms_norm(x, weight, eps=1e-6)
+ assert out.shape == x.shape
+ # Verify the computation graph produces the right intermediate shapes
+ rms = ((out * out).mean(axis=-1)).sqrt()
+ assert rms.shape == (1, 4)
+
+def test_rms_norm_scales_with_weight():
+ """Doubling the weight should produce an output with the same shape."""
+ from exo.worker.engines.tinygrad.layers.normalization import rms_norm
+
+ x = Tensor.randn(1, 4, 2048)
+ weight_1x = Tensor.ones(2048)
+ weight_2x = Tensor.ones(2048) * 2.0
+ out_1x = rms_norm(x, weight_1x, eps=1e-6)
+ out_2x = rms_norm(x, weight_2x, eps=1e-6)
+ assert out_1x.shape == out_2x.shape == (1, 4, 2048)
+ # Verify the difference tensor has the right shape
+ diff = (out_2x - out_1x * 2.0)
+ assert diff.shape == (1, 4, 2048)
+
+# --- RoPE ---
+
+def test_rope_frequency_shapes():
+ """cos/sin frequency tables should have shape (max_seq_len, head_dim//2)."""
+ from exo.worker.engines.tinygrad.layers.rotary import compute_rope_frequencies
+
+ cos_f, sin_f = compute_rope_frequencies(head_dim=64, max_seq_len=128)
+ assert cos_f.shape == (128, 32)
+ assert sin_f.shape == (128, 32)
+
+def test_rope_changes_with_position():
+ """apply_rope at different offsets should produce tensors of the same shape."""
+ from exo.worker.engines.tinygrad.layers.rotary import (
+ apply_rope,
+ compute_rope_frequencies,
+ )
+
+ cos_f, sin_f = compute_rope_frequencies(head_dim=64, max_seq_len=128)
+ x = Tensor.randn(1, 8, 4, 64) # (batch, heads, seq=4, head_dim)
+ out_0 = apply_rope(x, cos_f, sin_f, position_offset=0)
+ out_10 = apply_rope(x, cos_f, sin_f, position_offset=10)
+ assert out_0.shape == out_10.shape == (1, 8, 4, 64)
+ # Verify the difference tensor can be constructed (computation graph is valid)
+ diff = (out_0 - out_10).abs()
+ assert diff.shape == (1, 8, 4, 64)
+
+def test_rope_preserves_shape():
+ """apply_rope must not change the tensor shape."""
+ from exo.worker.engines.tinygrad.layers.rotary import (
+ apply_rope,
+ compute_rope_frequencies,
+ )
+
+ cos_f, sin_f = compute_rope_frequencies(head_dim=64, max_seq_len=128)
+ x = Tensor.randn(1, 8, 4, 64)
+ out = apply_rope(x, cos_f, sin_f, position_offset=0)
+ assert out.shape == x.shape
+
+# --- Embedding ---
+
+def test_embedding_lookup():
+ """Embedding lookup should index into the weight table."""
+ from exo.worker.engines.tinygrad.layers.embedding import apply_embedding
+
+ embed = Tensor.randn(100, 32) # vocab=100, dim=32
+ ids = Tensor([[0, 1, 2]])
+ out = apply_embedding(embed, ids)
+ assert out.shape == (1, 3, 32)
+
+def test_lm_head_shape():
+ """LM head should project hidden_size to vocab_size."""
+ from exo.worker.engines.tinygrad.layers.embedding import apply_lm_head
+
+ x = Tensor.randn(1, 4, 2048)
+ lm_head = Tensor.randn(128256, 2048) # (vocab, hidden)
+ out = apply_lm_head(x, lm_head)
+ assert out.shape == (1, 4, 128256)
+
+# --- MLP ---
+
+def test_swiglu_mlp_shape():
+ """SwiGLU MLP: hidden_size in, hidden_size out (merged gate+up proj)."""
+ from exo.worker.engines.tinygrad.layers.mlp import swiglu_mlp
+
+ x = Tensor.randn(1, 4, 2048)
+ gate_up = Tensor.randn(16384, 2048) # gate(8192) + up(8192) merged
+ down = Tensor.randn(2048, 8192)
+ out = swiglu_mlp(x, gate_up, down)
+ assert out.shape == (1, 4, 2048)
+
+# --- Attention ---
+
+def test_linear_forward_shape():
+ """linear_forward(x, weight) should compute x @ weight.T."""
+ from exo.worker.engines.tinygrad.layers.attention import linear_forward
+
+ x = Tensor.randn(1, 4, 2048)
+ weight = Tensor.randn(512, 2048) # (out, in)
+ out = linear_forward(x, weight)
+ assert out.shape == (1, 4, 512)
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_quantization.py b/src/exo/worker/tests/unittests/test_tinygrad/test_quantization.py
new file mode 100644
index 0000000000..89809f38c4
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_tinygrad/test_quantization.py
@@ -0,0 +1,401 @@
+from __future__ import annotations
+
+# pyright: reportUnknownMemberType=false
+import pytest
+from tinygrad.device import Device
+from tinygrad.dtype import dtypes
+from tinygrad.tensor import Tensor
+
+from exo.shared.model_config import ModelConfig
+
+Device.DEFAULT = "CPU"
+
+# ── Helpers ──
+
+def _make_model_config(
+ hidden_size: int = 2048,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int | None = None,
+ intermediate_size: int = 8192,
+ vocab_size: int = 128256,
+) -> ModelConfig:
+ """Build a ModelConfig with sensible defaults for quantization tests."""
+ from exo.shared.architecture.llama import LLAMA_SPEC
+
+ n_kv = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
+ return ModelConfig(
+ architecture_spec=LLAMA_SPEC,
+ num_hidden_layers=32,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=n_kv,
+ vocab_size=vocab_size,
+ head_dim=hidden_size // num_attention_heads,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ max_position_embeddings=4096,
+ rms_norm_eps=1e-6,
+ tie_word_embeddings=False,
+ quantization_config=None,
+ )
+
+# ── Packing / Unpacking ──
+
+def test_calculate_pack_factor():
+ """pack_factor = 32 // bits."""
+ from exo.worker.engines.tinygrad.quantization.packing import calculate_pack_factor
+
+ assert calculate_pack_factor(4) == 8
+ assert calculate_pack_factor(8) == 4
+ assert calculate_pack_factor(2) == 16
+
+def test_pack_bits_output_shape():
+ """Packed last dim should be original_last_dim / pack_factor."""
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ original = Tensor.zeros(64, 4096, dtype=dtypes.float16)
+ packed = pack_bits(original, bits=4)
+ assert packed.tensor.shape == (64, 512)
+
+def test_packed_tensor_metadata():
+ """PackedTensor must carry original_shape, pack_factor, bits."""
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ original = Tensor.zeros(64, 4096, dtype=dtypes.float16)
+ packed = pack_bits(original, bits=4)
+ assert packed.original_shape == (64, 4096)
+ assert packed.pack_factor == 8
+ assert packed.bits == 4
+
+def test_packed_tensor_is_frozen():
+ """PackedTensor must be immutable (frozen dataclass)."""
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ packed = pack_bits(Tensor.zeros(4, 8, dtype=dtypes.float16), bits=4)
+ with pytest.raises(AttributeError):
+ packed.bits = 8
+
+def test_pack_unpack_roundtrip_4bit():
+ """pack then unpack should recover original values (4-bit)."""
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits
+
+ original = Tensor([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=dtypes.float16)
+ packed = pack_bits(original, bits=4)
+ unpacked = unpack_bits(packed)
+ assert unpacked.shape == original.shape
+ for i in range(8):
+ assert abs(unpacked[0, i].item() - original[0, i].item()) < 1e-3
+
+def test_pack_unpack_roundtrip_8bit():
+ """pack then unpack should recover original values (8-bit)."""
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits
+
+ original = Tensor([[10, 20, 30, 40]], dtype=dtypes.float16)
+ packed = pack_bits(original, bits=8)
+ unpacked = unpack_bits(packed)
+ assert unpacked.shape == original.shape
+ for i in range(4):
+ assert abs(unpacked[0, i].item() - original[0, i].item()) < 1e-3
+
+def test_pack_unpack_non_divisible_dim():
+ """Unpacking handles last dims not divisible by pack_factor."""
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits
+
+ original = Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=dtypes.float16)
+ packed = pack_bits(original, bits=4)
+ unpacked = unpack_bits(packed)
+ assert unpacked.shape == (1, 10)
+ for i in range(10):
+ assert abs(unpacked[0, i].item() - original[0, i].item()) < 1e-3
+
+def test_unpack_dtype_is_float16():
+ """Unpacked tensor must be float16."""
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits, unpack_bits
+
+ original = Tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=dtypes.float16)
+ unpacked = unpack_bits(pack_bits(original, bits=4))
+ assert unpacked.dtype == dtypes.float16
+
+# ── Dequantization ──
+
+def test_affine_dequantize_identity():
+ """scale=1, bias=0 → identity."""
+ from exo.worker.engines.tinygrad.quantization.dequantization import (
+ affine_dequantize,
+ )
+
+ quantized = Tensor([[1.0, 2.0, 3.0, 4.0]])
+ scales = Tensor([[1.0]])
+ biases = Tensor([[0.0]])
+ result = affine_dequantize(quantized, scales, biases, group_size=4)
+ assert result.shape == (1, 4)
+ for i in range(4):
+ assert abs(result[0, i].item() - quantized[0, i].item()) < 1e-5
+
+def test_affine_dequantize_scaling():
+ """scale=2 should double values."""
+ from exo.worker.engines.tinygrad.quantization.dequantization import (
+ affine_dequantize,
+ )
+
+ quantized = Tensor([[1.0, 2.0, 3.0, 4.0]])
+ scales = Tensor([[2.0]])
+ biases = Tensor([[0.0]])
+ result = affine_dequantize(quantized, scales, biases, group_size=4)
+ for i in range(4):
+ assert abs(result[0, i].item() - quantized[0, i].item() * 2.0) < 1e-5
+
+def test_affine_dequantize_bias_offset():
+ """bias=10 should offset by 10."""
+ from exo.worker.engines.tinygrad.quantization.dequantization import (
+ affine_dequantize,
+ )
+
+ quantized = Tensor([[0.0, 0.0, 0.0, 0.0]])
+ scales = Tensor([[1.0]])
+ biases = Tensor([[10.0]])
+ result = affine_dequantize(quantized, scales, biases, group_size=4)
+ for i in range(4):
+ assert abs(result[0, i].item() - 10.0) < 1e-5
+
+def test_affine_dequantize_multiple_groups():
+ """Each group uses its own scale and bias."""
+ from exo.worker.engines.tinygrad.quantization.dequantization import (
+ affine_dequantize,
+ )
+
+ quantized = Tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]])
+ scales = Tensor([[2.0, 3.0]]) # 2 groups of 4
+ biases = Tensor([[0.0, 10.0]])
+ result = affine_dequantize(quantized, scales, biases, group_size=4)
+ assert abs(result[0, 0].item() - 2.0) < 1e-5 # group 0: 1*2+0
+ assert abs(result[0, 4].item() - 13.0) < 1e-5 # group 1: 1*3+10
+
+def test_affine_dequantize_2d_weight_matrix():
+ """Dequantize a realistic 2D weight matrix [out_features, in_features]."""
+ from exo.worker.engines.tinygrad.quantization.dequantization import (
+ affine_dequantize,
+ )
+
+ out_features, in_features, group_size = 32, 64, 32
+ quantized = Tensor.ones(out_features, in_features)
+ num_groups = in_features // group_size
+ scales = Tensor.ones(out_features, num_groups) * 0.5
+ biases = Tensor.zeros(out_features, num_groups)
+ result = affine_dequantize(quantized, scales, biases, group_size=group_size)
+ assert result.shape == (out_features, in_features)
+ assert abs(result[0, 0].item() - 0.5) < 1e-5
+
+# ── QuantizedLinear ──
+
+def test_quantized_linear_is_callable():
+ """QuantizedLinear instances must be callable via __call__."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ weight = Tensor.ones(16, 32, dtype=dtypes.float16)
+ packed = pack_bits(weight, bits=4)
+ layer = QuantizedLinear(
+ weight_q=packed,
+ scales=Tensor.ones(16, 1),
+ biases=Tensor.zeros(16, 1),
+ group_size=32,
+ )
+ x = Tensor.randn(1, 4, 32)
+ out = layer(x)
+ assert out.shape == (1, 4, 16)
+
+def test_quantized_linear_output_shape():
+ """QuantizedLinear(in=64, out=32) → (batch, seq, 32)."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ weight = Tensor.ones(32, 64, dtype=dtypes.float16)
+ packed = pack_bits(weight, bits=4)
+ layer = QuantizedLinear(
+ weight_q=packed,
+ scales=Tensor.ones(32, 1),
+ biases=Tensor.zeros(32, 1),
+ group_size=64,
+ )
+ x = Tensor.randn(1, 8, 64)
+ out = layer(x)
+ assert out.shape == (1, 8, 32)
+
+def test_quantized_linear_dequantize_returns_correct_shape():
+ """dequantize() returns a tensor with the original weight shape."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ weight = Tensor.ones(16, 32, dtype=dtypes.float16)
+ packed = pack_bits(weight, bits=4)
+ layer = QuantizedLinear(
+ weight_q=packed,
+ scales=Tensor.ones(16, 1),
+ biases=Tensor.zeros(16, 1),
+ group_size=32,
+ )
+ dequantized = layer.dequantize()
+ assert dequantized.shape == (16, 32)
+
+def test_quantized_linear_in_out_features():
+ """in_features and out_features reflect the original weight shape."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ weight = Tensor.ones(128, 256, dtype=dtypes.float16)
+ packed = pack_bits(weight, bits=4)
+ layer = QuantizedLinear(
+ weight_q=packed,
+ scales=Tensor.ones(128, 4),
+ biases=Tensor.zeros(128, 4),
+ group_size=64,
+ )
+ assert layer.out_features == 128
+ assert layer.in_features == 256
+
+def test_quantized_linear_no_bias_by_default():
+ """QuantizedLinear should not add a bias term by default."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ weight = Tensor.ones(16, 32, dtype=dtypes.float16)
+ packed = pack_bits(weight, bits=4)
+ layer = QuantizedLinear(
+ weight_q=packed,
+ scales=Tensor.ones(16, 1),
+ biases=Tensor.zeros(16, 1),
+ group_size=32,
+ )
+ assert layer.bias is None
+
+def test_quantized_linear_is_final():
+ """QuantizedLinear must be decorated with @final."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+
+ assert getattr(QuantizedLinear, "__final__", False)
+
+# ── QuantizedEmbedding ──
+
+def test_quantized_embedding_lookup():
+ """QuantizedEmbedding must return correct shape for index lookup."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedEmbedding
+ from exo.worker.engines.tinygrad.quantization.packing import pack_bits
+
+ weight = Tensor.ones(100, 32, dtype=dtypes.float16)
+ packed = pack_bits(weight, bits=4)
+ layer = QuantizedEmbedding(
+ num_embeddings=100,
+ embedding_dim=32,
+ weight_q=packed,
+ scales=Tensor.ones(100, 1),
+ biases=Tensor.zeros(100, 1),
+ group_size=32,
+ )
+ ids = Tensor([[0, 1, 2]])
+ out = layer(ids)
+ assert out.shape == (1, 3, 32)
+
+def test_quantized_embedding_is_final():
+ """QuantizedEmbedding must be decorated with @final."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedEmbedding
+
+ assert getattr(QuantizedEmbedding, "__final__", False)
+
+# ── Shape Inference (requires Phase 1 ModelConfig) ──
+
+def test_infer_shape_q_proj():
+ """q_proj shape = (hidden_size, hidden_size)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(hidden_size=2048, num_attention_heads=32)
+ shape = infer_weight_shape("model.layers.0.self_attn.q_proj.weight", config)
+ assert shape == (2048, 2048)
+
+def test_infer_shape_k_proj_gqa():
+ """k_proj with GQA = (num_kv_heads * head_dim, hidden_size)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(
+ hidden_size=2048, num_attention_heads=32, num_key_value_heads=8,
+ )
+ # head_dim = 2048/32 = 64, kv_dim = 8*64 = 512
+ shape = infer_weight_shape("model.layers.0.self_attn.k_proj.weight", config)
+ assert shape == (512, 2048)
+
+def test_infer_shape_v_proj_gqa():
+ """v_proj with GQA = (num_kv_heads * head_dim, hidden_size)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(
+ hidden_size=2048, num_attention_heads=32, num_key_value_heads=8,
+ )
+ shape = infer_weight_shape("model.layers.0.self_attn.v_proj.weight", config)
+ assert shape == (512, 2048)
+
+def test_infer_shape_k_proj_mha():
+ """When num_kv_heads == num_heads (MHA), k_proj = (hidden, hidden)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(
+ hidden_size=2048, num_attention_heads=32, num_key_value_heads=32,
+ )
+ shape = infer_weight_shape("model.layers.0.self_attn.k_proj.weight", config)
+ assert shape == (2048, 2048)
+
+def test_detect_layer_type_v_proj():
+ """Must detect v_proj correctly (not w_proj)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import detect_layer_type
+
+ assert detect_layer_type("model.layers.0.self_attn.v_proj.weight") == "v_proj"
+
+def test_infer_shape_gate_proj():
+ """gate_proj = (intermediate_size, hidden_size)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(hidden_size=2048, intermediate_size=8192)
+ shape = infer_weight_shape("model.layers.0.mlp.gate_proj.weight", config)
+ assert shape == (8192, 2048)
+
+def test_infer_shape_down_proj():
+ """down_proj = (hidden_size, intermediate_size)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(hidden_size=2048, intermediate_size=8192)
+ shape = infer_weight_shape("model.layers.0.mlp.down_proj.weight", config)
+ assert shape == (2048, 8192)
+
+def test_infer_shape_embed_tokens():
+ """embed_tokens = (vocab_size, hidden_size)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(hidden_size=2048, vocab_size=128256)
+ shape = infer_weight_shape("model.embed_tokens.weight", config)
+ assert shape == (128256, 2048)
+
+
+def test_infer_shape_lm_head():
+ """lm_head = (vocab_size, hidden_size)."""
+ from exo.worker.engines.tinygrad.quantization.shapes import infer_weight_shape
+
+ config = _make_model_config(hidden_size=2048, vocab_size=128256)
+ shape = infer_weight_shape("lm_head.weight", config)
+ assert shape == (128256, 2048)
+
+# ── Package Exports ──
+
+def test_package_exports_quantized_linear():
+ from exo.worker.engines.tinygrad.quantization import (
+ QuantizedLinear, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ )
+
+def test_package_exports_quantized_embedding():
+ from exo.worker.engines.tinygrad.quantization import (
+ QuantizedEmbedding, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ )
+
+def test_package_exports_packed_tensor():
+ from exo.worker.engines.tinygrad.quantization import (
+ PackedTensor, # noqa: F401 # pyright: ignore[reportUnusedImport]
+ )
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_sampling.py b/src/exo/worker/tests/unittests/test_tinygrad/test_sampling.py
new file mode 100644
index 0000000000..322681231d
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_tinygrad/test_sampling.py
@@ -0,0 +1,94 @@
+# pyright: reportUnknownMemberType=false
+from tinygrad.tensor import Tensor
+
+
+def test_sample_result_structure():
+ """sample_token should return a SampleResult with token_id, logprob, top_logprobs."""
+ from exo.worker.engines.tinygrad.sampling import SampleResult, sample_token
+
+ logits = Tensor([[[0.1, 0.2, 0.9, 0.3]]])
+ result = sample_token(logits, temperature=0.0)
+ assert isinstance(result, SampleResult)
+ assert isinstance(result.token_id, int)
+ assert isinstance(result.logprob, float)
+ assert isinstance(result.top_logprobs, list)
+
+def test_greedy_sampling():
+ """Temperature=0 should return the argmax token."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor([[[0.1, 0.2, 0.9, 0.3]]]) # token 2 has highest logit
+ result = sample_token(logits, temperature=0.0)
+ assert result.token_id == 2
+
+def test_temperature_zero_is_deterministic():
+ """Greedy sampling must always produce the same result."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor.randn(1, 1, 1000)
+ r1 = sample_token(logits, temperature=0.0)
+ r2 = sample_token(logits, temperature=0.0)
+ assert r1.token_id == r2.token_id
+
+def test_sampling_returns_valid_token():
+ """Sampled token must be in [0, vocab_size)."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ vocab_size = 100
+ logits = Tensor.randn(1, 1, vocab_size)
+ result = sample_token(logits, temperature=0.7)
+ assert 0 <= result.token_id < vocab_size
+
+def test_logprob_is_negative():
+ """Log-probabilities must be <= 0 (probabilities are in [0, 1])."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor.randn(1, 1, 100)
+ result = sample_token(logits, temperature=0.7, request_logprobs=True)
+ assert result.logprob <= 0.0
+
+def test_greedy_logprob_is_highest():
+ """Greedy token should have the highest logprob."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor([[[0.1, 0.2, 0.9, 0.3]]])
+ result = sample_token(logits, temperature=0.0, top_logprobs_count=4, request_logprobs=True)
+ # The selected token's logprob should match the top entry
+ assert result.token_id == result.top_logprobs[0][0]
+ assert abs(result.logprob - result.top_logprobs[0][1]) < 1e-5
+
+def test_top_logprobs_count():
+ """top_logprobs should return exactly the requested count."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor.randn(1, 1, 100)
+ result = sample_token(logits, temperature=0.0, top_logprobs_count=5, request_logprobs=True)
+ assert len(result.top_logprobs) == 5
+
+def test_top_logprobs_empty_when_not_requested():
+ """top_logprobs should be empty list when count is 0."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor.randn(1, 1, 100)
+ result = sample_token(logits, temperature=0.0, top_logprobs_count=0)
+ assert result.top_logprobs == []
+
+def test_top_logprobs_sorted_descending():
+ """top_logprobs entries should be sorted by logprob (highest first)."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor.randn(1, 1, 100)
+ result = sample_token(logits, temperature=0.0, top_logprobs_count=5, request_logprobs=True)
+ logprobs = [lp for _, lp in result.top_logprobs]
+ assert logprobs == sorted(logprobs, reverse=True)
+
+def test_high_temperature_is_more_random():
+ """Higher temperature should produce more token diversity across runs."""
+ from exo.worker.engines.tinygrad.sampling import sample_token
+
+ logits = Tensor.randn(1, 1, 10)
+ low_temp_tokens = {sample_token(logits, temperature=0.01).token_id for _ in range(20)}
+ high_temp_tokens = {sample_token(logits, temperature=2.0).token_id for _ in range(20)}
+ # High temp should generally produce more unique tokens
+ # (probabilistic — use a lenient check)
+ assert len(high_temp_tokens) >= len(low_temp_tokens)
diff --git a/src/exo/worker/tests/unittests/test_tinygrad/test_weight_loader.py b/src/exo/worker/tests/unittests/test_tinygrad/test_weight_loader.py
new file mode 100644
index 0000000000..1773ffdc44
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_tinygrad/test_weight_loader.py
@@ -0,0 +1,214 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+# pyright: reportUnknownMemberType=false
+import pytest
+from tinygrad.device import Device
+from tinygrad.dtype import dtypes
+from tinygrad.tensor import Tensor
+
+from exo.shared.model_config import ModelConfig, QuantizationConfig
+
+Device.DEFAULT = "CPU"
+
+# ── Helpers ──
+
+def _make_model_config(
+ hidden_size: int = 64,
+ num_attention_heads: int = 4,
+ num_key_value_heads: int | None = None,
+ intermediate_size: int = 128,
+ vocab_size: int = 256,
+ quantized: bool = False,
+) -> ModelConfig:
+ """Build a ModelConfig with small defaults for weight-loader tests."""
+ from exo.shared.architecture.llama import LLAMA_SPEC
+
+ n_kv = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
+ return ModelConfig(
+ architecture_spec=LLAMA_SPEC,
+ num_hidden_layers=2,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=n_kv,
+ vocab_size=vocab_size,
+ head_dim=hidden_size // num_attention_heads,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ max_position_embeddings=4096,
+ rms_norm_eps=1e-6,
+ tie_word_embeddings=False,
+ quantization_config=QuantizationConfig(bits=4, group_size=32) if quantized else None,
+ )
+
+
+def _mlx_quantized_raw(
+ key_prefix: str,
+ out_features: int,
+ in_features: int,
+ bits: int = 4,
+ group_size: int = 32,
+) -> dict[str, Tensor]:
+ """Create a synthetic MLX-format quantized weight dict (.weight + .scales + .biases)."""
+ pack_factor = 32 // bits
+ packed_dim = in_features // pack_factor
+ num_groups = in_features // group_size
+
+ return {
+ f"{key_prefix}.weight": Tensor.zeros(out_features, packed_dim, dtype=dtypes.uint32),
+ f"{key_prefix}.scales": Tensor.ones(out_features, num_groups, dtype=dtypes.float16),
+ f"{key_prefix}.biases": Tensor.zeros(out_features, num_groups, dtype=dtypes.float16),
+ }
+
+
+# ── NamedTuple structure tests (pre-existing) ──
+
+def test_layer_weights_is_named_tuple():
+ """LayerWeights should be a NamedTuple with expected fields."""
+ from exo.worker.engines.tinygrad.weight_loader import LayerWeights
+
+ assert hasattr(LayerWeights, "_fields")
+ fields = LayerWeights._fields
+ assert "qkv_proj" in fields
+ assert "o_proj" in fields
+ assert "gate_up_proj" in fields
+ assert "down_proj" in fields
+ assert "input_norm" in fields
+ assert "post_attn_norm" in fields
+
+def test_transformer_weights_is_named_tuple():
+ """TransformerWeights should contain embed, lm_head, final_norm, layers, config."""
+ from exo.worker.engines.tinygrad.weight_loader import TransformerWeights
+
+ fields = TransformerWeights._fields
+ assert "embed_tokens" in fields
+ assert "lm_head" in fields
+ assert "final_norm" in fields
+ assert "layers" in fields
+ assert "config" in fields
+
+@pytest.mark.slow
+def test_load_llama_weights():
+ """Load 2 layers from a real Llama model and verify structure."""
+ from exo.shared.model_config import parse_model_config
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=0, end_layer=2)
+ assert len(weights.layers) == 2
+ assert weights.embed_tokens is not None
+ assert weights.final_norm is not None
+ assert weights.lm_head is not None
+
+@pytest.mark.slow
+def test_load_respects_layer_range():
+ """Loading layers 2-4 should produce exactly 2 LayerWeights."""
+ from exo.shared.model_config import parse_model_config
+ from exo.worker.engines.tinygrad.weight_loader import load_transformer_weights
+
+ model_path = Path.home() / ".cache/exo/downloads/mlx-community/Llama-3.2-1B-Instruct-4bit"
+ if not model_path.exists():
+ pytest.skip("Model not downloaded")
+ config = parse_model_config(model_path / "config.json")
+ weights = load_transformer_weights(model_path, config, start_layer=2, end_layer=4)
+ assert len(weights.layers) == 2
+
+def test_load_missing_safetensors_raises(tmp_path: Path):
+ """Loading from an empty directory should raise FileNotFoundError."""
+ from exo.worker.engines.tinygrad.weight_loader import (
+ _load_all_safetensors, # pyright: ignore[reportPrivateUsage]
+ )
+
+ with pytest.raises(FileNotFoundError, match="No .safetensors"):
+ _load_all_safetensors(tmp_path)
+
+# ── _build_weight: MLX quantized format ──
+
+def test_build_weight_mlx_quantized_linear():
+ """MLX quantized linear (.weight + .scales + .biases with quantization_config) → QuantizedLinear."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.weight_loader import (
+ _build_weight, # pyright: ignore[reportPrivateUsage]
+ )
+
+ config = _make_model_config(quantized=True)
+ # o_proj shape: (hidden_size, hidden_size) = (64, 64)
+ raw = _mlx_quantized_raw(
+ "model.layers.0.self_attn.o_proj", out_features=64, in_features=64,
+ )
+ result = _build_weight(raw, "model.layers.0.self_attn.o_proj.weight", config)
+ assert isinstance(result, QuantizedLinear)
+
+def test_build_weight_mlx_quantized_embedding():
+ """MLX quantized embedding (.weight + .scales + .biases with is_embedding=True) → QuantizedEmbedding."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedEmbedding
+ from exo.worker.engines.tinygrad.weight_loader import (
+ _build_weight, # pyright: ignore[reportPrivateUsage]
+ )
+
+ config = _make_model_config(quantized=True)
+ # embed_tokens shape: (vocab_size, hidden_size) = (256, 64)
+ raw = _mlx_quantized_raw(
+ "model.embed_tokens", out_features=256, in_features=64,
+ )
+ result = _build_weight(raw, "model.embed_tokens.weight", config, is_embedding=True)
+ assert isinstance(result, QuantizedEmbedding)
+
+def test_build_weight_unquantized_returns_tensor():
+ """Unquantized model (.weight only, no quantization_config) → plain Tensor."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.weight_loader import (
+ _build_weight, # pyright: ignore[reportPrivateUsage]
+ )
+
+ config = _make_model_config(quantized=False)
+ raw = {"model.layers.0.self_attn.o_proj.weight": Tensor.zeros(64, 64)}
+ result = _build_weight(raw, "model.layers.0.self_attn.o_proj.weight", config)
+ assert isinstance(result, Tensor)
+ assert not isinstance(result, QuantizedLinear)
+
+def test_build_weight_quantized_config_but_no_companions():
+ """Quantized config present but no .scales/.biases companions → plain Tensor (not quantized)."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.weight_loader import (
+ _build_weight, # pyright: ignore[reportPrivateUsage]
+ )
+
+ config = _make_model_config(quantized=True)
+ # Only .weight, no .scales or .biases
+ raw = {"model.layers.0.self_attn.o_proj.weight": Tensor.zeros(64, 64)}
+ result = _build_weight(raw, "model.layers.0.self_attn.o_proj.weight", config)
+ assert isinstance(result, Tensor)
+ assert not isinstance(result, QuantizedLinear)
+
+def test_build_weight_missing_key_raises():
+ """Empty dict → KeyError."""
+ from exo.worker.engines.tinygrad.weight_loader import (
+ _build_weight, # pyright: ignore[reportPrivateUsage]
+ )
+
+ config = _make_model_config(quantized=False)
+ with pytest.raises(KeyError):
+ _build_weight({}, "model.layers.0.self_attn.o_proj.weight", config)
+
+def test_build_weight_mlx_quantized_shape_correctness():
+ """MLX quantized gate_proj should have correct in_features and out_features."""
+ from exo.worker.engines.tinygrad.quantization.layers import QuantizedLinear
+ from exo.worker.engines.tinygrad.weight_loader import (
+ _build_weight, # pyright: ignore[reportPrivateUsage]
+ )
+
+ config = _make_model_config(quantized=True)
+ # gate_proj shape: (intermediate_size, hidden_size) = (128, 64)
+ raw = _mlx_quantized_raw(
+ "model.layers.0.mlp.gate_proj", out_features=128, in_features=64,
+ )
+ result = _build_weight(raw, "model.layers.0.mlp.gate_proj.weight", config)
+ assert isinstance(result, QuantizedLinear)
+ assert result.out_features == 128
+ assert result.in_features == 64