diff --git a/docs/custom-pipeline.md b/docs/custom-pipeline.md index cc311df19..35dbc4961 100644 --- a/docs/custom-pipeline.md +++ b/docs/custom-pipeline.md @@ -1,4 +1,4 @@ -# Creating a Custom Live Pipeline +# Creating a Custom Pipeline This guide explains how to create a custom pipeline for the AI Runner from a **separate repository**. We'll use [scope-runner](https://github.com/livepeer/scope-runner) as a reference implementation. @@ -9,12 +9,12 @@ This guide explains how to create a custom pipeline for the AI Runner from a **s A custom pipeline is a Python package that: 1. Extends the `ai-runner[realtime]` library as a dependency (or `ai-runner[batch]` for a batch pipeline) -2. Implements the [`Pipeline`](../runner/src/runner/live/pipelines/interface.py#L46) interface for frame processing +2. Uses the [`@pipeline`](../runner/src/runner/live/pipelines/create.py) decorator to define frame processing logic 3. Optionally defines custom parameters extending [`BaseParams`](../runner/src/runner/live/pipelines/interface.py#L10) 4. Provides a `prepare_models()` classmethod for model download/compilation 5. Ships as a Docker image, ideally extending `livepeer/ai-runner:live-base` -## Prerequisites +## Requirements - Python 3.10+ (stricter dependency will likely come from your pipeline code) - [uv](https://docs.astral.sh/uv/) package manager @@ -74,63 +74,98 @@ touch src/my_pipeline/pipeline/params.py --- -## Step 2: Implement the Pipeline Interface +## Step 2: Implement the Pipeline -### 2.1 Define Parameters (Optional) +Use the `@pipeline` decorator to define your pipeline. The decorator handles frame queues, lifecycle management, parameter validation, and threading automatically. -Implement `src/my_pipeline/pipeline/params.py`: +**Function form** — simplest possible pipeline: ```python -from runner.live.pipelines import BaseParams - -class MyPipelineParams(BaseParams): - # Define your custom fields here +# src/my_pipeline/pipeline/pipeline.py +import torch +from runner.live.pipelines import pipeline, BaseParams +from runner.live.trickle import VideoFrame + +@pipeline(name="green-shift") +async def green_shift(frame: VideoFrame, params: BaseParams) -> torch.Tensor: + # Process frame tensor and return modified tensor + tensor = frame.tensor.clone() + tensor[:, :, :, 1] = torch.clamp(tensor[:, :, :, 1] + 0.3, -1.0, 1.0) + return tensor ``` -### 2.2 Implement the Pipeline - -Implement `src/my_pipeline/pipeline/pipeline.py`: +**Class form** — for state, device setup, and mid-stream parameter updates: ```python -import asyncio +# src/my_pipeline/pipeline/pipeline.py import logging +import torch +import torch.nn.functional as F +from pydantic import Field +from runner.live.pipelines import pipeline, BaseParams +from runner.live.trickle import VideoFrame + +class EdgeParams(BaseParams): + threshold: float = Field(default=0.1, ge=0.0, le=1.0, description="Edge threshold.") + colorize: bool = Field(default=False, description="Colorize edges by direction.") + +@pipeline(name="edge-detect", params=EdgeParams) +class EdgeDetect: + + def on_ready(self, **params): + self._threshold = params.get("threshold", 0.1) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.sobel_x = torch.tensor( + [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=self.device + ).view(1, 1, 3, 3) + self.sobel_y = torch.tensor( + [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=self.device + ).view(1, 1, 3, 3) + + def transform(self, frame: VideoFrame, params: EdgeParams) -> torch.Tensor: + tensor = frame.tensor.to(self.device) + gray = tensor.mean(dim=-1, keepdim=True).permute(0, 3, 1, 2) + edges_x = F.conv2d(gray, self.sobel_x, padding=1) + edges_y = F.conv2d(gray, self.sobel_y, padding=1) + magnitude = torch.sqrt(edges_x ** 2 + edges_y ** 2) + magnitude = magnitude / (magnitude.max() + 1e-8) + edges = (magnitude > self._threshold).float() + out = edges.expand(-1, 3, -1, -1).permute(0, 2, 3, 1) + return (out * 2.0 - 1.0) + + def on_update(self, **params): + self._threshold = params.get("threshold", 0.1) + logging.info(f"Edge threshold updated: {self._threshold}") + + def on_stop(self): + logging.info("EdgeDetect stopped") +``` -from runner.live.pipelines import Pipeline -from runner.live.trickle import VideoFrame, VideoOutput - -class MyPipeline(Pipeline): - def __init__(self): - self.frame_queue: asyncio.Queue[VideoOutput] = asyncio.Queue() +**Lifecycle methods:** - async def initialize(self, **params): - logging.info(f"Initializing with params: {params}") - # Load your model here (use asyncio.to_thread for blocking operations) - # self.model = await asyncio.to_thread(load_model, params) +| Method | Required | When it runs | What to do here | +| --- | --- | --- | --- | +| `prepare_models(cls)` | No | **Build time** | Download weights, compile TensorRT engines | +| `on_ready(self, **params)` | No | **Process startup** | Load model from disk to GPU | +| `transform(self, frame, params)` | Yes | **Every frame** | Run inference, return tensor | +| `on_update(self, **params)` | No | **Mid-stream** | Handle param changes | +| `on_stop(self)` | No | **Shutdown** | Release resources | - async def put_video_frame(self, frame: VideoFrame, request_id: str): - # Process frame here (use asyncio.to_thread for blocking inference) - # result = await asyncio.to_thread(self.model.predict, frame.tensor) - await self.frame_queue.put(VideoOutput(frame, request_id)) +Both `async def` and `def` work for all methods. Sync functions automatically run in a thread pool. - async def get_processed_video_frame(self) -> VideoOutput: - return await self.frame_queue.get() +See [`examples/live-video-to-video/`](../examples/live-video-to-video/) for complete working examples. - async def update_params(self, **params): - logging.info(f"Updating params: {params}") - # Return asyncio.create_task(...) if reload needed (shows loading overlay) +### Define Parameters (Optional) - async def stop(self): - logging.info("Stopping pipeline") +```python +# src/my_pipeline/pipeline/params.py +from runner.live.pipelines import BaseParams - @classmethod - def prepare_models(cls): - logging.info("Preparing models") - # Download models, compile TensorRT engines, etc. +class MyPipelineParams(BaseParams): + # Define your custom fields here ``` -For a real-world implementation, see [scope-runner's pipeline](https://github.com/daydreamlive/scope-runner/blob/dec9ecf7e306892df9cfae21759c23fdf15b0510/src/scope_runner/pipeline/pipeline.py#L22). - -### 2.3 Keep Module Exports Minimal +### Keep Module Exports Minimal > **⚠️ Important**: Do **not** export `Pipeline` or `Params` classes from `__init__.py`. The loader imports these by their full path (`module.path:ClassName`), and re-exporting from `__init__.py` would trigger expensive imports (torch, etc.) when only loading the params class. @@ -207,9 +242,16 @@ CMD ["uv", "run", "--frozen", "my-pipeline"] ## Step 5: Implement Model Preparation -The `prepare_models()` classmethod is called when running with the `PREPARE_MODELS=1` environment variable (or `--prepare-models` flag). It is set automatically by `dl_checkpoints.sh` during operator setup. +The `prepare_models()` classmethod runs at **build time** when an operator sets up their node, not when a stream or request arrives. It is triggered by the `PREPARE_MODELS=1` environment variable (or `--prepare-models` flag), and is called automatically by `dl_checkpoints.sh` during operator setup. -Example implementation (in your `pipeline.py`): +This is the right place for any expensive one-time work: + +- **Downloading model weights** from HuggingFace, Google Drive, etc. +- **Compiling TensorRT engines** for optimized GPU inference +- **Converting model formats** (e.g., ONNX export, quantization) +- **Warming up caches** or generating lookup tables + +Unlike runtime (where `HF_HUB_OFFLINE=1` prevents accidental downloads), `prepare_models` runs with full network access so you can fetch weights from HuggingFace, Google Drive, or other sources. ```python @classmethod @@ -230,12 +272,26 @@ def prepare_models(cls): local_dir_use_symlinks=False, ) - # Compile TensorRT engines if needed - # This is where you'd run expensive one-time operations + # Optional: compile TensorRT engine for faster inference + # import torch_tensorrt + # model = torch.load(models_dir / "my-model" / "model.pt") + # trt_model = torch_tensorrt.compile(model, inputs=[...]) + # torch.save(trt_model, models_dir / "my-model" / "model_trt.pt") logging.info("Model preparation complete") ``` +Then in `on_ready`, just load the pre-downloaded (and optionally pre-compiled) model from disk: + +```python +def on_ready(self, **params): + """Load model from disk to GPU. Should be fast (seconds, not minutes).""" + models_dir = Path(os.environ.get("MODEL_DIR", "/models")) / "MyPipeline--models" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = torch.load(models_dir / "my-model" / "model.pt", map_location=self.device) + self.model.eval() +``` + --- ## Step 6: Integration with Livepeer Infrastructure @@ -322,49 +378,78 @@ When running the orchestrator, configure it to advertise your pipeline capabilit ### Integration Testing with go-livepeer Box -For full end-to-end testing with the Livepeer stack (gateway, orchestrator, Trickle streams), use the [go-livepeer box](https://github.com/livepeer/go-livepeer/blob/master/box/box.md) with your local runner. +For full end-to-end testing with the Livepeer stack (gateway, orchestrator, MediaMTX, Trickle streams), use the [go-livepeer box](https://github.com/livepeer/go-livepeer/blob/master/box/box.md) with your local runner. You will need [go-livepeer](https://github.com/livepeer/go-livepeer) cloned and [MediaMTX](https://github.com/bluenviron/mediamtx) — either installed locally or via Docker (`DOCKER=true`). -1. **Start your local pipeline**: +#### 1. Start your local pipeline - ```bash - uv run my-pipeline - # Pipeline starts on http://localhost:8000 - ``` +```bash +# From your pipeline project (or the ai-runner examples) +cd ai-runner/runner +uv sync --extra realtime-dev +PYTHONPATH=src:../examples/live-video-to-video uv run python \ + ../examples/live-video-to-video/test_examples.py green-shift -2. **Create an `aiModels.json` file** pointing to your local runner: +# Pipeline starts on http://localhost:8000 +``` - ```json - [ - { - "pipeline": "live-video-to-video", - "model_id": "my-pipeline", - "url": "http://localhost:8000" - } - ] - ``` +#### 2. Create an `aiModels.json` file - The `url` field tells the orchestrator to use your local runner instead of starting a Docker container. The `model_id` must match your pipeline's `name` in the `PipelineSpec`. +Create `go-livepeer/box/aiModels.json` pointing to your local runner: -3. **Start the go-livepeer box** with your config: +```json +[ + { + "pipeline": "live-video-to-video", + "model_id": "my-pipeline", + "url": "http://localhost:8000" + } +] +``` - ```bash - cd /path/to/go-livepeer/box +The `url` field tells the orchestrator to use your local runner as an "external container" instead of starting a Docker container. The `model_id` must match your pipeline's `name` in the `PipelineSpec`. - # Point to your aiModels.json file - export AI_MODELS_JSON=/path/to/aiModels.json +#### 3. Start the go-livepeer box - # Start the orchestrator and gateway - make box - ``` +```bash +cd go-livepeer -4. **Stream and playback**: +make box REBUILD=false DOCKER=true \ + AI_MODELS_JSON="$(pwd)/box/aiModels.json" +``` - ```bash - make box-stream # Start streaming - make box-playback # View the output - ``` +Key flags: -The orchestrator will route requests to your local runner at `http://localhost:8000` instead of spinning up a Docker container. +| Flag | Purpose | +| ------ | --------- | +| `REBUILD=false` | Skips rebuilding go-livepeer and runner (must have been built at least once) | +| `DOCKER=true` | Runs gateway, orchestrator, and MediaMTX in Docker containers | +| `AI_MODELS_JSON` | **Must be an absolute path** — the file is mounted into the orchestrator container | + +Verify the orchestrator started correctly by looking for these log lines: + +```bash +Starting external container name=live-video-to-video_my-pipeline_http://localhost:8000 modelID=my-pipeline +Capability live-video-to-video (ID: Live video to video) advertised with model constraint my-pipeline +``` + +#### 4. Stream and playback + +In separate terminals: + +```bash +# Send a test stream (uses ffmpeg test pattern) +PIPELINE=my-pipeline make box-stream + +# Or stream from your webcam +PIPELINE=my-pipeline INPUT_WEBCAM=/dev/video0 make box-stream + +# View the processed output +make box-playback +``` + +`PIPELINE` is needed for `box-stream` so the RTMP URL includes `pipeline=my-pipeline` in the query string. It is not needed for `box-playback` — the output stream is always `my-stream-out`. + +To list available webcam devices, run `ls /dev/video*`. --- @@ -378,8 +463,7 @@ The orchestrator will route requests to your local runner at `http://localhost:8 ### Async Operations -- Use `asyncio.to_thread()` for blocking/CPU-bound operations -- Never block the event loop in `put_video_frame` or `get_processed_video_frame` +- Both `async def` and `def` work — the `@pipeline` decorator automatically runs sync methods in a thread pool so they won't block the event loop ### Error Handling @@ -388,8 +472,8 @@ The orchestrator will route requests to your local runner at `http://localhost:8 ### Parameter Updates -- Return nothing from `update_params()` for instant updates -- Return an `asyncio.Task` for updates that will take a long time, normally a "pipeline reload". The runtime shows loading overlay while the reload is running. +- Return nothing from `on_update()` for instant updates +- For slow reloads, the runtime shows a loading overlay while the update is running --- @@ -403,6 +487,8 @@ The orchestrator will route requests to your local runner at `http://localhost:8 3. **CUDA out of memory**: The pipeline runs in an isolated subprocess - OOM errors will trigger a restart. +4. **"Connection refused" on RTMP port 1935**: This usually means MediaMTX isn't running. Check the `make box` output for MediaMTX errors — if you see `[RTMP] listener is closing` shortly after startup, a port conflict likely caused it to shut down (see issue 4 above). + --- ## Reference Implementation @@ -449,7 +535,7 @@ class MyBatchPipeline(Pipeline): **Key differences from live pipelines:** | Aspect | Live Pipeline | Batch Pipeline | -|--------|---------------|----------------| +| --- | --- | --- | | Base class | `runner.live.pipelines.Pipeline` | `runner.pipelines.base.Pipeline` | | Processing | Continuous frame stream | Single request/response | | Entry point | `start_app(pipeline_spec)` | `start_app(pipeline_instance)` | diff --git a/examples/live-video-to-video/__init__.py b/examples/live-video-to-video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/live-video-to-video/edge_detect.py b/examples/live-video-to-video/edge_detect.py new file mode 100644 index 000000000..930e517c5 --- /dev/null +++ b/examples/live-video-to-video/edge_detect.py @@ -0,0 +1,100 @@ +"""Edge Detection (Sobel) -- @pipeline example (class form) with lifecycle hooks. + +Pure-torch Sobel edge detection. No external model needed. +Demonstrates on_ready, transform, on_update, and on_stop lifecycle hooks. + +Usage: + python examples/live-video-to-video/edge_detect.py +""" + +import logging + +import torch +import torch.nn.functional as F +from pydantic import Field + +from runner.app import start_app +from runner.live.pipelines import pipeline, BaseParams +from runner.live.trickle import VideoFrame + + +class EdgeParams(BaseParams): + """Parameters for edge detection, adjustable mid-stream.""" + + threshold: float = Field( + default=0.1, ge=0.0, le=1.0, + description="Edge detection threshold. Higher = fewer edges.", + ) + colorize: bool = Field( + default=False, + description="Colorize edges based on gradient direction.", + ) + + +@pipeline(name="edge-detect", params=EdgeParams) +class EdgeDetect: + """Real-time Sobel edge detection. + + Demonstrates: + - on_ready: initializes Sobel kernels on the correct device + - transform: runs edge detection per frame + - on_update: adjusts threshold mid-stream + - on_stop: cleanup on shutdown + """ + + def on_ready(self, **params): + self._threshold = params.get("threshold", 0.1) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Sobel kernels for horizontal and vertical gradients + self.sobel_x = torch.tensor( + [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=self.device + ).view(1, 1, 3, 3) + self.sobel_y = torch.tensor( + [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=self.device + ).view(1, 1, 3, 3) + + logging.info(f"EdgeDetect ready on {self.device}") + + def transform(self, frame: VideoFrame, params: EdgeParams) -> torch.Tensor: + """Run Sobel edge detection on each frame.""" + # frame.tensor is (B, H, W, C) in [-1.0, 1.0] + tensor = frame.tensor.to(self.device) + + # Convert to grayscale: (B, H, W, C) -> (B, 1, H, W) + gray = tensor.mean(dim=-1, keepdim=True).permute(0, 3, 1, 2) + + # Apply Sobel filters + edges_x = F.conv2d(gray, self.sobel_x, padding=1) + edges_y = F.conv2d(gray, self.sobel_y, padding=1) + + # Edge magnitude + magnitude = torch.sqrt(edges_x ** 2 + edges_y ** 2) + magnitude = magnitude / (magnitude.max() + 1e-8) # normalize to [0, 1] + + # Apply threshold + edges = (magnitude > self._threshold).float() + + if params.colorize: + # Color edges by gradient direction + angle = torch.atan2(edges_y, edges_x + 1e-8) # [-pi, pi] + angle = (angle + torch.pi) / (2 * torch.pi) # [0, 1] + r = edges * angle + g = edges * (1.0 - torch.abs(angle - 0.5) * 2.0) + b = edges * (1.0 - angle) + out = torch.cat([r, g, b], dim=1) # (B, 3, H, W) + else: + out = edges.expand(-1, 3, -1, -1) # (B, 3, H, W) + + # Convert back to (B, H, W, C) in [-1, 1] + out = out.permute(0, 2, 3, 1) + out = out * 2.0 - 1.0 + return out + + def on_update(self, **params): + """Update threshold when params change mid-stream.""" + self._threshold = params.get("threshold", 0.1) + logging.info(f"Edge threshold updated: {self._threshold}") + + def on_stop(self): + logging.info("EdgeDetect stopped") diff --git a/examples/live-video-to-video/green_shift.py b/examples/live-video-to-video/green_shift.py new file mode 100644 index 000000000..268886566 --- /dev/null +++ b/examples/live-video-to-video/green_shift.py @@ -0,0 +1,26 @@ +"""Green Shift -- minimal @pipeline example (function form). + +Boosts the green channel of every video frame. No model download, no GPU +inference — just a tensor op. Ideal for verifying the pipeline infrastructure. + +Usage: + python examples/live-video-to-video/green_shift.py +""" + +import torch + +from runner.app import start_app +from runner.live.pipelines import pipeline, BaseParams +from runner.live.trickle import VideoFrame + + +@pipeline(name="green-shift") +async def green_shift(frame: VideoFrame, params: BaseParams) -> torch.Tensor: + """Boost the green channel of every frame. + + Frame tensor layout: (B, H, W, C), values in [-1.0, 1.0]. + Channel order: R=0, G=1, B=2. + """ + tensor = frame.tensor.clone() + tensor[:, :, :, 1] = torch.clamp(tensor[:, :, :, 1] + 0.3, -1.0, 1.0) + return tensor diff --git a/examples/live-video-to-video/test_examples.py b/examples/live-video-to-video/test_examples.py new file mode 100644 index 000000000..f28f87351 --- /dev/null +++ b/examples/live-video-to-video/test_examples.py @@ -0,0 +1,35 @@ +"""Run an example pipeline as a live server you can interact with. + +Usage: + cd runner + PYTHONPATH=src:../examples/live-video-to-video uv run python \ + ../examples/live-video-to-video/test_examples.py [green-shift|edge-detect] +""" + +import sys + +from runner.app import start_app +from runner.live.pipelines import PipelineSpec + +EXAMPLES = { + "green-shift": PipelineSpec( + name="green-shift", + pipeline_cls="green_shift:green_shift", + ), + "edge-detect": PipelineSpec( + name="edge-detect", + pipeline_cls="edge_detect:EdgeDetect", + params_cls="edge_detect:EdgeParams", + ), +} + +if __name__ == "__main__": + name = sys.argv[1] if len(sys.argv) > 1 else "green-shift" + spec = EXAMPLES.get(name) + if spec is None: + print(f"Unknown example: {name}") + print(f"Available: {', '.join(EXAMPLES)}") + sys.exit(1) + + print(f"Starting {name} pipeline...") + start_app(pipeline=spec) diff --git a/runner/src/runner/app.py b/runner/src/runner/app.py index cd543039d..ebe09c3c0 100644 --- a/runner/src/runner/app.py +++ b/runner/src/runner/app.py @@ -95,13 +95,14 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None: route.operation_id = route.name -def prepare_models(pipeline_spec: PipelineSpec) -> None: +async def _prepare_models_async(pipeline_spec: PipelineSpec) -> None: """Prepare models for a live pipeline (download, compile TensorRT engines, etc.).""" from .live.pipelines.loader import load_pipeline_class + from .live.pipelines.create import _invoke logger.info(f"Preparing models for pipeline: {pipeline_spec.name}") pipeline_class = load_pipeline_class(pipeline_spec.pipeline_cls) - pipeline_class.prepare_models() + await _invoke(pipeline_class.prepare_models) logger.info("Model preparation complete") @@ -181,7 +182,8 @@ def start_app( if isinstance(pipeline, PipelineSpec): # Check for model preparation mode if os.getenv("PREPARE_MODELS") == "1" or "--prepare-models" in sys.argv: - prepare_models(pipeline) + import asyncio + asyncio.run(_prepare_models_async(pipeline)) return # Wrap in LiveVideoToVideoPipeline for normal operation diff --git a/runner/src/runner/live/pipelines/__init__.py b/runner/src/runner/live/pipelines/__init__.py index 4b96065c1..07627aa61 100644 --- a/runner/src/runner/live/pipelines/__init__.py +++ b/runner/src/runner/live/pipelines/__init__.py @@ -1,4 +1,5 @@ from .interface import Pipeline, BaseParams from .spec import PipelineSpec, builtin_pipeline_spec +from .create import pipeline -__all__ = ["Pipeline", "BaseParams", "PipelineSpec", "builtin_pipeline_spec"] +__all__ = ["Pipeline", "BaseParams", "PipelineSpec", "builtin_pipeline_spec", "pipeline"] diff --git a/runner/src/runner/live/pipelines/create.py b/runner/src/runner/live/pipelines/create.py new file mode 100644 index 000000000..179546ad3 --- /dev/null +++ b/runner/src/runner/live/pipelines/create.py @@ -0,0 +1,166 @@ +"""@pipeline decorator for creating live pipelines. + +Lifecycle hooks: +* ``prepare_models``: called at build time (model download, TensorRT compile) +* ``on_ready``: called once at startup +* ``transform``: called per frame +* ``on_update``: called when params change mid-stream +* ``on_stop``: called on shutdown +""" + +import asyncio +import inspect +import logging +from typing import Optional, Type + +from .interface import Pipeline, BaseParams +from .spec import PipelineSpec +from ..trickle import VideoFrame, VideoOutput + + +async def _invoke(func, *args, **kwargs): + """Call a function, handling both async and sync. Sync runs in a thread pool.""" + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + return await asyncio.to_thread(func, *args, **kwargs) + + +def pipeline( + name: str, + params: Optional[Type[BaseParams]] = None, + initial_params: Optional[dict] = None, +): + """Decorator to define a pipeline. Can decorate a function or class. + + Args: + name: Pipeline identifier. Must match the MODEL_ID on the orchestrator. + params: Pydantic model for pipeline parameters. Defaults to BaseParams. + initial_params: Default parameter values passed on init. + """ + params_cls = params or BaseParams + + def decorator(func_or_class): + if callable(func_or_class) and not isinstance(func_or_class, type): + # Function form: wrap into a minimal class with a transform method. + func = func_or_class + + class _Wrapper: + async def transform(self, frame, p): + return await _invoke(func, frame, p) + + _Wrapper.__name__ = func.__name__ + _Wrapper.__qualname__ = func.__qualname__ + _Wrapper.__module__ = func.__module__ + user_cls = _Wrapper + elif isinstance(func_or_class, type): + user_cls = func_or_class + else: + raise TypeError( + "@pipeline can only decorate a function or class, got " + f"{type(func_or_class)}" + ) + + if "." in user_cls.__qualname__: + logging.warning( + f"@pipeline decorating nested '{user_cls.__qualname__}' — " + f"this may cause import issues if the enclosing scope is not accessible" + ) + + pipeline_cls = _build_pipeline(user_cls, params_cls) + + # Auto-generate PipelineSpec. + params_import = ( + f"{params_cls.__module__}:{params_cls.__qualname__}" + if params_cls is not BaseParams + else None + ) + pipeline_cls._spec = PipelineSpec( + name=name, + pipeline_cls=f"{pipeline_cls.__module__}:{pipeline_cls.__qualname__}", + params_cls=params_import, + initial_params=initial_params or {}, + ) + + return pipeline_cls + + return decorator + + +def _build_pipeline(user_cls, params_cls: Type[BaseParams]) -> Type[Pipeline]: + """Build a Pipeline subclass from a user class. + + The user class must define a ``transform`` method. Lifecycle hooks are + detected by name: ``on_ready``, ``on_update``, ``on_stop``, + ``prepare_models``. + """ + if not hasattr(user_cls, "transform"): + raise TypeError( + f"@pipeline class {user_cls.__name__} must define a 'transform' method" + ) + + has_on_ready = hasattr(user_cls, "on_ready") + has_on_update = hasattr(user_cls, "on_update") + has_on_stop = hasattr(user_cls, "on_stop") + has_prepare = hasattr(user_cls, "prepare_models") + label = user_cls.__name__ + + class GeneratedPipeline(Pipeline): + def __init__(self): + super().__init__() + self._lock = asyncio.Lock() + self.frame_queue: asyncio.Queue[VideoOutput] = asyncio.Queue() + self.params_instance: Optional[BaseParams] = None + self._inner = user_cls() + + async def initialize(self, **kw_params): + logging.info(f"Initializing {label} pipeline with params: {kw_params}") + async with self._lock: + self.params_instance = params_cls(**kw_params) + if has_on_ready: + await _invoke(self._inner.on_ready, **kw_params) + logging.info("Pipeline initialization complete") + + async def put_video_frame(self, frame: VideoFrame, request_id: str): + async with self._lock: + result = await _invoke( + self._inner.transform, frame, self.params_instance + ) + if isinstance(result, VideoOutput): + # Normalize request_id to match the current stream. + if result.request_id != request_id: + result = VideoOutput(result.frame, request_id) + await self.frame_queue.put(result) + else: + await self.frame_queue.put( + VideoOutput(frame, request_id).replace_tensor(result) + ) + + async def get_processed_video_frame(self) -> VideoOutput: + return await self.frame_queue.get() + + async def update_params(self, **kw_params): + logging.info(f"Updating {label} params: {kw_params}") + async with self._lock: + self.params_instance = params_cls(**kw_params) + if has_on_update: + await _invoke(self._inner.on_update, **kw_params) + + async def stop(self): + logging.info(f"Stopping {label} pipeline") + if has_on_stop: + await _invoke(self._inner.on_stop) + + @classmethod + async def prepare_models(cls): + if has_prepare: + await _invoke(user_cls.prepare_models) + else: + logging.info(f"{label} pipeline does not require model preparation") + + # Keep the original decorated name so importlib can find it via + # getattr(module, name) after the decorator replaces the symbol. + GeneratedPipeline.__name__ = user_cls.__name__ + GeneratedPipeline.__qualname__ = user_cls.__qualname__ + GeneratedPipeline.__module__ = user_cls.__module__ + + return GeneratedPipeline diff --git a/runner/src/runner/live/pipelines/interface.py b/runner/src/runner/live/pipelines/interface.py index 1cdc245a2..e695dae8b 100644 --- a/runner/src/runner/live/pipelines/interface.py +++ b/runner/src/runner/live/pipelines/interface.py @@ -44,9 +44,16 @@ def get_output_resolution(self) -> tuple[int, int]: return (self.width, self.height) class Pipeline(ABC): - """Abstract base class for image processing pipelines. + """Abstract base class for frame processing pipelines. - Processes frames sequentially and supports dynamic parameter updates. + .. deprecated:: + For new pipelines, use the ``@pipeline`` decorator instead of + subclassing this ABC directly. The decorator handles frame queues, + lifecycle management, and parameter validation automatically. + See ``docs/custom-pipeline.md`` for usage. + + This ABC is retained for internal use and backward compatibility with + existing pipeline implementations (e.g., ComfyUI, StreamDiffusion). Notes: - Error handling is done by the caller, so the implementation can let @@ -109,7 +116,6 @@ async def stop(self): pass @classmethod - @abstractmethod def prepare_models(cls): """Download and/or compile any assets required for this pipeline.""" pass