File tree Expand file tree Collapse file tree 4 files changed +56
-1
lines changed
examples/distributed_inference Expand file tree Collapse file tree 4 files changed +56
-1
lines changed Original file line number Diff line number Diff line change @@ -111,6 +111,9 @@ Tutorials
111111 tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
112112 tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
113113 tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
114+ tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
115+ tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
116+
114117
115118Python API Documenation
116119------------------------
Original file line number Diff line number Diff line change 1+ """
2+ .. _data_parallel_gpt2:
3+
4+ Torch-TensorRT Distributed Inference
5+ ======================================================
6+
7+ This interactive script is intended as a sample of distributed inference using data
8+ parallelism using Accelerate
9+ library with the Torch-TensorRT workflow on GPT2 model.
10+
11+ """
12+
13+ # %%
14+ # Imports and Model Definition
15+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16+
117import torch
218from accelerate import PartialState
319from transformers import AutoTokenizer , GPT2LMHeadModel
622
723tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
824
25+ # Set input prompts for different devices
926prompt1 = "GPT2 is a model developed by."
1027prompt2 = "Llama is a model developed by "
1128
1431
1532distributed_state = PartialState ()
1633
34+ # Import GPT2 model and load to distributed devices
1735model = GPT2LMHeadModel .from_pretrained ("gpt2" ).eval ().to (distributed_state .device )
1836
37+
38+ # Instantiate model with Torch-TensorRT backend
1939model .forward = torch .compile (
2040 model .forward ,
2141 backend = "torch_tensorrt" ,
2747 dynamic = False ,
2848)
2949
50+ # %%
51+ # Inference
52+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
53+
54+ # Assume there are 2 processes (2 devices)
3055with distributed_state .split_between_processes ([input_id1 , input_id2 ]) as prompt :
3156 cur_input = torch .clone (prompt [0 ]).to (distributed_state .device )
3257
Original file line number Diff line number Diff line change 1+ """
2+ .. _data_parallel_stable_diffusion:
3+
4+ Torch-TensorRT Distributed Inference
5+ ======================================================
6+
7+ This interactive script is intended as a sample of distributed inference using data
8+ parallelism using Accelerate
9+ library with the Torch-TensorRT workflow on Stable Diffusion model.
10+
11+ """
12+
13+ # %%
14+ # Imports and Model Definition
15+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
116import torch
217from accelerate import PartialState
318from diffusers import DiffusionPipeline
1732backend = "torch_tensorrt"
1833
1934# Optimize the UNet portion with Torch-TensorRT
20- pipe .unet = torch .compile (
35+ pipe .unet = torch .compile ( # %%
36+ # Inference
37+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
38+ # Assume there are 2 processes (2 devices)
2139 pipe .unet ,
2240 backend = backend ,
2341 options = {
3048)
3149torch_tensorrt .runtime .set_multi_device_safe_mode (True )
3250
51+
52+ # %%
53+ # Inference
54+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55+
56+ # Assume there are 2 processes (2 devices)
3357with distributed_state .split_between_processes (["a dog" , "a cat" ]) as prompt :
3458 print ("before \n " )
3559 result = pipe (prompt ).images [0 ]
Original file line number Diff line number Diff line change 1+ accelerate
2+ transformers
3+ diffusers
You can’t perform that action at this time.
0 commit comments