Skip to content

Commit dfaf146

Browse files
mickqiandougyster
authored andcommitted
[diffusion] refactor: refactor sampling params (#13706)
1 parent 7bc738f commit dfaf146

File tree

7 files changed

+135
-149
lines changed

7 files changed

+135
-149
lines changed

python/sglang/multimodal_gen/configs/sample/base.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import dataclasses
66
import hashlib
77
import json
8+
import math
89
import os.path
910
import re
1011
import time
1112
import unicodedata
1213
import uuid
13-
from copy import deepcopy
1414
from dataclasses import dataclass
1515
from enum import Enum, auto
1616
from typing import Any
@@ -137,7 +137,7 @@ class SamplingParams:
137137
return_trajectory_latents: bool = False # returns all latents for each timestep
138138
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
139139

140-
def set_output_file_ext(self):
140+
def _set_output_file_ext(self):
141141
# add extension if needed
142142
if not any(
143143
self.output_file_name.endswith(ext)
@@ -147,7 +147,7 @@ def set_output_file_ext(self):
147147
f"{self.output_file_name}.{self.data_type.get_default_extension()}"
148148
)
149149

150-
def set_output_file_name(self):
150+
def _set_output_file_name(self):
151151
# settle output_file_name
152152
if (
153153
self.output_file_name is None
@@ -178,7 +178,7 @@ def set_output_file_name(self):
178178
self.output_file_name = _sanitize_filename(self.output_file_name)
179179

180180
# Ensure a proper extension is present
181-
self.set_output_file_ext()
181+
self._set_output_file_ext()
182182

183183
def __post_init__(self) -> None:
184184
assert self.num_frames >= 1
@@ -195,6 +195,93 @@ def check_sampling_param(self):
195195
if self.prompt_path and not self.prompt_path.endswith(".txt"):
196196
raise ValueError("prompt_path must be a txt file")
197197

198+
def adjust(
199+
self,
200+
server_args: ServerArgs,
201+
):
202+
"""
203+
final adjustment, called after merged with user params
204+
"""
205+
pipeline_config = server_args.pipeline_config
206+
if not isinstance(self.prompt, str):
207+
raise TypeError(f"`prompt` must be a string, but got {type(self.prompt)}")
208+
209+
# Process negative prompt
210+
if self.negative_prompt is not None and not self.negative_prompt.isspace():
211+
# avoid stripping default negative prompt: ' ' for qwen-image
212+
self.negative_prompt = self.negative_prompt.strip()
213+
214+
# Validate dimensions
215+
if self.num_frames <= 0:
216+
raise ValueError(
217+
f"height, width, and num_frames must be positive integers, got "
218+
f"height={self.height}, width={self.width}, "
219+
f"num_frames={self.num_frames}"
220+
)
221+
222+
if pipeline_config.task_type.is_image_gen():
223+
# settle num_frames
224+
logger.debug(f"Setting num_frames to 1 because this is a image-gen model")
225+
self.num_frames = 1
226+
self.data_type = DataType.IMAGE
227+
else:
228+
# Adjust number of frames based on number of GPUs for video task
229+
use_temporal_scaling_frames = (
230+
pipeline_config.vae_config.use_temporal_scaling_frames
231+
)
232+
num_frames = self.num_frames
233+
num_gpus = server_args.num_gpus
234+
temporal_scale_factor = (
235+
pipeline_config.vae_config.arch_config.temporal_compression_ratio
236+
)
237+
238+
if use_temporal_scaling_frames:
239+
orig_latent_num_frames = (num_frames - 1) // temporal_scale_factor + 1
240+
else: # stepvideo only
241+
orig_latent_num_frames = self.num_frames // 17 * 3
242+
243+
if orig_latent_num_frames % server_args.num_gpus != 0:
244+
# Adjust latent frames to be divisible by number of GPUs
245+
if self.num_frames_round_down:
246+
# Ensure we have at least 1 batch per GPU
247+
new_latent_num_frames = (
248+
max(1, (orig_latent_num_frames // num_gpus)) * num_gpus
249+
)
250+
else:
251+
new_latent_num_frames = (
252+
math.ceil(orig_latent_num_frames / num_gpus) * num_gpus
253+
)
254+
255+
if use_temporal_scaling_frames:
256+
# Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor
257+
new_num_frames = (
258+
new_latent_num_frames - 1
259+
) * temporal_scale_factor + 1
260+
else: # stepvideo only
261+
# Find the least common multiple of 3 and num_gpus
262+
divisor = math.lcm(3, num_gpus)
263+
# Round up to the nearest multiple of this LCM
264+
new_latent_num_frames = (
265+
(new_latent_num_frames + divisor - 1) // divisor
266+
) * divisor
267+
# Convert back to actual frames using the StepVideo formula
268+
new_num_frames = new_latent_num_frames // 3 * 17
269+
270+
logger.info(
271+
"Adjusting number of frames from %s to %s based on number of GPUs (%s)",
272+
self.num_frames,
273+
new_num_frames,
274+
server_args.num_gpus,
275+
)
276+
self.num_frames = new_num_frames
277+
278+
self.num_frames = server_args.pipeline_config.adjust_num_frames(
279+
self.num_frames
280+
)
281+
282+
self._set_output_file_name()
283+
self.log(server_args=server_args)
284+
198285
def update(self, source_dict: dict[str, Any]) -> None:
199286
for key, value in source_dict.items():
200287
if hasattr(self, key):
@@ -220,9 +307,15 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams":
220307
sampling_params = cls(**kwargs)
221308
return sampling_params
222309

223-
def from_user_sampling_params(self, user_params):
224-
sampling_params = deepcopy(self)
225-
sampling_params._merge_with_user_params(user_params)
310+
@staticmethod
311+
def from_user_sampling_params_args(model_path: str, server_args, *args, **kwargs):
312+
sampling_params = SamplingParams.from_pretrained(model_path)
313+
314+
user_sampling_params = SamplingParams(*args, **kwargs)
315+
sampling_params._merge_with_user_params(user_sampling_params)
316+
317+
sampling_params.adjust(server_args)
318+
226319
return sampling_params
227320

228321
@staticmethod

python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,15 @@ def generate(
264264
else DataType.VIDEO
265265
)
266266
pretrained_sampling_params.data_type = data_type
267-
pretrained_sampling_params.set_output_file_name()
267+
pretrained_sampling_params._set_output_file_name()
268+
pretrained_sampling_params.adjust(self.server_args)
268269

269270
requests: list[Req] = []
270271
for output_idx, p in enumerate(prompts):
271272
current_sampling_params = deepcopy(pretrained_sampling_params)
272273
current_sampling_params.prompt = p
273274
requests.append(
274275
prepare_request(
275-
p,
276276
server_args=self.server_args,
277277
sampling_params=current_sampling_params,
278278
)
@@ -310,21 +310,11 @@ def generate(
310310
continue
311311
for output_idx, sample in enumerate(output_batch.output):
312312
num_outputs = len(output_batch.output)
313-
output_file_name = req.output_file_name
314-
if num_outputs > 1 and output_file_name:
315-
base, ext = os.path.splitext(output_file_name)
316-
output_file_name = f"{base}_{output_idx}{ext}"
317-
318-
save_path = (
319-
os.path.join(req.output_path, output_file_name)
320-
if output_file_name
321-
else None
322-
)
323313
frames = self.post_process_sample(
324314
sample,
325315
fps=req.fps,
326316
save_output=req.save_output,
327-
save_file_path=save_path,
317+
save_file_path=req.output_file_path(num_outputs, output_idx),
328318
data_type=req.data_type,
329319
)
330320

python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,10 @@ def _build_sampling_params_from_request(
5656
) -> SamplingParams:
5757
width, height = _parse_size(size)
5858
ext = _choose_ext(output_format, background)
59-
6059
server_args = get_global_server_args()
61-
sampling_params = SamplingParams.from_pretrained(server_args.model_path)
62-
6360
# Build user params
64-
user_params = SamplingParams(
61+
sampling_params = SamplingParams.from_user_sampling_params_args(
62+
model_path=server_args.model_path,
6563
request_id=request_id,
6664
prompt=prompt,
6765
image_path=image_path,
@@ -70,18 +68,9 @@ def _build_sampling_params_from_request(
7068
height=height,
7169
num_outputs_per_prompt=max(1, min(int(n or 1), 10)),
7270
save_output=True,
71+
server_args=server_args,
72+
output_file_name=f"{request_id}.{ext}",
7373
)
74-
75-
# Let SamplingParams auto-generate a file name, then force desired extension
76-
sampling_params = sampling_params.from_user_sampling_params(user_params)
77-
if not sampling_params.output_file_name:
78-
sampling_params.output_file_name = request_id
79-
if not sampling_params.output_file_name.endswith(f".{ext}"):
80-
# strip any existing extension and apply desired one
81-
base = sampling_params.output_file_name.rsplit(".", 1)[0]
82-
sampling_params.output_file_name = f"{base}.{ext}"
83-
84-
sampling_params.log(server_args)
8574
return sampling_params
8675

8776

@@ -107,7 +96,6 @@ def _build_req_from_sampling(s: SamplingParams) -> Req:
10796
async def generations(
10897
request: ImageGenerationsRequest,
10998
):
110-
11199
request_id = generate_request_id()
112100
sampling = _build_sampling_params_from_request(
113101
request_id=request_id,
@@ -118,7 +106,6 @@ async def generations(
118106
background=request.background,
119107
)
120108
batch = prepare_request(
121-
prompt=request.prompt,
122109
server_args=get_global_server_args(),
123110
sampling_params=sampling,
124111
)
@@ -175,7 +162,6 @@ async def edits(
175162
background: Optional[str] = Form("auto"),
176163
user: Optional[str] = Form(None),
177164
):
178-
179165
request_id = generate_request_id()
180166
# Resolve images from either `image` or `image[]` (OpenAI SDK sends `image[]` when list is provided)
181167
images = image or image_array

python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
router = APIRouter(prefix="/v1/videos", tags=["videos"])
4343

4444

45+
# NOTE(mick): the sampling params needs to be further adjusted
46+
# FIXME: duplicated with the one in `image_api.py`
4547
def _build_sampling_params_from_request(
4648
request_id: str, request: VideoGenerationsRequest
4749
) -> SamplingParams:
@@ -56,9 +58,8 @@ def _build_sampling_params_from_request(
5658
request.num_frames if request.num_frames is not None else derived_num_frames
5759
)
5860
server_args = get_global_server_args()
59-
# TODO: should we cache this sampling_params?
60-
sampling_params = SamplingParams.from_pretrained(server_args.model_path)
61-
user_params = SamplingParams(
61+
sampling_params = SamplingParams.from_user_sampling_params_args(
62+
model_path=server_args.model_path,
6263
request_id=request_id,
6364
prompt=request.prompt,
6465
num_frames=num_frames,
@@ -67,10 +68,10 @@ def _build_sampling_params_from_request(
6768
height=height,
6869
image_path=request.input_reference,
6970
save_output=True,
71+
server_args=server_args,
72+
output_file_name=request_id,
7073
)
71-
sampling_params = sampling_params.from_user_sampling_params(user_params)
72-
sampling_params.set_output_file_name()
73-
sampling_params.log(server_args)
74+
7475
return sampling_params
7576

7677

@@ -195,7 +196,6 @@ async def create_video(
195196

196197
# Build Req for scheduler
197198
batch = prepare_request(
198-
prompt=req.prompt,
199199
server_args=get_global_server_args(),
200200
sampling_params=sampling_params,
201201
)

0 commit comments

Comments
 (0)