Skip to content

Commit b5b6496

Browse files
committed
feat: add Ollama local generation backend
- OllamaProvider for Ollama's OpenAI-compatible API - Supports model selection, temperature, max_tokens - No external deps (uses urllib) - 13 new tests (173 total) Closes #2
1 parent fe704cf commit b5b6496

3 files changed

Lines changed: 237 additions & 0 deletions

File tree

src/castwright/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AnthropicProvider,
2424
LLMProvider,
2525
MockProvider,
26+
OllamaProvider,
2627
OpenAIProvider,
2728
)
2829

@@ -47,6 +48,7 @@
4748
"LLMProvider",
4849
"OpenAIProvider",
4950
"AnthropicProvider",
51+
"OllamaProvider",
5052
"MockProvider",
5153
# Filters
5254
"filter_examples",

src/castwright/providers.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import re
1111
from abc import ABC, abstractmethod
1212
from typing import Any, Dict, List, Optional, Tuple
13+
from urllib.request import Request, urlopen
1314

1415
from castwright._types import ProviderError
1516

@@ -179,6 +180,66 @@ def generate(
179180
return text, input_tokens, output_tokens
180181

181182

183+
class OllamaProvider(LLMProvider):
184+
"""Ollama local LLM provider (uses its OpenAI-compatible API).
185+
186+
Connects to Ollama at ``http://localhost:11434`` by default.
187+
No extra dependencies required — uses only ``urllib``.
188+
"""
189+
190+
def __init__(
191+
self,
192+
model: str = "llama3",
193+
host: Optional[str] = None,
194+
base_url: Optional[str] = None,
195+
) -> None:
196+
if base_url:
197+
self._base_url = base_url.rstrip("/")
198+
elif host:
199+
self._base_url = host.rstrip("/") + "/v1"
200+
else:
201+
self._base_url = "http://localhost:11434/v1"
202+
self._model = model
203+
204+
def generate(
205+
self,
206+
prompt: str,
207+
system: str = "",
208+
temperature: float = 0.9,
209+
max_tokens: int = 4096,
210+
) -> Tuple[str, int, int]:
211+
messages: list[dict[str, str]] = []
212+
if system:
213+
messages.append({"role": "system", "content": system})
214+
messages.append({"role": "user", "content": prompt})
215+
216+
payload = json.dumps({
217+
"model": self._model,
218+
"messages": messages,
219+
"temperature": temperature,
220+
"max_tokens": max_tokens,
221+
}).encode()
222+
223+
req = Request(
224+
f"{self._base_url}/chat/completions",
225+
data=payload,
226+
headers={"Content-Type": "application/json"},
227+
)
228+
229+
try:
230+
with urlopen(req) as resp:
231+
data = json.loads(resp.read())
232+
except Exception as e:
233+
raise ProviderError(f"Ollama API error: {e}") from e
234+
235+
text = data["choices"][0]["message"]["content"] if data.get("choices") else ""
236+
usage = data.get("usage", {})
237+
input_tokens = usage.get("prompt_tokens", 0)
238+
output_tokens = usage.get("completion_tokens", 0)
239+
240+
return text, input_tokens, output_tokens
241+
242+
182243
class MockProvider(LLMProvider):
183244
"""Mock provider for testing without API calls.
184245

tests/test_ollama.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Tests for castwright OllamaProvider."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from unittest.mock import MagicMock, patch
7+
8+
import pytest
9+
10+
from castwright._types import ProviderError
11+
from castwright.providers import OllamaProvider
12+
13+
14+
class TestOllamaProviderInit:
15+
def test_default_init(self):
16+
p = OllamaProvider()
17+
assert p._model == "llama3"
18+
assert p._base_url == "http://localhost:11434/v1"
19+
20+
def test_custom_model(self):
21+
p = OllamaProvider(model="mistral")
22+
assert p._model == "mistral"
23+
24+
def test_custom_base_url(self):
25+
p = OllamaProvider(base_url="http://myhost:8000/v1")
26+
assert p._base_url == "http://myhost:8000/v1"
27+
28+
def test_custom_host(self):
29+
p = OllamaProvider(host="http://myhost:11434")
30+
assert p._base_url == "http://myhost:11434/v1"
31+
32+
33+
class TestOllamaProviderGenerate:
34+
def _mock_response(self, text: str, prompt_tokens: int = 10, completion_tokens: int = 20):
35+
return {
36+
"choices": [{"message": {"content": text}}],
37+
"usage": {
38+
"prompt_tokens": prompt_tokens,
39+
"completion_tokens": completion_tokens,
40+
},
41+
}
42+
43+
@patch("castwright.providers.urlopen")
44+
def test_basic_generation(self, mock_urlopen):
45+
resp_data = self._mock_response("Hello world")
46+
mock_resp = MagicMock()
47+
mock_resp.read.return_value = json.dumps(resp_data).encode()
48+
mock_resp.__enter__ = lambda s: s
49+
mock_resp.__exit__ = MagicMock(return_value=False)
50+
mock_urlopen.return_value = mock_resp
51+
52+
p = OllamaProvider(model="llama3")
53+
text, in_tok, out_tok = p.generate("Say hello")
54+
55+
assert text == "Hello world"
56+
assert in_tok == 10
57+
assert out_tok == 20
58+
59+
@patch("castwright.providers.urlopen")
60+
def test_with_system_prompt(self, mock_urlopen):
61+
resp_data = self._mock_response("response")
62+
mock_resp = MagicMock()
63+
mock_resp.read.return_value = json.dumps(resp_data).encode()
64+
mock_resp.__enter__ = lambda s: s
65+
mock_resp.__exit__ = MagicMock(return_value=False)
66+
mock_urlopen.return_value = mock_resp
67+
68+
p = OllamaProvider()
69+
p.generate("prompt", system="You are helpful")
70+
71+
# Verify the request was made
72+
call_args = mock_urlopen.call_args
73+
req = call_args[0][0]
74+
body = json.loads(req.data)
75+
assert body["messages"][0]["role"] == "system"
76+
assert body["messages"][0]["content"] == "You are helpful"
77+
assert body["messages"][1]["role"] == "user"
78+
79+
@patch("castwright.providers.urlopen")
80+
def test_temperature_and_max_tokens(self, mock_urlopen):
81+
resp_data = self._mock_response("ok")
82+
mock_resp = MagicMock()
83+
mock_resp.read.return_value = json.dumps(resp_data).encode()
84+
mock_resp.__enter__ = lambda s: s
85+
mock_resp.__exit__ = MagicMock(return_value=False)
86+
mock_urlopen.return_value = mock_resp
87+
88+
p = OllamaProvider()
89+
p.generate("hi", temperature=0.1, max_tokens=512)
90+
91+
call_args = mock_urlopen.call_args
92+
req = call_args[0][0]
93+
body = json.loads(req.data)
94+
assert body["temperature"] == 0.1
95+
assert body["max_tokens"] == 512
96+
97+
@patch("castwright.providers.urlopen")
98+
def test_no_usage_field(self, mock_urlopen):
99+
resp_data = {"choices": [{"message": {"content": "hi"}}]}
100+
mock_resp = MagicMock()
101+
mock_resp.read.return_value = json.dumps(resp_data).encode()
102+
mock_resp.__enter__ = lambda s: s
103+
mock_resp.__exit__ = MagicMock(return_value=False)
104+
mock_urlopen.return_value = mock_resp
105+
106+
p = OllamaProvider()
107+
text, in_tok, out_tok = p.generate("prompt")
108+
assert text == "hi"
109+
assert in_tok == 0
110+
assert out_tok == 0
111+
112+
@patch("castwright.providers.urlopen")
113+
def test_empty_choices(self, mock_urlopen):
114+
resp_data = {"choices": [{"message": {"content": ""}}], "usage": {"prompt_tokens": 5, "completion_tokens": 0}}
115+
mock_resp = MagicMock()
116+
mock_resp.read.return_value = json.dumps(resp_data).encode()
117+
mock_resp.__enter__ = lambda s: s
118+
mock_resp.__exit__ = MagicMock(return_value=False)
119+
mock_urlopen.return_value = mock_resp
120+
121+
p = OllamaProvider()
122+
text, _, _ = p.generate("prompt")
123+
assert text == ""
124+
125+
@patch("castwright.providers.urlopen")
126+
def test_network_error_raises_provider_error(self, mock_urlopen):
127+
from urllib.error import URLError
128+
mock_urlopen.side_effect = URLError("Connection refused")
129+
130+
p = OllamaProvider()
131+
with pytest.raises(ProviderError, match="Ollama API error"):
132+
p.generate("prompt")
133+
134+
@patch("castwright.providers.urlopen")
135+
def test_parse_json_array_inherited(self, mock_urlopen):
136+
"""OllamaProvider should inherit parse_json_array from LLMProvider."""
137+
p = OllamaProvider()
138+
result = p.parse_json_array('[{"a": 1}]')
139+
assert result == [{"a": 1}]
140+
141+
@patch("castwright.providers.urlopen")
142+
def test_without_system(self, mock_urlopen):
143+
resp_data = self._mock_response("response")
144+
mock_resp = MagicMock()
145+
mock_resp.read.return_value = json.dumps(resp_data).encode()
146+
mock_resp.__enter__ = lambda s: s
147+
mock_resp.__exit__ = MagicMock(return_value=False)
148+
mock_urlopen.return_value = mock_resp
149+
150+
p = OllamaProvider()
151+
p.generate("prompt")
152+
153+
call_args = mock_urlopen.call_args
154+
req = call_args[0][0]
155+
body = json.loads(req.data)
156+
assert len(body["messages"]) == 1
157+
assert body["messages"][0]["role"] == "user"
158+
159+
@patch("castwright.providers.urlopen")
160+
def test_model_in_request(self, mock_urlopen):
161+
resp_data = self._mock_response("ok")
162+
mock_resp = MagicMock()
163+
mock_resp.read.return_value = json.dumps(resp_data).encode()
164+
mock_resp.__enter__ = lambda s: s
165+
mock_resp.__exit__ = MagicMock(return_value=False)
166+
mock_urlopen.return_value = mock_resp
167+
168+
p = OllamaProvider(model="phi3")
169+
p.generate("hi")
170+
171+
call_args = mock_urlopen.call_args
172+
req = call_args[0][0]
173+
body = json.loads(req.data)
174+
assert body["model"] == "phi3"

0 commit comments

Comments
 (0)