-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Description
Feature Request
This RFC proposes to support Qwen-Image Flow-GRPO Training based on vLLM-Omni.
Motivation
The goal is to enhance verl’s scalability so that it can support online DPO-like training for state-of-the-art diffusion image and video generation models, including Qwen-Image, Z-Image, Wan2.2, and others. We choose Flow-GRPO as the representative algorithm in this domain, while additional algorithms such as DiffusionNFT and DanceGRPO can be seamlessly integrated following this update. As an initial step, Qwen-Image has been selected as the first supported model for multimodal generation tasks.
At present, verl does not support diffusion-based generation models. To enable this functionality, two major extensions are required: first, the addition of a rollout engine capable of handling image and video generation tasks, incorporating components such as vLLM-Omni; and second, the addition of a training engine for diffusion model training, which will rely on diffusers with an FSDP backend. Consequently, integrating diffusers and vLLM-Omni becomes a necessary change after this modification.
However, because of the strong coupling between verl and LLM training, directly modifying verl to support diffusion models could risk altering its original behavior with LLMs. To address this challenge, we suggest separating the pipelines of LLMs and diffusion models, adding diffusion model support through inheritance of the existing classes without breaking the original code.
In the following section, we will first briefly present the overall picture of FlowGRPO, and then demonstrate the necessary code changes for enabling Qwen-Image FlowGRPO training.
Overall Structure
Figure 1. Overview of integrating the FlowGRPO algorithm into verl. The left panel shows the entry point and algorithm implementation with a standalone RayFlowGRPOTrainer. The right panel shows the corresponding workers (class names in bold) that need to be implemented. Other miscellaneous changes, such as configs, dataloaders and the logger, are not shown here.
Algorithm Implementation
And here is a brief explanation of the FlowGRPO algorithm (functions) in the left panel:
i. generate_sequence: The sequence generation for the diffusion model produces a sequence of image/video latent samples during the diffusion process, rather than tokens as in LLMs. This includes the final generated image/video, prompt embeddings, timesteps, and other necessary information from the denoising stage.
- Inputs: Prompts
- Outputs: Generated Images/Videos; Prompt embeddings; Timesteps; Latent samples, log probabilities, and latent sample means during the sampling stage
ii. compute_rm_score: The reward model calculates the score of the generated images/videos based on the user’s task, such as OCR, GenEval Score, Clip Score, etc. For simplicity and general purposes, we may consider API calls to the vLLM server.
- Inputs: Generated Images/Videos
- Outputs: Scores of each image/video
iii. compute_old_log_prob: Similar to the inconsistent behavior between the training and inference engines in LLMs, we also need to add support to calculate the old log probabilities from the training side.
- Inputs: Latent samples, timesteps, and prompt embeddings from the inference engine
- Outputs: Updated old log probabilities
iv. compute_ref_old_prob: We need to calculate the latent sample means from the reference model, used for KL divergence computation.
- Inputs: Latent samples, timesteps, and prompt embeddings from the inference engine
- Outputs: Reference latent sample means
v. compute_advantage: The advantage calculation is basically the same as in the GRPO algorithm.
- Inputs: Scores of each image/video, UIDs to specify the group the sample belongs to
- Outputs: Advantage score of each sample
vi. update_actor: The overall workflow of updating the actor is very similar to PPO/GRPO, with a slightly different loss function from GRPO incorporated.
- Inputs: Latent samples, prompt embeddings; timesteps; Updated old log probabilities, advantages, timesteps, Reference latent sample means
- Outputs: Loss value
New Components
The right panel shows the classes we need to add to support the FlowGRPO algorithm. Other diffusion-based RL algorithms may also use this to speed up the development process.
Rollout
We add DiffusionAgentLoopWorker, AsyncDiffusionServerManager, DiffusionSingleTurnAgentLoop for diffusion-based agent loop and async rollout.
DiffusionAgentLoopWorker runs DiffusionSingleTurnAgentLoop to call server manager to generate sequence.
AsyncDiffusionServerManager calls API to generate response.
We apply the vllm-omni API calling for image generation, which relies on PRs:
vllm-project/vllm-omni#355
vllm-project/vllm-omni#376
vllm-project/vllm-omni#371
1a) DiffusionAgentLoopWorker
The agent loop worker for asynchronous rollout, supports generation in generate_sequence.
class DiffusionAgentLoopWorkerBase:
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
def __init__(self, config, server_handles):
"""Initialize agent loop manager.
Args:
config (DictConfig): YAML config.
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
reward_router_address (str): reward router address.
"""
...
async def generate_sequences(self, batch):
"""Generate responses from agent loop.
Args:
batch (DataProto): Input batch.
Returns:
DataProto: Output batch.
- prompts: [bsz, prompt_length], prompts from dataset.
- responses: [bsz, channel, height, width], output images
from diffusion generation from tool_calls.
- prompt_embeddings: [bsz, ], prompt embeddings
- timesteps: [bsz, ], timesteps
- latent_samples: [bsz, ], latents per step
- log_probs: [bsz, ], log probabilities
- latent_sample_means: [bsz, ], latent means
"""
...
@ray.remote
class DiffusionAgentLoopWorker(DiffusionAgentLoopWorkerBase):
"""Agent loop worker takes a batch of prompts and run each prompt in an agent loop."""
...1b) AsyncDiffusionServerManager
The diffusion server manager to call API for generation.
class AsyncDiffusionServerManager(AsynLLMServerManager):
"""
A class to manage multiple OpenAI compatible Diffusion servers, e.g., vLLM-Omni. This class provides
- Load balance: least requests load balancing
- Sticky session: send multi-turn chat completions to same server for automatic prefix caching
"""
def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000):
"""Initialize the AsyncLLMServerManager.
Args:
config (DictConfig): YAML config.
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000.
"""
...
async def generate(self, request_id, prompt):
"""Generate image/video responses from prompt.
Args:
request_id (str): request id for sticky session.
prompt (str): prompt.
Returns:
ImageOutput: image output with extra info.
"""
...1c) DiffusionSingleTurnAgentLoop
The agent loop supports single-turn response generation from the server.
class DiffusionSingleTurnAgentLoop(AgentLoopBase):
"""Agent loop that only do single turn generation."""
def __init__(self, trainer_config, server_manager):
"""Initialize agent loop, each sample will have its own loop instance.
Args:
trainer_config (DictConfigWrap): trainer config.
server_manager (AsyncDiffusionServerManager): OpenAI compatible diffusion server manager.
"""
...
async def run(self, prompts, **kwargs):
"""
Run agent loop to interact with vLLM-Omni server and environment.
"""
...1d) vLLMOmniReplica
rollout class to launch API servers (i.e., vLLMOmniHttpServer) for async rollout calling.
class vLLMOmniReplica(vLLMReplica):
...1e) vLLMOmniHttpServer
The vLLM http server in a single node and supports vLLM-Omni server calling.
class vLLMOmniHttpServerBase(vLLMHttpServerBase):
async def launch_server(self, master_address: str = None, master_port: int = None):
...
async def wake_up(self):
...
async def generate(self, prompts):
"""Generate images/video from prompts """
...
@ray.remote(num_cpus=1)
class vLLMOmniHttpServer(vLLMOmniHttpServerBase):
...Reward
2a) DiffusionRewardLoopManager
manages reward loop workers, and generate async rewards by calling workers.
class DiffusionRewardLoopManager(RewardLoopManager):
def compute_rm_score(self, data: DataProto) -> DataProto:
...2b) DiffusionRewardLoopWorker
a loop worker to compute rewards for different logics.
@ray.remote
class DiffusionRewardLoopWorker(RewardLoopWorker):
def __init__(self, config: DictConfig, reward_router_address: str = None):
...
async def compute_score_batch(self, data: DataProto) -> list[dict]:
...
async def compute_score(self, data: DataProto) -> dict:
...2c) DiffusionRewardManager
manage diffusion-related reward computing.
class DiffusionRewardManager(RewardManagerBase):
def __init__(self, config, compute_score=None, reward_router_address=None):
...
async def run_single(self, data: DataProto) -> dict:
...Actor
3) DiffusersFSDPEngine
The default base training engine for diffusers models, supporting diffusion pipeline instantiation, forward and backward steps.
class DiffusersFSDPEngine(FSDPEngine):
"""
Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP).
Supports model sharding, activation/optimizer offloading, LoRA.
"""
def __init__(
self,
model_config: HFModelConfig,
engine_config: FSDPEngineConfig,
optimizer_config: FSDPOptimizerConfig,
checkpoint_config: CheckpointConfig,
):
"""
Initialize the DiffusersFSDPEngine.
Set up distributed device meshes, LoRA, and offload policies based on config.
Args:
config: Configuration object with FSDP and model settings.
"""
...
def initialize(self):
"""
Build the model, optimizer, and learning rate scheduler under FSDP.
Applies device, dtype, and precision configurations, including mixed precision.
Sets up checkpoint manager and FLOPs counter.
"""
...
def forward_step(self, micro_batch: TensorDict, loss_function, forward_only):
...Future Plan: Rollout with Async Reward Computing
Asynchronous reward computation during rollout is a commonly used trick for speedup during RL training, it is worth integration.
Feature: Asynchronous reward computation during Rollout stage.
Detail: We will deliver an example script demonstrating asynchronous reward computation during rollout. The core changes will reside in the Rollout class.
Development Plan
See latest progress in PR #5297.
Rollout
- DiffusionAgentLoopWorker @zhtmike @knlnguyen1802
- DiffusionSingleTurnAgentLoop @zhtmike @knlnguyen1802
- vLLM-Omni: support LoRA weight update in vLLM-Omni @AndyZhou952
- vLLM-Omni: APIs of pipeline and worker customization @knlnguyen1802
- vLLM-Omni: Async Server Engine support @knlnguyen1802
- vLLMOmniServerAdapter @zhtmike @knlnguyen1802
- vLLMOmniColocateWorkerExtension (customized vLLM-Omni's worker) @zhtmike @knlnguyen1802
- vLLMOmniHttpServer & vLLMOmniReplica @zhtmike @knlnguyen1802
- QwenImagePipelineWithLogProb @zhtmike @knlnguyen1802
Trainer (left panel)
- Modifications of Engine Workers @zhtmike @chenyingshu
- RayFlowGRPOTrainer @zhtmike
- embedding's padding conversions @zhtmike @chenyingshu
Actor
Reward
- DiffusionRewardLoopManager @chenyingshu
- DiffusionRewardLoopWorker @chenyingshu
- DiffusionRewardManager @chenyingshu
- Reward support: (vllm) Qwen2.5-VL OCR @chenyingshu
- (TBD) Rollout with Async Reward Computing @chenyingshu
FlowGRPO Algorithm Support
- Advantage and Loss Computation @chenyingshu
Dataloader and Logger
- Dataset & Dataloader for T2I Generation RL @chenyingshu
- Support wandb logger, etc. @chenyingshu
