Skip to content

Commit d0e13ce

Browse files
committed
fix(py): race condition when starting dev server and implement reflection api parity and action cancellation
- Fixed race condition that caused devui to show "runtime not detected" errors - Refactor signal handling. - Implemented `GET /api/values` to support retrieving `defaultModel`. - Added `GET /api/envs` returning default dev environment. - Added `POST` support for `/api/__quitquitquit` for JS parity. - Implemented `POST /api/cancelAction` for cancelling running actions. - Updated create_reflection_asgi_app to track active actions and handle cancellation requests.
1 parent 1548a2f commit d0e13ce

File tree

6 files changed

+253
-17
lines changed

6 files changed

+253
-17
lines changed

py/packages/genkit/src/genkit/ai/_base_async.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,33 @@ async def run_user_coro_wrapper():
139139

140140
reflection_server = _make_reflection_server(self.registry, spec)
141141

142+
# Setup signal handlers for graceful shutdown (parity with JS)
143+
import signal
144+
145+
def handle_signal(signum, frame):
146+
logger.info(f'Received signal {signum}, initiating shutdown...')
147+
# We can't cancel the task group directly from a signal handler easily
148+
# but raising SystemExit or similar often works.
149+
# anyio.run catches signals if we don't interfere.
150+
# But let's rely on anyio's built-in signal handling if possible
151+
# or just let the default handler propagate as KeyboardInterrupt/SystemExit.
152+
pass
153+
154+
# Actually, anyio.run handles Ctrl+C (SIGINT) by raising KeyboardInterrupt/CancelledError
155+
# For SIGTERM, we might need to be explicit if we run in a container/process manager.
156+
# JS uses: process.on('SIGTERM', shutdown); process.on('SIGINT', shutdown);
157+
158+
# Since anyio/asyncio handles SIGINT well, let's add a task to catch SIGTERM
159+
async def handle_sigterm(tg_to_cancel):
160+
with anyio.open_signal_receiver(signal.SIGTERM) as signals:
161+
async for signum in signals:
162+
logger.info('Received SIGTERM, cancelling tasks...')
163+
tg_to_cancel.cancel_scope.cancel()
164+
return
165+
142166
try:
143-
async with RuntimeManager(spec):
167+
# Use lazy_write=True to prevent race condition where file exists before server is up
168+
async with RuntimeManager(spec, lazy_write=True) as runtime_manager:
144169
# We use anyio.TaskGroup because it is compatible with
145170
# asyncio's event loop and works with Python 3.10
146171
# (asyncio.TaskGroup was added in 3.11, and we can switch to
@@ -150,6 +175,48 @@ async def run_user_coro_wrapper():
150175
tg.start_soon(reflection_server.serve, name='genkit-reflection-server')
151176
await logger.ainfo(f'Started Genkit reflection server at {spec.url}')
152177

178+
# Start SIGTERM handler
179+
tg.start_soon(handle_sigterm, tg, name='genkit-sigterm-handler')
180+
181+
# Wait for server to be responsive
182+
# We need to loop and poll the health endpoint or wait for uvicorn to be ready
183+
# Since uvicorn run is blocking (but we are in a task), we can't easily hook into its startup
184+
# unless we use uvicorn's server object directly which we do.
185+
# reflection_server.started is set when uvicorn starts.
186+
187+
# Simple polling loop
188+
import urllib.error
189+
import urllib.request
190+
191+
max_retries = 20 # 2 seconds total roughly
192+
for i in range(max_retries):
193+
try:
194+
# TODO: Use async http client if available to avoid blocking loop?
195+
# But we are in dev mode, so maybe okay.
196+
# Actually we should use anyio.to_thread to avoid blocking event loop
197+
# or assume standard lib urllib is fast enough for localhost.
198+
199+
# Using sync urllib in async loop blocks the loop!
200+
# We must use anyio.to_thread or a non-blocking check.
201+
# But let's check if reflection_server object has a 'started' flag we can trust.
202+
# uvicorn.Server has 'started' attribute but it might be internal state.
203+
204+
# Let's stick to simple polling with to_thread for safety
205+
def check_health():
206+
with urllib.request.urlopen(f'{spec.url}/api/__health', timeout=0.5) as response:
207+
return response.status == 200
208+
209+
is_healthy = await anyio.to_thread.run_sync(check_health)
210+
if is_healthy:
211+
break
212+
except Exception:
213+
await anyio.sleep(0.1)
214+
else:
215+
logger.warning(f'Reflection server at {spec.url} did not become healthy in time.')
216+
217+
# Now write the file (or verify it persisted)
218+
runtime_manager.write_runtime_file()
219+
153220
# Start the (potentially short-lived) user coroutine wrapper
154221
tg.start_soon(run_user_coro_wrapper, name='genkit-user-coroutine')
155222
await logger.ainfo('Started Genkit user coroutine')

py/packages/genkit/src/genkit/ai/_runtime.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,29 +146,38 @@ class RuntimeManager:
146146
that the context manager exits cleanly and allows exceptions to propagate.
147147
"""
148148

149-
def __init__(self, spec: ServerSpec, runtime_dir: str | Path | None = None):
149+
def __init__(
150+
self,
151+
spec: ServerSpec,
152+
runtime_dir: str | Path | None = None,
153+
lazy_write: bool = False,
154+
):
150155
"""Initialize the RuntimeManager.
151156
152157
Args:
153158
spec: The server specification for the reflection server.
154159
runtime_dir: The directory to store the runtime file in.
155160
Defaults to .genkit/runtimes in the current directory.
161+
lazy_write: If True, the runtime file will not be written immediately
162+
on context entry. It must be written manually by calling
163+
write_runtime_file().
156164
"""
157165
self.spec = spec
158166
if runtime_dir is None:
159167
self._runtime_dir = Path(os.getcwd()) / DEFAULT_RUNTIME_DIR_NAME
160168
else:
161169
self._runtime_dir = Path(runtime_dir)
162170

171+
self.lazy_write = lazy_write
163172
self._runtime_file_path: Path | None = None
164173

165174
async def __aenter__(self) -> RuntimeManager:
166175
"""Create the runtime directory and file."""
167176
try:
168177
await logger.adebug(f'Ensuring runtime directory exists: {self._runtime_dir}')
169178
self._runtime_dir.mkdir(parents=True, exist_ok=True)
170-
self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec)
171-
_register_atexit_cleanup_handler(self._runtime_file_path)
179+
if not self.lazy_write:
180+
self.write_runtime_file()
172181

173182
except Exception as e:
174183
logger.error(f'Failed to initialize runtime file: {e}', exc_info=True)
@@ -189,18 +198,19 @@ async def __aexit__(
189198
exc_tb: The traceback of the exception that occurred.
190199
191200
Returns:
192-
True if cleanup was successful, False if cleanup failed.
201+
False to indicate exceptions should propagate.
193202
"""
203+
self.cleanup()
194204
await logger.adebug('RuntimeManager async context exited.')
195-
return True
205+
return False
196206

197207
def __enter__(self) -> RuntimeManager:
198208
"""Synchronous entry point: Create the runtime directory and file."""
199209
try:
200210
logger.debug(f'[sync] Ensuring runtime directory exists: {self._runtime_dir}')
201211
self._runtime_dir.mkdir(parents=True, exist_ok=True)
202-
self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec)
203-
_register_atexit_cleanup_handler(self._runtime_file_path)
212+
if not self.lazy_write:
213+
self.write_runtime_file()
204214

205215
except Exception as e:
206216
logger.error(f'[sync] Failed to initialize runtime file: {e}', exc_info=True)
@@ -219,5 +229,25 @@ def __exit__(self, exc_type: Exception | None, exc_val: Exception | None, exc_tb
219229
Returns:
220230
False to indicate exceptions should propagate.
221231
"""
232+
self.cleanup()
222233
logger.debug('RuntimeManager sync context exited.')
223234
return False
235+
236+
def write_runtime_file(self) -> Path:
237+
"""Calculates metadata, creates filename, and writes the runtime file.
238+
239+
Returns:
240+
The Path object of the created file.
241+
"""
242+
if self._runtime_file_path:
243+
return self._runtime_file_path
244+
245+
self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec)
246+
_register_atexit_cleanup_handler(self._runtime_file_path)
247+
return self._runtime_file_path
248+
249+
def cleanup(self) -> None:
250+
"""Explicitly cleanup the runtime file."""
251+
if self._runtime_file_path:
252+
logger.debug(f'Cleaning up runtime file: {self._runtime_file_path}')
253+
_remove_file(self._runtime_file_path)

py/packages/genkit/src/genkit/core/action/_action.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
self,
125125
on_chunk: StreamingCallback | None = None,
126126
context: dict[str, Any] | None = None,
127+
on_trace_start: Callable[[str], None] | None = None,
127128
):
128129
"""Initializes an ActionRunContext instance.
129130
@@ -136,9 +137,12 @@ def __init__(
136137
context: An optional dictionary containing context data to be made
137138
available within the action execution. Defaults to an empty
138139
dictionary.
140+
on_trace_start: A callable to be invoked with the trace ID when
141+
the trace is started.
139142
"""
140143
self._on_chunk = on_chunk if on_chunk is not None else noop_streaming_callback
141144
self._context = context if context is not None else {}
145+
self._on_trace_start = on_trace_start if on_trace_start else lambda _: None
142146

143147
@property
144148
def context(self) -> dict[str, Any]:
@@ -302,6 +306,7 @@ async def arun(
302306
input: Any = None,
303307
on_chunk: StreamingCallback | None = None,
304308
context: dict[str, Any] | None = None,
309+
on_trace_start: Callable[[str], None] | None = None,
305310
telemetry_labels: dict[str, Any] | None = None,
306311
) -> ActionResponse:
307312
"""Executes the action asynchronously with the given input.
@@ -331,14 +336,15 @@ async def arun(
331336

332337
return await self._afn(
333338
input,
334-
ActionRunContext(on_chunk=on_chunk, context=_action_context.get(None)),
339+
ActionRunContext(on_chunk=on_chunk, context=_action_context.get(None), on_trace_start=on_trace_start),
335340
)
336341

337342
async def arun_raw(
338343
self,
339344
raw_input: Any,
340345
on_chunk: StreamingCallback | None = None,
341346
context: dict[str, Any] | None = None,
347+
on_trace_start: Callable[[str], None] | None = None,
342348
telemetry_labels: dict[str, Any] | None = None,
343349
):
344350
"""Executes the action asynchronously with raw, unvalidated input.
@@ -367,6 +373,7 @@ async def arun_raw(
367373
input=input_action,
368374
on_chunk=on_chunk,
369375
context=context,
376+
on_trace_start=on_trace_start,
370377
telemetry_labels=telemetry_labels,
371378
)
372379

@@ -501,6 +508,7 @@ async def async_tracing_wrapper(input: Any | None, ctx: ActionRunContext) -> Act
501508
afn = ensure_async(fn)
502509
with tracer.start_as_current_span(name) as span:
503510
trace_id = str(span.get_span_context().trace_id)
511+
ctx._on_trace_start(trace_id)
504512
record_input_metadata(
505513
span=span,
506514
kind=kind,
@@ -541,6 +549,7 @@ def sync_tracing_wrapper(input: Any | None, ctx: ActionRunContext) -> ActionResp
541549
"""
542550
with tracer.start_as_current_span(name) as span:
543551
trace_id = str(span.get_span_context().trace_id)
552+
ctx._on_trace_start(trace_id)
544553
record_input_metadata(
545554
span=span,
546555
kind=kind,

py/packages/genkit/src/genkit/core/error.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,17 @@ def __init__(
7979
detail: Optional detail information.
8080
source: Optional source of the error.
8181
"""
82-
source_prefix = f'{source}: ' if source else ''
83-
super().__init__(f'{source_prefix}{status}: {message}')
84-
self.original_message = message
85-
8682
self.status = status
8783
if not self.status and isinstance(cause, GenkitError):
8884
self.status = cause.status
8985

9086
if not self.status:
9187
self.status = 'INTERNAL'
9288

89+
source_prefix = f'{source}: ' if source else ''
90+
super().__init__(f'{source_prefix}{self.status}: {message}')
91+
self.original_message = message
92+
9393
self.http_code = http_status_code(self.status)
9494

9595
if not details:

0 commit comments

Comments
 (0)