Skip to content

[WIP][RFC] Support Qwen-Image Flow-GRPO Training based on vLLM-Omni #4639

@chenyingshu

Description

@chenyingshu

Feature Request

This RFC proposes to support Qwen-Image Flow-GRPO Training based on vLLM-Omni.

Motivation

The goal is to enhance verl’s scalability so that it can support online DPO-like training for state-of-the-art diffusion image and video generation models, including Qwen-Image, Z-Image, Wan2.2, and others. We choose Flow-GRPO as the representative algorithm in this domain, while additional algorithms such as DiffusionNFT and DanceGRPO can be seamlessly integrated following this update. As an initial step, Qwen-Image has been selected as the first supported model for multimodal generation tasks.

At present, verl does not support diffusion-based generation models. To enable this functionality, two major extensions are required: first, the addition of a rollout engine capable of handling image and video generation tasks, incorporating components such as vLLM-Omni; and second, the addition of a training engine for diffusion model training, which will rely on diffusers with an FSDP backend. Consequently, integrating diffusers and vLLM-Omni becomes a necessary change after this modification.

However, because of the strong coupling between verl and LLM training, directly modifying verl to support diffusion models could risk altering its original behavior with LLMs. To address this challenge, we suggest separating the pipelines of LLMs and diffusion models, adding diffusion model support through inheritance of the existing classes without breaking the original code.

In the following section, we will first briefly present the overall picture of FlowGRPO, and then demonstrate the necessary code changes for enabling Qwen-Image FlowGRPO training.

Overall Structure

Image

Figure 1. Overview of integrating the FlowGRPO algorithm into verl. The left panel shows the entry point and algorithm implementation with a standalone RayFlowGRPOTrainer. The right panel shows the corresponding workers (class names in bold) that need to be implemented. Other miscellaneous changes, such as configs, dataloaders and the logger, are not shown here.

Algorithm Implementation

And here is a brief explanation of the FlowGRPO algorithm (functions) in the left panel:

i. generate_sequence: The sequence generation for the diffusion model produces a sequence of image/video latent samples during the diffusion process, rather than tokens as in LLMs. This includes the final generated image/video, prompt embeddings, timesteps, and other necessary information from the denoising stage.

  • Inputs: Prompts
  • Outputs: Generated Images/Videos; Prompt embeddings; Timesteps; Latent samples, log probabilities, and latent sample means during the sampling stage

ii. compute_rm_score: The reward model calculates the score of the generated images/videos based on the user’s task, such as OCR, GenEval Score, Clip Score, etc. For simplicity and general purposes, we may consider API calls to the vLLM server.

  • Inputs: Generated Images/Videos
  • Outputs: Scores of each image/video

iii. compute_old_log_prob: Similar to the inconsistent behavior between the training and inference engines in LLMs, we also need to add support to calculate the old log probabilities from the training side.

  • Inputs: Latent samples, timesteps, and prompt embeddings from the inference engine
  • Outputs: Updated old log probabilities

iv. compute_ref_old_prob: We need to calculate the latent sample means from the reference model, used for KL divergence computation.

  • Inputs: Latent samples, timesteps, and prompt embeddings from the inference engine
  • Outputs: Reference latent sample means

v. compute_advantage: The advantage calculation is basically the same as in the GRPO algorithm.

  • Inputs: Scores of each image/video, UIDs to specify the group the sample belongs to
  • Outputs: Advantage score of each sample

vi. update_actor: The overall workflow of updating the actor is very similar to PPO/GRPO, with a slightly different loss function from GRPO incorporated.

  • Inputs: Latent samples, prompt embeddings; timesteps; Updated old log probabilities, advantages, timesteps, Reference latent sample means
  • Outputs: Loss value

New Components

The right panel shows the classes we need to add to support the FlowGRPO algorithm. Other diffusion-based RL algorithms may also use this to speed up the development process.

Rollout

We add DiffusionAgentLoopWorker, AsyncDiffusionServerManager, DiffusionSingleTurnAgentLoop for diffusion-based agent loop and async rollout.
DiffusionAgentLoopWorker runs DiffusionSingleTurnAgentLoop to call server manager to generate sequence.
AsyncDiffusionServerManager calls API to generate response.

We apply the vllm-omni API calling for image generation, which relies on PRs:
vllm-project/vllm-omni#355
vllm-project/vllm-omni#376
vllm-project/vllm-omni#371

1a) DiffusionAgentLoopWorker

The agent loop worker for asynchronous rollout, supports generation in generate_sequence.

class DiffusionAgentLoopWorkerBase:
    """Agent loop worker takes a batch of messages and run each message in an agent loop."""
    def __init__(self, config, server_handles):
        """Initialize agent loop manager.
        Args:
            config (DictConfig): YAML config.
            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
            reward_router_address (str): reward router address.
        """
        ...

    async def generate_sequences(self, batch):
        """Generate responses from agent loop.
        Args:
            batch (DataProto): Input batch.
        Returns:
            DataProto: Output batch.
            - prompts: [bsz, prompt_length], prompts from dataset.
            - responses: [bsz, channel, height, width],  output images
              from diffusion generation from tool_calls.
            - prompt_embeddings: [bsz, ], prompt embeddings
            - timesteps: [bsz, ], timesteps
            - latent_samples: [bsz, ], latents per step
            - log_probs: [bsz, ], log probabilities
            - latent_sample_means: [bsz, ], latent means
        """
        ...

@ray.remote
class DiffusionAgentLoopWorker(DiffusionAgentLoopWorkerBase):
    """Agent loop worker takes a batch of prompts and run each prompt in an agent loop."""
    ...

1b) AsyncDiffusionServerManager

The diffusion server manager to call API for generation.

class AsyncDiffusionServerManager(AsynLLMServerManager):
    """
    A class to manage multiple OpenAI compatible Diffusion servers, e.g., vLLM-Omni. This class provides
    - Load balance: least requests load balancing
    - Sticky session: send multi-turn chat completions to same server for automatic prefix caching
    """
    def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000):
        """Initialize the AsyncLLMServerManager.

        Args:
            config (DictConfig): YAML config.
            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
            max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000.
        """
        ...

    async def generate(self, request_id, prompt):
        """Generate image/video responses from prompt.

        Args:
            request_id (str): request id for sticky session.
            prompt (str): prompt.

        Returns:
            ImageOutput: image output with extra info.

        """
        ...

1c) DiffusionSingleTurnAgentLoop

The agent loop supports single-turn response generation from the server.

class DiffusionSingleTurnAgentLoop(AgentLoopBase):
    """Agent loop that only do single turn generation."""

    def __init__(self, trainer_config, server_manager):
        """Initialize agent loop, each sample will have its own loop instance.
        Args:
            trainer_config (DictConfigWrap): trainer config.
            server_manager (AsyncDiffusionServerManager): OpenAI compatible diffusion server manager.
        """
        ...

    async def run(self, prompts, **kwargs):
        """
        Run agent loop to interact with vLLM-Omni server and environment.
        """
        ...

1d) vLLMOmniReplica

rollout class to launch API servers (i.e., vLLMOmniHttpServer) for async rollout calling.

class vLLMOmniReplica(vLLMReplica):
    ...

1e) vLLMOmniHttpServer

The vLLM http server in a single node and supports vLLM-Omni server calling.

class vLLMOmniHttpServerBase(vLLMHttpServerBase):
    async def launch_server(self, master_address: str = None, master_port: int = None):
        ...

    async def wake_up(self):
        ...

    async def generate(self, prompts):
        """Generate images/video from prompts """
        ...

@ray.remote(num_cpus=1)
class vLLMOmniHttpServer(vLLMOmniHttpServerBase):
    ...

Reward

2a) DiffusionRewardLoopManager

manages reward loop workers, and generate async rewards by calling workers.

class DiffusionRewardLoopManager(RewardLoopManager):
    def compute_rm_score(self, data: DataProto) -> DataProto:
        ...

2b) DiffusionRewardLoopWorker

a loop worker to compute rewards for different logics.

@ray.remote
class DiffusionRewardLoopWorker(RewardLoopWorker):
    def __init__(self, config: DictConfig, reward_router_address: str = None):
        ...

    async def compute_score_batch(self, data: DataProto) -> list[dict]:
        ...

    async def compute_score(self, data: DataProto) -> dict:
        ...

2c) DiffusionRewardManager

manage diffusion-related reward computing.

class DiffusionRewardManager(RewardManagerBase):
    def __init__(self, config, compute_score=None, reward_router_address=None):
        ...

    async def run_single(self, data: DataProto) -> dict:
        ...

Actor

3) DiffusersFSDPEngine

The default base training engine for diffusers models, supporting diffusion pipeline instantiation, forward and backward steps.

class DiffusersFSDPEngine(FSDPEngine):
    """
    Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP).
    Supports model sharding, activation/optimizer offloading, LoRA.
    """
    def __init__(
        self,
        model_config: HFModelConfig,
        engine_config: FSDPEngineConfig,
        optimizer_config: FSDPOptimizerConfig,
        checkpoint_config: CheckpointConfig,
    ):
        """
        Initialize the DiffusersFSDPEngine.
        Set up distributed device meshes, LoRA, and offload policies based on config.
        Args:
            config: Configuration object with FSDP and model settings.
        """
        ...

    def initialize(self):
        """
        Build the model, optimizer, and learning rate scheduler under FSDP.

        Applies device, dtype, and precision configurations, including mixed precision.
        Sets up checkpoint manager and FLOPs counter.
        """
        ...

    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only):
        ...

Future Plan: Rollout with Async Reward Computing

Asynchronous reward computation during rollout is a commonly used trick for speedup during RL training, it is worth integration.

Feature: Asynchronous reward computation during Rollout stage.

with_reward Left: Synchronous reward computing. Right: Asynchronous reward computing during rollout.

Detail: We will deliver an example script demonstrating asynchronous reward computation during rollout. The core changes will reside in the Rollout class.

Development Plan

See latest progress in PR #5297.

Rollout

Trainer (left panel)

Actor

Reward

FlowGRPO Algorithm Support

Dataloader and Logger

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions