55import dataclasses
66import hashlib
77import json
8+ import math
89import os .path
910import re
1011import time
1112import unicodedata
1213import uuid
13- from copy import deepcopy
1414from dataclasses import dataclass
1515from enum import Enum , auto
1616from typing import Any
@@ -137,7 +137,7 @@ class SamplingParams:
137137 return_trajectory_latents : bool = False # returns all latents for each timestep
138138 return_trajectory_decoded : bool = False # returns decoded latents for each timestep
139139
140- def set_output_file_ext (self ):
140+ def _set_output_file_ext (self ):
141141 # add extension if needed
142142 if not any (
143143 self .output_file_name .endswith (ext )
@@ -147,7 +147,7 @@ def set_output_file_ext(self):
147147 f"{ self .output_file_name } .{ self .data_type .get_default_extension ()} "
148148 )
149149
150- def set_output_file_name (self ):
150+ def _set_output_file_name (self ):
151151 # settle output_file_name
152152 if (
153153 self .output_file_name is None
@@ -178,7 +178,7 @@ def set_output_file_name(self):
178178 self .output_file_name = _sanitize_filename (self .output_file_name )
179179
180180 # Ensure a proper extension is present
181- self .set_output_file_ext ()
181+ self ._set_output_file_ext ()
182182
183183 def __post_init__ (self ) -> None :
184184 assert self .num_frames >= 1
@@ -195,6 +195,93 @@ def check_sampling_param(self):
195195 if self .prompt_path and not self .prompt_path .endswith (".txt" ):
196196 raise ValueError ("prompt_path must be a txt file" )
197197
198+ def adjust (
199+ self ,
200+ server_args : ServerArgs ,
201+ ):
202+ """
203+ final adjustment, called after merged with user params
204+ """
205+ pipeline_config = server_args .pipeline_config
206+ if not isinstance (self .prompt , str ):
207+ raise TypeError (f"`prompt` must be a string, but got { type (self .prompt )} " )
208+
209+ # Process negative prompt
210+ if self .negative_prompt is not None and not self .negative_prompt .isspace ():
211+ # avoid stripping default negative prompt: ' ' for qwen-image
212+ self .negative_prompt = self .negative_prompt .strip ()
213+
214+ # Validate dimensions
215+ if self .num_frames <= 0 :
216+ raise ValueError (
217+ f"height, width, and num_frames must be positive integers, got "
218+ f"height={ self .height } , width={ self .width } , "
219+ f"num_frames={ self .num_frames } "
220+ )
221+
222+ if pipeline_config .task_type .is_image_gen ():
223+ # settle num_frames
224+ logger .debug (f"Setting num_frames to 1 because this is a image-gen model" )
225+ self .num_frames = 1
226+ self .data_type = DataType .IMAGE
227+ else :
228+ # Adjust number of frames based on number of GPUs for video task
229+ use_temporal_scaling_frames = (
230+ pipeline_config .vae_config .use_temporal_scaling_frames
231+ )
232+ num_frames = self .num_frames
233+ num_gpus = server_args .num_gpus
234+ temporal_scale_factor = (
235+ pipeline_config .vae_config .arch_config .temporal_compression_ratio
236+ )
237+
238+ if use_temporal_scaling_frames :
239+ orig_latent_num_frames = (num_frames - 1 ) // temporal_scale_factor + 1
240+ else : # stepvideo only
241+ orig_latent_num_frames = self .num_frames // 17 * 3
242+
243+ if orig_latent_num_frames % server_args .num_gpus != 0 :
244+ # Adjust latent frames to be divisible by number of GPUs
245+ if self .num_frames_round_down :
246+ # Ensure we have at least 1 batch per GPU
247+ new_latent_num_frames = (
248+ max (1 , (orig_latent_num_frames // num_gpus )) * num_gpus
249+ )
250+ else :
251+ new_latent_num_frames = (
252+ math .ceil (orig_latent_num_frames / num_gpus ) * num_gpus
253+ )
254+
255+ if use_temporal_scaling_frames :
256+ # Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor
257+ new_num_frames = (
258+ new_latent_num_frames - 1
259+ ) * temporal_scale_factor + 1
260+ else : # stepvideo only
261+ # Find the least common multiple of 3 and num_gpus
262+ divisor = math .lcm (3 , num_gpus )
263+ # Round up to the nearest multiple of this LCM
264+ new_latent_num_frames = (
265+ (new_latent_num_frames + divisor - 1 ) // divisor
266+ ) * divisor
267+ # Convert back to actual frames using the StepVideo formula
268+ new_num_frames = new_latent_num_frames // 3 * 17
269+
270+ logger .info (
271+ "Adjusting number of frames from %s to %s based on number of GPUs (%s)" ,
272+ self .num_frames ,
273+ new_num_frames ,
274+ server_args .num_gpus ,
275+ )
276+ self .num_frames = new_num_frames
277+
278+ self .num_frames = server_args .pipeline_config .adjust_num_frames (
279+ self .num_frames
280+ )
281+
282+ self ._set_output_file_name ()
283+ self .log (server_args = server_args )
284+
198285 def update (self , source_dict : dict [str , Any ]) -> None :
199286 for key , value in source_dict .items ():
200287 if hasattr (self , key ):
@@ -220,9 +307,15 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams":
220307 sampling_params = cls (** kwargs )
221308 return sampling_params
222309
223- def from_user_sampling_params (self , user_params ):
224- sampling_params = deepcopy (self )
225- sampling_params ._merge_with_user_params (user_params )
310+ @staticmethod
311+ def from_user_sampling_params_args (model_path : str , server_args , * args , ** kwargs ):
312+ sampling_params = SamplingParams .from_pretrained (model_path )
313+
314+ user_sampling_params = SamplingParams (* args , ** kwargs )
315+ sampling_params ._merge_with_user_params (user_sampling_params )
316+
317+ sampling_params .adjust (server_args )
318+
226319 return sampling_params
227320
228321 @staticmethod
0 commit comments