11import base64
2+ import json
23import time
34from collections .abc import AsyncGenerator
45from typing import Literal , cast
@@ -194,7 +195,9 @@ def _setup_routes(self) -> None:
194195 self .app .post ("/v1/chat/completions" , response_model = None )(
195196 self .chat_completions
196197 )
197- self .app .post ("/v1/images/generations" )(self .image_generations )
198+ self .app .post ("/v1/images/generations" , response_model = None )(
199+ self .image_generations
200+ )
198201 self .app .post ("/v1/images/edits" )(self .image_edits )
199202 self .app .get ("/state" )(lambda : self .state )
200203 self .app .get ("/events" )(lambda : self ._event_log )
@@ -551,8 +554,12 @@ async def chat_completions(
551554
552555 async def image_generations (
553556 self , payload : ImageGenerationTaskParams
554- ) -> ImageGenerationResponse :
555- """Handle image generation requests."""
557+ ) -> ImageGenerationResponse | StreamingResponse :
558+ """Handle image generation requests.
559+
560+ When stream=True and partial_images > 0, returns a StreamingResponse
561+ with SSE-formatted events for partial and final images.
562+ """
556563 model_meta = await resolve_model_meta (payload .model )
557564 payload .model = model_meta .model_id
558565
@@ -570,22 +577,128 @@ async def image_generations(
570577 )
571578 await self ._send (command )
572579
573- # Collect all image chunks (non-streaming)
574- num_images = payload .n or 1
580+ # Check if streaming is requested
581+ if payload .stream and payload .partial_images and payload .partial_images > 0 :
582+ return StreamingResponse (
583+ self ._generate_image_stream (
584+ command_id = command .command_id ,
585+ num_images = payload .n or 1 ,
586+ response_format = payload .response_format or "b64_json" ,
587+ ),
588+ media_type = "text/event-stream" ,
589+ )
590+
591+ # Non-streaming: collect all image chunks
592+ return await self ._collect_image_generation (
593+ command_id = command .command_id ,
594+ num_images = payload .n or 1 ,
595+ response_format = payload .response_format or "b64_json" ,
596+ )
597+
598+ async def _generate_image_stream (
599+ self ,
600+ command_id : CommandId ,
601+ num_images : int ,
602+ response_format : str ,
603+ ) -> AsyncGenerator [str , None ]:
604+ """Generate SSE stream of partial and final images."""
605+ # Track chunks: {(image_index, is_partial): {chunk_index: data}}
606+ image_chunks : dict [tuple [int , bool ], dict [int , str ]] = {}
607+ image_total_chunks : dict [tuple [int , bool ], int ] = {}
608+ image_metadata : dict [tuple [int , bool ], tuple [int | None , int | None ]] = {}
609+ images_complete = 0
610+
611+ try :
612+ self ._image_generation_queues [command_id ], recv = channel [ImageChunk ]()
613+
614+ with recv as chunks :
615+ async for chunk in chunks :
616+ key = (chunk .image_index , chunk .is_partial )
617+
618+ if key not in image_chunks :
619+ image_chunks [key ] = {}
620+ image_total_chunks [key ] = chunk .total_chunks
621+ image_metadata [key ] = (
622+ chunk .partial_index ,
623+ chunk .total_partials ,
624+ )
625+
626+ image_chunks [key ][chunk .chunk_index ] = chunk .data
627+
628+ # Check if this image is complete
629+ if len (image_chunks [key ]) == image_total_chunks [key ]:
630+ full_data = "" .join (
631+ image_chunks [key ][i ] for i in range (len (image_chunks [key ]))
632+ )
633+
634+ partial_idx , total_partials = image_metadata [key ]
635+
636+ if chunk .is_partial :
637+ # Yield partial image event
638+ event_data = {
639+ "type" : "partial" ,
640+ "partial_index" : partial_idx ,
641+ "total_partials" : total_partials ,
642+ "data" : {
643+ "b64_json" : full_data
644+ if response_format == "b64_json"
645+ else None ,
646+ },
647+ }
648+ yield f"data: { json .dumps (event_data )} \n \n "
649+ else :
650+ # Final image
651+ event_data = {
652+ "type" : "final" ,
653+ "image_index" : chunk .image_index ,
654+ "data" : {
655+ "b64_json" : full_data
656+ if response_format == "b64_json"
657+ else None ,
658+ },
659+ }
660+ yield f"data: { json .dumps (event_data )} \n \n "
661+ images_complete += 1
575662
663+ if images_complete >= num_images :
664+ yield "data: [DONE]\n \n "
665+ break
666+
667+ # Clean up completed image chunks
668+ del image_chunks [key ]
669+ del image_total_chunks [key ]
670+ del image_metadata [key ]
671+
672+ except anyio .get_cancelled_exc_class ():
673+ raise
674+ finally :
675+ await self ._send (TaskFinished (finished_command_id = command_id ))
676+ if command_id in self ._image_generation_queues :
677+ del self ._image_generation_queues [command_id ]
678+
679+ async def _collect_image_generation (
680+ self ,
681+ command_id : CommandId ,
682+ num_images : int ,
683+ response_format : str ,
684+ ) -> ImageGenerationResponse :
685+ """Collect all image chunks (non-streaming) and return a single response."""
576686 # Track chunks per image: {image_index: {chunk_index: data}}
687+ # Only track non-partial (final) images
577688 image_chunks : dict [int , dict [int , str ]] = {}
578689 image_total_chunks : dict [int , int ] = {}
579690 images_complete = 0
580691
581692 try :
582- self ._image_generation_queues [command .command_id ], recv = channel [
583- ImageChunk
584- ]()
693+ self ._image_generation_queues [command_id ], recv = channel [ImageChunk ]()
585694
586695 while images_complete < num_images :
587696 with recv as chunks :
588697 async for chunk in chunks :
698+ # Skip partial images in non-streaming mode
699+ if chunk .is_partial :
700+ continue
701+
589702 if chunk .image_index not in image_chunks :
590703 image_chunks [chunk .image_index ] = {}
591704 image_total_chunks [chunk .image_index ] = chunk .total_chunks
@@ -609,26 +722,18 @@ async def image_generations(
609722 full_data = "" .join (chunks_dict [i ] for i in range (len (chunks_dict )))
610723 images .append (
611724 ImageData (
612- b64_json = full_data
613- if payload .response_format == "b64_json"
614- else None ,
725+ b64_json = full_data if response_format == "b64_json" else None ,
615726 url = None , # URL format not implemented yet
616727 )
617728 )
618729
619730 return ImageGenerationResponse (data = images )
620731 except anyio .get_cancelled_exc_class ():
621- # TODO(ciaran): TaskCancelled
622- """
623- self.command_sender.send_nowait(
624- ForwarderCommand(origin=self.node_id, command=command)
625- )
626- """
627732 raise
628733 finally :
629- # Send TaskFinished command
630- await self ._send ( TaskFinished ( finished_command_id = command . command_id ))
631- del self ._image_generation_queues [command . command_id ]
734+ await self . _send ( TaskFinished ( finished_command_id = command_id ))
735+ if command_id in self ._image_generation_queues :
736+ del self ._image_generation_queues [command_id ]
632737
633738 async def image_edits (
634739 self ,
0 commit comments