Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions sdks/python/src/opik/runner/activate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import signal
import threading

from rich.console import Console
Expand All @@ -19,6 +20,18 @@
_lock = threading.Lock()


def install_signal_handlers(shutdown_event: threading.Event) -> None:
def handler(signum: int, frame: object) -> None:
LOGGER.info("Received signal %s, shutting down", signum)
shutdown_event.set()

try:
signal.signal(signal.SIGTERM, handler)
signal.signal(signal.SIGINT, handler)
except ValueError:
LOGGER.warning("Cannot install signal handlers outside main thread")


def activate_runner() -> None:
"""Start the runner loop in a background thread (non-blocking)."""
if os.environ.get("OPIK_RUNNER_MODE") != "true":
Expand All @@ -30,11 +43,14 @@ def activate_runner() -> None:
return
_started = True

t = threading.Thread(target=_run, daemon=True)
shutdown_event = threading.Event()
install_signal_handlers(shutdown_event)

t = threading.Thread(target=_run, args=(shutdown_event,), daemon=True)
t.start()


def _run() -> None:
def _run(shutdown_event: threading.Event) -> None:
runner_id = os.environ.get("OPIK_RUNNER_ID", "")
project_name = os.environ.get("OPIK_PROJECT_NAME", "")

Expand Down Expand Up @@ -77,7 +93,6 @@ def _sync_agent(name: str) -> None:

LOGGER.info("Runner activated")

shutdown_event = threading.Event()
loop = InProcessRunnerLoop(api, runner_id, shutdown_event)

try:
Expand Down
20 changes: 20 additions & 0 deletions sdks/python/src/opik/runner/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Per-job context tracking using contextvars (works in both asyncio and threads)."""

import contextvars
from typing import Optional

_job_id_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"runner_job_id", default=None
)


def get_current_job_id() -> Optional[str]:
return _job_id_var.get()


def set_job_id(job_id: str) -> contextvars.Token:
return _job_id_var.set(job_id)


def reset_job_id(token: contextvars.Token) -> None:
_job_id_var.reset(token)
33 changes: 15 additions & 18 deletions sdks/python/src/opik/runner/in_process_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import asyncio
import collections
import contextvars
import inspect
import logging
import random
import signal
import threading
import time
from typing import Callable, Optional
Expand All @@ -16,6 +16,8 @@
from ..rest_api.core.api_error import ApiError
from ..rest_api.types.local_runner_job import LocalRunnerJob
from . import registry
from .context import reset_job_id, set_job_id
from .log_streamer import LogStreamer

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,10 +62,9 @@ def __init__(
self._lock = threading.Lock()
self._job_queue: asyncio.Queue[LocalRunnerJob] = asyncio.Queue()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._log_streamer: Optional[LogStreamer] = None

def run(self) -> None:
self._install_signal_handlers()

heartbeat_thread = threading.Thread(
target=self._heartbeat_loop,
daemon=True,
Expand Down Expand Up @@ -132,12 +133,17 @@ def _heartbeat_loop(self) -> None:

def _run_job_loop(self) -> None:
self._loop = asyncio.new_event_loop()
self._log_streamer = LogStreamer(self._api, self._loop)
self._log_streamer.install()
try:
self._loop.run_until_complete(self._job_consumer())
finally:
Comment thread
petrotiurin marked this conversation as resolved.
self._loop.run_until_complete(self._log_streamer.stop())
self._loop.close()

async def _job_consumer(self) -> None:
assert self._log_streamer is not None
self._log_streamer.start()
tasks: set[asyncio.Task] = set()
while not self._shutdown_event.is_set():
try:
Expand Down Expand Up @@ -188,6 +194,8 @@ async def _execute_job(self, job: LocalRunnerJob) -> None:
trace_id=trace_id,
)

token = set_job_id(job_id)
ctx = contextvars.copy_context()
try:
timeout = job.timeout
if inspect.iscoroutinefunction(func):
Expand All @@ -206,13 +214,13 @@ def _run_with_mask() -> object:
if timeout:
result = await asyncio.wait_for(
asyncio.get_running_loop().run_in_executor(
None, _run_with_mask
None, ctx.run, _run_with_mask
),
timeout=timeout,
)
else:
result = await asyncio.get_running_loop().run_in_executor(
None, _run_with_mask
None, ctx.run, _run_with_mask
)

if not isinstance(result, (dict, str, int, float, bool, list, type(None))):
Expand Down Expand Up @@ -240,6 +248,8 @@ def _run_with_mask() -> object:
error=f"{type(e).__name__}: {e}",
trace_id=trace_id,
)
finally:
reset_job_id(token)

def _safe_report_job_result(self, job_id: str, **kwargs: object) -> None:
"""Report a job result, logging and swallowing any exception."""
Expand All @@ -262,16 +272,3 @@ def _prune_cancelled_jobs(self, now: float) -> None:
def _backoff_wait(self, backoff: float) -> None:
wait = min(backoff, self._backoff_cap_seconds) * (0.5 + random.random() * 0.5)
self._shutdown_event.wait(wait)

def _install_signal_handlers(self) -> None:
shutdown = self._shutdown_event

def handler(signum: int, frame: object) -> None:
LOGGER.info("Received signal %s, shutting down", signum)
shutdown.set()

try:
signal.signal(signal.SIGTERM, handler)
signal.signal(signal.SIGINT, handler)
except ValueError:
LOGGER.warning("Cannot install signal handlers outside main thread")
120 changes: 120 additions & 0 deletions sdks/python/src/opik/runner/log_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Captures stdout/stderr per-job and streams to the backend asynchronously."""

import asyncio
import io
import logging
import sys
import typing

from .context import get_current_job_id
from ..rest_api.client import OpikApi
from ..rest_api.types.local_runner_log_entry import LocalRunnerLogEntry

LOGGER = logging.getLogger(__name__)

_FLUSH_INTERVAL_SECONDS = 0.5


class _CaptureStream(io.TextIOBase):
"""Intercepts writes, captures per-job log entries, forwards to the inner stream."""

encoding: str = "utf-8"

def __init__(
self,
stream: typing.TextIO,
stream_name: str,
loop: asyncio.AbstractEventLoop,
queue: asyncio.Queue,
) -> None:
self._stream = stream
self._stream_name = stream_name
self._loop = loop
self._queue = queue
self.encoding = getattr(stream, "encoding", "utf-8")

def write(self, s: str) -> int:
if s.strip():
job_id = get_current_job_id()
if job_id is not None:
entry = LocalRunnerLogEntry(stream=self._stream_name, text=s)
try:
self._loop.call_soon_threadsafe(
self._queue.put_nowait, (job_id, entry)
)
except RuntimeError:
pass
return self._stream.write(s)
Comment thread
petrotiurin marked this conversation as resolved.

def flush(self) -> None:
self._stream.flush()

def isatty(self) -> bool:
return self._stream.isatty()

def fileno(self) -> int:
return self._stream.fileno()


class LogStreamer:
def __init__(self, api: OpikApi, loop: asyncio.AbstractEventLoop) -> None:
self._api = api
self._loop = loop
self._queue: asyncio.Queue[typing.Tuple[str, LocalRunnerLogEntry]] = (
asyncio.Queue()
)
self._task: typing.Optional[asyncio.Task] = None

def install(self) -> None:
sys.stdout = _CaptureStream(sys.stdout, "stdout", self._loop, self._queue) # type: ignore[assignment]
Comment thread
petrotiurin marked this conversation as resolved.
sys.stderr = _CaptureStream(sys.stderr, "stderr", self._loop, self._queue) # type: ignore[assignment]

def start(self) -> None:
self._task = self._loop.create_task(self._run())

async def stop(self) -> None:
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass

async def _run(self) -> None:
pending: typing.Dict[str, typing.List[LocalRunnerLogEntry]] = {}

while True:
try:
job_id, entry = await asyncio.wait_for(
self._queue.get(), timeout=_FLUSH_INTERVAL_SECONDS
)
pending.setdefault(job_id, []).append(entry)

if self._queue.empty():
await self._drain_all(pending)
Comment thread
petrotiurin marked this conversation as resolved.
except asyncio.TimeoutError:
await self._drain_all(pending)
except asyncio.CancelledError:
await self._drain_all(pending)
return

async def _drain_all(
self, pending: typing.Dict[str, typing.List[LocalRunnerLogEntry]]
) -> None:
for job_id in list(pending):
entries = pending.pop(job_id)
if entries:
await self._send_batch(job_id, entries)

async def _send_batch(
self, job_id: str, entries: typing.List[LocalRunnerLogEntry]
) -> None:
try:
await self._loop.run_in_executor(
None,
lambda: self._api.runners.append_job_logs(
job_id=job_id, request=entries
),
)
except Exception:
LOGGER.debug("Failed to send logs for job %s", job_id, exc_info=True)
7 changes: 6 additions & 1 deletion sdks/python/src/opik/runner/prefixed_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ def fileno(self) -> int:
return self._stream.fileno()


_installed = False


def install() -> None:
if not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty():
global _installed
if _installed:
return
sys.stdout = PrefixedStream(sys.stdout) # type: ignore[assignment]
sys.stderr = PrefixedStream(sys.stderr) # type: ignore[assignment]
_installed = True
1 change: 1 addition & 0 deletions sdks/python/tests/e2e/runner/echo_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

@opik.track(entrypoint=True)
def echo(message: str) -> str:
print(f"echo stdout: {message}")
return f"echo: {message}"


Expand Down
21 changes: 20 additions & 1 deletion sdks/python/tests/e2e/runner/test_runner_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _is_agent_registered():


def test_runner_happy_path(api_client, runner_process: RunnerInfo, project_id):
"""Basic: register echo agent, run job, verify job result and trace output."""
"""Basic: register echo agent, run job, verify job result, trace output, and job logs."""
message = f"hello-e2e-{int(time.time())}"

wait_for_agent_registration(api_client, "echo", project_id)
Expand All @@ -141,6 +141,25 @@ def test_runner_happy_path(api_client, runner_process: RunnerInfo, project_id):
trace = find_trace_by_input(api_client, OPIK_E2E_TESTS_PROJECT_NAME, message)
assert f"echo: {message}" in str(trace.output)

logs_result = []

def _find_logs():
logs = api_client.runners.get_job_logs(job.id)
if logs:
logs_result.clear()
logs_result.extend(logs)
return True
return False

assert opik.synchronization.until(
_find_logs,
max_try_seconds=5,
allow_errors=True,
), f"Expected job logs for job {job.id}, got none"

log_text = " ".join(entry.text for entry in logs_result)
assert message in log_text, f"Expected '{message}' in job logs, got: {log_text}"


def test_runner_with_mask(
opik_client, api_client, runner_process: RunnerInfo, project_id
Expand Down
Loading
Loading