2929from pipeline_txt2img_xl import Txt2ImgXLPipeline
3030
3131
32- def run_demo ():
33- """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""
34-
35- args = parse_arguments (is_xl = True , description = "Options for Stable Diffusion XL Demo" )
36-
37- prompt , negative_prompt = repeat_prompt (args )
38-
39- # Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf).
40- image_height = args .height
41- image_width = args .width
42-
32+ def load_pipelines (args , batch_size ):
4333 # Register TensorRT plugins
4434 engine_type = get_engine_type (args .engine )
4535 if engine_type == EngineType .TRT :
@@ -49,37 +39,83 @@ def run_demo():
4939
5040 max_batch_size = 16
5141 if (engine_type in [EngineType .ORT_TRT , EngineType .TRT ]) and (
52- args .build_dynamic_shape or image_height > 512 or image_width > 512
42+ args .build_dynamic_shape or args . height > 512 or args . width > 512
5343 ):
5444 max_batch_size = 4
5545
56- batch_size = len (prompt )
5746 if batch_size > max_batch_size :
5847 raise ValueError (f"Batch size { batch_size } is larger than allowed { max_batch_size } ." )
5948
49+ # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
50+ # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
51+ # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
52+ min_image_size = 832 if args .engine != "ORT_CUDA" else 512
53+ max_image_size = 1216 if args .engine != "ORT_CUDA" else 2048
54+
6055 # No VAE decoder in base when it outputs latent instead of image.
61- base_info = PipelineInfo (args .version , use_vae = False )
62- base = init_pipeline (Txt2ImgXLPipeline , base_info , engine_type , args , max_batch_size , batch_size )
56+ base_info = PipelineInfo (
57+ args .version , use_vae = args .disable_refiner , min_image_size = min_image_size , max_image_size = max_image_size
58+ )
6359
64- refiner_info = PipelineInfo (args .version , is_refiner = True )
65- refiner = init_pipeline (Img2ImgXLPipeline , refiner_info , engine_type , args , max_batch_size , batch_size )
60+ # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
61+ # optimize the shape used most frequently. We can let user config it when we develop a UI plugin.
62+ # In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine.
63+ # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference.
64+ opt_batch_size = 1 if args .build_dynamic_batch else batch_size
65+ opt_image_height = base_info .default_image_size () if args .build_dynamic_shape else args .height
66+ opt_image_width = base_info .default_image_size () if args .build_dynamic_shape else args .width
67+
68+ base = init_pipeline (
69+ Txt2ImgXLPipeline ,
70+ base_info ,
71+ engine_type ,
72+ args ,
73+ max_batch_size ,
74+ opt_batch_size ,
75+ opt_image_height ,
76+ opt_image_width ,
77+ )
78+
79+ refiner = None
80+ if not args .disable_refiner :
81+ refiner_info = PipelineInfo (
82+ args .version , is_refiner = True , min_image_size = min_image_size , max_image_size = max_image_size
83+ )
84+ refiner = init_pipeline (
85+ Img2ImgXLPipeline ,
86+ refiner_info ,
87+ engine_type ,
88+ args ,
89+ max_batch_size ,
90+ opt_batch_size ,
91+ opt_image_height ,
92+ opt_image_width ,
93+ )
6694
6795 if engine_type == EngineType .TRT :
68- max_device_memory = max (base .backend .max_device_memory (), refiner .backend .max_device_memory ())
96+ max_device_memory = max (base .backend .max_device_memory (), ( refiner or base ) .backend .max_device_memory ())
6997 _ , shared_device_memory = cudart .cudaMalloc (max_device_memory )
7098 base .backend .activate_engines (shared_device_memory )
71- refiner .backend .activate_engines (shared_device_memory )
99+ if refiner :
100+ refiner .backend .activate_engines (shared_device_memory )
72101
73102 if engine_type == EngineType .ORT_CUDA :
74103 enable_vae_slicing = args .enable_vae_slicing
75104 if batch_size > 4 and not enable_vae_slicing :
76105 print ("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4." )
77106 enable_vae_slicing = True
78107 if enable_vae_slicing :
79- refiner .backend .enable_vae_slicing ()
108+ (refiner or base ).backend .enable_vae_slicing ()
109+ return base , refiner
110+
80111
112+ def run_pipelines (args , base , refiner , prompt , negative_prompt , is_warm_up = False ):
113+ image_height = args .height
114+ image_width = args .width
115+ batch_size = len (prompt )
81116 base .load_resources (image_height , image_width , batch_size )
82- refiner .load_resources (image_height , image_width , batch_size )
117+ if refiner :
118+ refiner .load_resources (image_height , image_width , batch_size )
83119
84120 def run_base_and_refiner (warmup = False ):
85121 images , time_base = base .run (
@@ -91,8 +127,13 @@ def run_base_and_refiner(warmup=False):
91127 denoising_steps = args .denoising_steps ,
92128 guidance = args .guidance ,
93129 seed = args .seed ,
94- return_type = "latent" ,
130+ return_type = "latent" if refiner else "image" ,
95131 )
132+ if refiner is None :
133+ return images , time_base
134+
135+ # Use same seed in base and refiner.
136+ seed = base .get_current_seed ()
96137
97138 images , time_refiner = refiner .run (
98139 prompt ,
@@ -103,7 +144,7 @@ def run_base_and_refiner(warmup=False):
103144 warmup = warmup ,
104145 denoising_steps = args .denoising_steps ,
105146 guidance = args .guidance ,
106- seed = args . seed ,
147+ seed = seed ,
107148 )
108149
109150 return images , time_base + time_refiner
@@ -112,25 +153,104 @@ def run_base_and_refiner(warmup=False):
112153 # inference once to get cuda graph
113154 _ , _ = run_base_and_refiner (warmup = True )
114155
115- print ("[I] Warming up .." )
156+ if args .num_warmup_runs > 0 :
157+ print ("[I] Warming up .." )
116158 for _ in range (args .num_warmup_runs ):
117159 _ , _ = run_base_and_refiner (warmup = True )
118160
161+ if is_warm_up :
162+ return
163+
119164 print ("[I] Running StableDiffusion XL pipeline" )
120165 if args .nvtx_profile :
121166 cudart .cudaProfilerStart ()
122167 _ , latency = run_base_and_refiner (warmup = False )
123168 if args .nvtx_profile :
124169 cudart .cudaProfilerStop ()
125170
126- base .teardown ()
127-
128171 print ("|------------|--------------|" )
129172 print ("| {:^10} | {:>9.2f} ms |" .format ("e2e" , latency ))
130173 print ("|------------|--------------|" )
131- refiner .teardown ()
174+
175+
176+ def run_demo (args ):
177+ """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""
178+
179+ prompt , negative_prompt = repeat_prompt (args )
180+ batch_size = len (prompt )
181+ base , refiner = load_pipelines (args , batch_size )
182+ run_pipelines (args , base , refiner , prompt , negative_prompt )
183+ base .teardown ()
184+ if refiner :
185+ refiner .teardown ()
186+
187+
188+ def run_dynamic_shape_demo (args ):
189+ """Run demo of generating images with different settings with ORT CUDA provider."""
190+ args .engine = "ORT_CUDA"
191+ args .disable_cuda_graph = True
192+ base , refiner = load_pipelines (args , 1 )
193+
194+ prompts = [
195+ "starry night over Golden Gate Bridge by van gogh" ,
196+ "beautiful photograph of Mt. Fuji during cherry blossom" ,
197+ "little cute gremlin sitting on a bed, cinematic" ,
198+ "cute grey cat with blue eyes, wearing a bowtie, acrylic painting" ,
199+ "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation" ,
200+ "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic" ,
201+ ]
202+
203+ # batch size, height, width, scheduler, steps, prompt, seed
204+ configs = [
205+ (1 , 832 , 1216 , "UniPC" , 8 , prompts [0 ], None ),
206+ (1 , 1024 , 1024 , "DDIM" , 24 , prompts [1 ], None ),
207+ (1 , 1216 , 832 , "UniPC" , 16 , prompts [2 ], None ),
208+ (1 , 1344 , 768 , "DDIM" , 24 , prompts [3 ], None ),
209+ (2 , 640 , 1536 , "UniPC" , 16 , prompts [4 ], 4312973633252712 ),
210+ (2 , 1152 , 896 , "DDIM" , 24 , prompts [5 ], 1964684802882906 ),
211+ ]
212+
213+ # Warm up each combination of (batch size, height, width) once before serving.
214+ args .prompt = ["warm up" ]
215+ args .num_warmup_runs = 1
216+ for batch_size , height , width , _ , _ , _ , _ in configs :
217+ args .batch_size = batch_size
218+ args .height = height
219+ args .width = width
220+ print (f"\n Warm up batch_size={ batch_size } , height={ height } , width={ width } " )
221+ prompt , negative_prompt = repeat_prompt (args )
222+ run_pipelines (args , base , refiner , prompt , negative_prompt , is_warm_up = True )
223+
224+ # Run pipeline on a list of prompts.
225+ args .num_warmup_runs = 0
226+ for batch_size , height , width , scheduler , steps , example_prompt , seed in configs :
227+ args .prompt = [example_prompt ]
228+ args .batch_size = batch_size
229+ args .height = height
230+ args .width = width
231+ args .scheduler = scheduler
232+ args .denoising_steps = steps
233+ args .seed = seed
234+ base .set_scheduler (scheduler )
235+ if refiner :
236+ refiner .set_scheduler (scheduler )
237+ print (
238+ f"\n batch_size={ batch_size } , height={ height } , width={ width } , scheduler={ scheduler } , steps={ steps } , prompt={ example_prompt } , seed={ seed } "
239+ )
240+ prompt , negative_prompt = repeat_prompt (args )
241+ run_pipelines (args , base , refiner , prompt , negative_prompt , is_warm_up = False )
242+
243+ base .teardown ()
244+ if refiner :
245+ refiner .teardown ()
132246
133247
134248if __name__ == "__main__" :
135249 coloredlogs .install (fmt = "%(funcName)20s: %(message)s" )
136- run_demo ()
250+
251+ args = parse_arguments (is_xl = True , description = "Options for Stable Diffusion XL Demo" )
252+ no_prompt = isinstance (args .prompt , list ) and len (args .prompt ) == 1 and not args .prompt [0 ]
253+ if no_prompt :
254+ run_dynamic_shape_demo (args )
255+ else :
256+ run_demo (args )
0 commit comments