Skip to content

Commit 84170b3

Browse files
committed
Handle partial image streaming
1 parent f95fd21 commit 84170b3

8 files changed

Lines changed: 398 additions & 93 deletions

File tree

src/exo/master/api.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import json
23
import time
34
from collections.abc import AsyncGenerator
45
from typing import Literal, cast
@@ -194,7 +195,9 @@ def _setup_routes(self) -> None:
194195
self.app.post("/v1/chat/completions", response_model=None)(
195196
self.chat_completions
196197
)
197-
self.app.post("/v1/images/generations")(self.image_generations)
198+
self.app.post("/v1/images/generations", response_model=None)(
199+
self.image_generations
200+
)
198201
self.app.post("/v1/images/edits")(self.image_edits)
199202
self.app.get("/state")(lambda: self.state)
200203
self.app.get("/events")(lambda: self._event_log)
@@ -551,8 +554,12 @@ async def chat_completions(
551554

552555
async def image_generations(
553556
self, payload: ImageGenerationTaskParams
554-
) -> ImageGenerationResponse:
555-
"""Handle image generation requests."""
557+
) -> ImageGenerationResponse | StreamingResponse:
558+
"""Handle image generation requests.
559+
560+
When stream=True and partial_images > 0, returns a StreamingResponse
561+
with SSE-formatted events for partial and final images.
562+
"""
556563
model_meta = await resolve_model_meta(payload.model)
557564
payload.model = model_meta.model_id
558565

@@ -570,22 +577,128 @@ async def image_generations(
570577
)
571578
await self._send(command)
572579

573-
# Collect all image chunks (non-streaming)
574-
num_images = payload.n or 1
580+
# Check if streaming is requested
581+
if payload.stream and payload.partial_images and payload.partial_images > 0:
582+
return StreamingResponse(
583+
self._generate_image_stream(
584+
command_id=command.command_id,
585+
num_images=payload.n or 1,
586+
response_format=payload.response_format or "b64_json",
587+
),
588+
media_type="text/event-stream",
589+
)
590+
591+
# Non-streaming: collect all image chunks
592+
return await self._collect_image_generation(
593+
command_id=command.command_id,
594+
num_images=payload.n or 1,
595+
response_format=payload.response_format or "b64_json",
596+
)
597+
598+
async def _generate_image_stream(
599+
self,
600+
command_id: CommandId,
601+
num_images: int,
602+
response_format: str,
603+
) -> AsyncGenerator[str, None]:
604+
"""Generate SSE stream of partial and final images."""
605+
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
606+
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
607+
image_total_chunks: dict[tuple[int, bool], int] = {}
608+
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
609+
images_complete = 0
610+
611+
try:
612+
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
613+
614+
with recv as chunks:
615+
async for chunk in chunks:
616+
key = (chunk.image_index, chunk.is_partial)
617+
618+
if key not in image_chunks:
619+
image_chunks[key] = {}
620+
image_total_chunks[key] = chunk.total_chunks
621+
image_metadata[key] = (
622+
chunk.partial_index,
623+
chunk.total_partials,
624+
)
625+
626+
image_chunks[key][chunk.chunk_index] = chunk.data
627+
628+
# Check if this image is complete
629+
if len(image_chunks[key]) == image_total_chunks[key]:
630+
full_data = "".join(
631+
image_chunks[key][i] for i in range(len(image_chunks[key]))
632+
)
633+
634+
partial_idx, total_partials = image_metadata[key]
635+
636+
if chunk.is_partial:
637+
# Yield partial image event
638+
event_data = {
639+
"type": "partial",
640+
"partial_index": partial_idx,
641+
"total_partials": total_partials,
642+
"data": {
643+
"b64_json": full_data
644+
if response_format == "b64_json"
645+
else None,
646+
},
647+
}
648+
yield f"data: {json.dumps(event_data)}\n\n"
649+
else:
650+
# Final image
651+
event_data = {
652+
"type": "final",
653+
"image_index": chunk.image_index,
654+
"data": {
655+
"b64_json": full_data
656+
if response_format == "b64_json"
657+
else None,
658+
},
659+
}
660+
yield f"data: {json.dumps(event_data)}\n\n"
661+
images_complete += 1
575662

663+
if images_complete >= num_images:
664+
yield "data: [DONE]\n\n"
665+
break
666+
667+
# Clean up completed image chunks
668+
del image_chunks[key]
669+
del image_total_chunks[key]
670+
del image_metadata[key]
671+
672+
except anyio.get_cancelled_exc_class():
673+
raise
674+
finally:
675+
await self._send(TaskFinished(finished_command_id=command_id))
676+
if command_id in self._image_generation_queues:
677+
del self._image_generation_queues[command_id]
678+
679+
async def _collect_image_generation(
680+
self,
681+
command_id: CommandId,
682+
num_images: int,
683+
response_format: str,
684+
) -> ImageGenerationResponse:
685+
"""Collect all image chunks (non-streaming) and return a single response."""
576686
# Track chunks per image: {image_index: {chunk_index: data}}
687+
# Only track non-partial (final) images
577688
image_chunks: dict[int, dict[int, str]] = {}
578689
image_total_chunks: dict[int, int] = {}
579690
images_complete = 0
580691

581692
try:
582-
self._image_generation_queues[command.command_id], recv = channel[
583-
ImageChunk
584-
]()
693+
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
585694

586695
while images_complete < num_images:
587696
with recv as chunks:
588697
async for chunk in chunks:
698+
# Skip partial images in non-streaming mode
699+
if chunk.is_partial:
700+
continue
701+
589702
if chunk.image_index not in image_chunks:
590703
image_chunks[chunk.image_index] = {}
591704
image_total_chunks[chunk.image_index] = chunk.total_chunks
@@ -609,26 +722,18 @@ async def image_generations(
609722
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
610723
images.append(
611724
ImageData(
612-
b64_json=full_data
613-
if payload.response_format == "b64_json"
614-
else None,
725+
b64_json=full_data if response_format == "b64_json" else None,
615726
url=None, # URL format not implemented yet
616727
)
617728
)
618729

619730
return ImageGenerationResponse(data=images)
620731
except anyio.get_cancelled_exc_class():
621-
# TODO(ciaran): TaskCancelled
622-
"""
623-
self.command_sender.send_nowait(
624-
ForwarderCommand(origin=self.node_id, command=command)
625-
)
626-
"""
627732
raise
628733
finally:
629-
# Send TaskFinished command
630-
await self._send(TaskFinished(finished_command_id=command.command_id))
631-
del self._image_generation_queues[command.command_id]
734+
await self._send(TaskFinished(finished_command_id=command_id))
735+
if command_id in self._image_generation_queues:
736+
del self._image_generation_queues[command_id]
632737

633738
async def image_edits(
634739
self,

src/exo/shared/types/chunks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class ImageChunk(BaseChunk):
3030
chunk_index: int
3131
total_chunks: int
3232
image_index: int
33+
is_partial: bool = False
34+
partial_index: int | None = None
35+
total_partials: int | None = None
3336

3437
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
3538
for name, value in super().__repr_args__():

src/exo/shared/types/worker/runner_response.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,19 @@ def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
3232
yield name, value
3333

3434

35+
class PartialImageResponse(BaseRunnerResponse):
36+
image_data: bytes
37+
format: Literal["png", "jpeg", "webp"] = "png"
38+
partial_index: int
39+
total_partials: int
40+
41+
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
42+
for name, value in super().__repr_args__():
43+
if name == "image_data":
44+
yield name, f"<{len(self.image_data)} bytes>"
45+
elif name is not None:
46+
yield name, value
47+
48+
3549
class FinishedResponse(BaseRunnerResponse):
3650
pass

src/exo/worker/engines/image/base.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from collections.abc import Generator
12
from pathlib import Path
2-
from typing import Literal, Optional, Protocol, runtime_checkable
3+
from typing import Literal, Protocol, runtime_checkable
34

45
from PIL import Image
56

@@ -21,11 +22,18 @@ def generate(
2122
seed: int,
2223
image_path: Path | None = None,
2324
image_strength: float | None = None,
24-
) -> Optional[Image.Image]:
25+
partial_images: int = 0,
26+
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
2527
"""Generate an image from a text prompt, or edit an existing image.
2628
27-
For distributed inference, only the first stage (rank 0) returns the image.
28-
Other stages return None after participating in the pipeline.
29+
For distributed inference, only the last stage returns images.
30+
Other stages yield nothing after participating in the pipeline.
31+
32+
When partial_images > 0, yields intermediate images during diffusion
33+
as tuples of (image, partial_index, total_partials), then yields
34+
the final image.
35+
36+
When partial_images = 0 (default), only yields the final image.
2937
3038
Args:
3139
prompt: Text description of the image to generate
@@ -35,8 +43,10 @@ def generate(
3543
seed: Random seed for reproducibility
3644
image_path: Optional path to input image for img2img
3745
image_strength: Optional strength for img2img (0.0-1.0, higher = more change)
46+
partial_images: Number of intermediate images to yield (0 for none)
3847
39-
Returns:
40-
Generated PIL Image (rank 0) or None (other ranks)
48+
Yields:
49+
Intermediate images as (Image, partial_index, total_partials) tuples
50+
Final PIL Image (last stage) or nothing (other stages)
4151
"""
4252
...

src/exo/worker/engines/image/distributed_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Generator
12
from pathlib import Path
23
from typing import TYPE_CHECKING, Any, Literal, Optional
34

@@ -186,7 +187,8 @@ def generate(
186187
seed: int = 2,
187188
image_path: Path | None = None,
188189
image_strength: float | None = None,
189-
) -> Optional[Image.Image]:
190+
partial_images: int = 0,
191+
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
190192
# Determine number of inference steps based on quality
191193
steps = self._config.get_steps_for_quality(quality)
192194

@@ -197,20 +199,22 @@ def generate(
197199
image_path=image_path,
198200
image_strength=image_strength,
199201
)
200-
image = self._generate_image(settings=config, prompt=prompt, seed=seed)
201-
logger.info("generated image")
202202

203-
# Only final rank returns the actual image
204-
if self.is_last_stage:
205-
return image.image
206-
207-
def _generate_image(self, settings: Config, prompt: str, seed: int) -> Any:
208-
"""Generate image by delegating to the runner."""
209-
return self._runner.generate_image(
210-
settings=settings,
203+
# Generate images via the runner
204+
for result in self._runner.generate_image(
205+
settings=config,
211206
prompt=prompt,
212207
seed=seed,
213-
)
208+
partial_images=partial_images,
209+
):
210+
if isinstance(result, tuple):
211+
# Partial image: (GeneratedImage, partial_index, total_partials)
212+
generated_image, partial_idx, total_partials = result
213+
yield (generated_image.image, partial_idx, total_partials)
214+
else:
215+
# Final image: GeneratedImage
216+
logger.info("generated image")
217+
yield result.image
214218

215219

216220
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:

0 commit comments

Comments
 (0)