Skip to content

Commit 4bc05b7

Browse files
committed
add requirements.txt, annotate the script and add reference to index.rst
1 parent 62d8773 commit 4bc05b7

File tree

4 files changed

+56
-1
lines changed

4 files changed

+56
-1
lines changed

docsrc/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ Tutorials
111111
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
112112
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
113113
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
114+
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
115+
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
116+
114117

115118
Python API Documenation
116119
------------------------

examples/distributed_inference/data_parallel_gpt2.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
"""
2+
.. _data_parallel_gpt2:
3+
4+
Torch-TensorRT Distributed Inference
5+
======================================================
6+
7+
This interactive script is intended as a sample of distributed inference using data
8+
parallelism using Accelerate
9+
library with the Torch-TensorRT workflow on GPT2 model.
10+
11+
"""
12+
13+
# %%
14+
# Imports and Model Definition
15+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16+
117
import torch
218
from accelerate import PartialState
319
from transformers import AutoTokenizer, GPT2LMHeadModel
@@ -6,6 +22,7 @@
622

723
tokenizer = AutoTokenizer.from_pretrained("gpt2")
824

25+
# Set input prompts for different devices
926
prompt1 = "GPT2 is a model developed by."
1027
prompt2 = "Llama is a model developed by "
1128

@@ -14,8 +31,11 @@
1431

1532
distributed_state = PartialState()
1633

34+
# Import GPT2 model and load to distributed devices
1735
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)
1836

37+
38+
# Instantiate model with Torch-TensorRT backend
1939
model.forward = torch.compile(
2040
model.forward,
2141
backend="torch_tensorrt",
@@ -27,6 +47,11 @@
2747
dynamic=False,
2848
)
2949

50+
# %%
51+
# Inference
52+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
53+
54+
# Assume there are 2 processes (2 devices)
3055
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt:
3156
cur_input = torch.clone(prompt[0]).to(distributed_state.device)
3257

examples/distributed_inference/data_parallel_stable_diffusion.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
"""
2+
.. _data_parallel_stable_diffusion:
3+
4+
Torch-TensorRT Distributed Inference
5+
======================================================
6+
7+
This interactive script is intended as a sample of distributed inference using data
8+
parallelism using Accelerate
9+
library with the Torch-TensorRT workflow on Stable Diffusion model.
10+
11+
"""
12+
13+
# %%
14+
# Imports and Model Definition
15+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
116
import torch
217
from accelerate import PartialState
318
from diffusers import DiffusionPipeline
@@ -17,7 +32,10 @@
1732
backend = "torch_tensorrt"
1833

1934
# Optimize the UNet portion with Torch-TensorRT
20-
pipe.unet = torch.compile(
35+
pipe.unet = torch.compile( # %%
36+
# Inference
37+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
38+
# Assume there are 2 processes (2 devices)
2139
pipe.unet,
2240
backend=backend,
2341
options={
@@ -30,6 +48,12 @@
3048
)
3149
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
3250

51+
52+
# %%
53+
# Inference
54+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55+
56+
# Assume there are 2 processes (2 devices)
3357
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
3458
print("before \n")
3559
result = pipe(prompt).images[0]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
accelerate
2+
transformers
3+
diffusers

0 commit comments

Comments
 (0)