Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c04b6be
remove mixin classes
andompesta Feb 20, 2025
84b68bd
remove vaeMixin
andompesta Feb 20, 2025
81450d5
remove exporters classes
andompesta Feb 20, 2025
9dfc659
remove mixin classes
andompesta Feb 20, 2025
9596aa3
remove vaeMixin
andompesta Feb 20, 2025
19d3e3e
Merge pull request #8 from andompesta/trt-refactor-remove-exporters
andompesta Feb 21, 2025
9d8837e
add trt-config base classes to collect all trt attribute needed to bu…
andompesta Feb 20, 2025
b1f7df6
add build docstring
andompesta Feb 20, 2025
7c8dc8c
clip trt config
andompesta Feb 20, 2025
e665bc9
t5 trt config
andompesta Feb 20, 2025
7b162c0
transformer trt config
andompesta Feb 20, 2025
f305c97
vae trt config
andompesta Feb 20, 2025
7489b4f
trt config init package
andompesta Feb 20, 2025
c79d2d7
guidance-embed is a bool not an int
andompesta Feb 20, 2025
0bb5bf1
update engine to use trt-config
andompesta Feb 20, 2025
33a2844
update engine to use trt-config
andompesta Feb 20, 2025
b57be34
simplify registry keys
andompesta Feb 20, 2025
7e5e43e
add missing dataclass
andompesta Feb 21, 2025
4587729
Merge pull request #9 from andompesta/trt-refactor-add-configs
andompesta Feb 21, 2025
1fbcfae
refactor trt-manager to use trt-configs to build engines
andompesta Feb 20, 2025
b0a51e8
update `load_engines` naming to match `trt_` conventions
andompesta Feb 21, 2025
72c2655
not needed to know the cuda device to use
andompesta Feb 21, 2025
cea2f80
onnx dir need to exists
andompesta Feb 21, 2025
666ef26
Merge pull request #10 from andompesta/trt-refactor-trt-manager
andompesta Feb 21, 2025
db02180
implement cli-control trt modification
andompesta Feb 21, 2025
99eb549
better handle trt dependencies
andompesta Feb 21, 2025
ae4ab3d
fix assertion error
andompesta Feb 21, 2025
ef44038
add clis support for TRT engines inference
andompesta Feb 21, 2025
3f0aa2a
Merge pull request #11 from andompesta/trt-refactor-clis
andompesta Feb 21, 2025
da05f35
parse model precision from string
andompesta Mar 3, 2025
2c1794c
use string for model precisions, not boolean flags
andompesta Mar 3, 2025
2a49b57
use precision string instead of boolean flags
andompesta Mar 3, 2025
4787420
format
andompesta Mar 3, 2025
2e3d077
t5 fp8 config
andompesta Mar 3, 2025
0281e1a
use precision as string to check proper ckp
andompesta Mar 3, 2025
3b7b82c
from {model-nmae}={precision} to {model-name}-{precision}
andompesta Mar 3, 2025
e024471
use new format
andompesta Mar 3, 2025
9889768
add t5-fp8 config
andompesta Mar 3, 2025
a3d5826
use same config for both t5 model precisions
andompesta Mar 6, 2025
e75dc12
Revert "add clis support for TRT engines inference"
andompesta Mar 6, 2025
7835c31
add cleanup memory function
andompesta Mar 6, 2025
7750516
transformer and t5 precision provided as separate variables
andompesta Mar 6, 2025
fdb13c6
add check if trt is installed or not
andompesta Mar 6, 2025
0ea2c10
add missing arg description
andompesta Mar 6, 2025
e9443d9
new interface where all trt-args have a `trt` prefix
andompesta Mar 6, 2025
c16bb18
stream and runtime-init done when engines are build
andompesta Mar 6, 2025
c12de3a
use private clueanup function
andompesta Mar 6, 2025
c05245c
improove logging
andompesta Mar 6, 2025
2e87f66
Revert "implement cli-control trt modification"
andompesta Mar 6, 2025
60885e9
add checks to import trt
andompesta Mar 6, 2025
e441347
add support for t5 in multiple precisions
andompesta Mar 6, 2025
28ca974
add stop trt runtime
andompesta Mar 6, 2025
2881d48
Merge pull request #12 from andompesta/t5-fp8
andompesta Mar 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.trt.trt_manager import TRTManager
try:
from flux.trt.trt_manager import TRTManager
TRT_AVAIABLE = True
except: # noqa: E722
TRT_AVAIABLE = False
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image

NSFW_THRESHOLD = 0.85
Expand Down Expand Up @@ -132,6 +136,7 @@ def main(
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
trt: use TensorRT backend for optimized inference
trt_transformer_precision: specify transformer precision for inference
kwargs: additional arguments for TensorRT support
"""

Expand Down Expand Up @@ -179,42 +184,34 @@ def main(
ae = load_ae(name, device="cpu" if offload else torch_device)

if trt:
# offload to CPU to save memory
ae = ae.cpu()
model = model.cpu()
clip = clip.cpu()
t5 = t5.cpu()

torch.cuda.empty_cache()
if not TRT_AVAIABLE:
raise ModuleNotFoundError(
"TRT dependencies are needed. Follow README instruction to setup the tensorrt environment."
)

trt_ctx_manager = TRTManager(
bf16=True,
device=torch_device,
static_batch=kwargs.get("static_batch", True),
static_shape=kwargs.get("static_shape", True),
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
)
ae.decoder.params = ae.params
engines = trt_ctx_manager.load_engines(
models={
"clip": clip,
"transformer": model,
"t5": t5,
"vae": ae.decoder,
"clip": clip.cpu(),
"transformer": model.cpu(),
"t5": t5.cpu(),
"vae": ae.decoder.cpu(),
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
transformer_precision=trt_transformer_precision,
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_static_batch=kwargs.get("static_batch", True),
trt_static_shape=kwargs.get("static_shape", True),
)

torch.cuda.synchronize()

trt_ctx_manager.init_runtime()
# TODO: refactor. stream should be part of engine constructor maybe !!
for _, engine in engines.items():
engine.set_stream(stream=trt_ctx_manager.stream)

if not offload:
for _, engine in engines.items():
engine.load()
Expand Down
38 changes: 24 additions & 14 deletions src/flux/cli_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
from flux.trt.trt_manager import TRTManager

try:
from flux.trt.trt_manager import TRTManager

TRT_AVAIABLE = True
except: # noqa: E722
TRT_AVAIABLE = False
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image


Expand Down Expand Up @@ -198,6 +204,7 @@ def main(
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
trt: use TensorRT backend for optimized inference
trt_transformer_precision: specify transformer precision for inference
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

Expand Down Expand Up @@ -240,7 +247,7 @@ def main(

# set lora scale
if "lora" in name and lora_scale is not None:
assert not trt, "TRT does not support LORA yet"
assert not trt, "TRT does not support LORA"
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(lora_scale)
Expand All @@ -253,11 +260,14 @@ def main(
raise NotImplementedError()

if trt:
if not TRT_AVAIABLE:
raise ModuleNotFoundError(
"TRT dependencies are needed. Follow README instruction to setup the tensorrt environment."
)

trt_ctx_manager = TRTManager(
bf16=True,
device=torch_device,
static_batch=kwargs.get("static_batch", True),
static_shape=kwargs.get("static_shape", True),
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
)
ae.decoder.params = ae.params
ae.encoder.params = ae.params
Expand All @@ -271,17 +281,14 @@ def main(
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
transformer_precision=trt_transformer_precision,
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_static_batch=kwargs.get("static_batch", True),
trt_static_shape=kwargs.get("static_shape", True),
)
torch.cuda.synchronize()

trt_ctx_manager.init_runtime()
# TODO: refactor. stream should be part of engine constructor maybe !!
for _, engine in engines.items():
engine.set_stream(stream=trt_ctx_manager.stream)

if not offload:
for _, engine in engines.items():
engine.load()
Expand Down Expand Up @@ -390,6 +397,9 @@ def main(
else:
opts = None

if trt:
trt_ctx_manager.stop_runtime()


def app():
Fire(main)
Expand Down
30 changes: 15 additions & 15 deletions src/flux/trt/engine/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import engine_from_bytes

from flux.trt.trt_config import TRTBaseConfig

TRT_LOGGER = trt.Logger(trt.Logger.ERROR)


Expand Down Expand Up @@ -57,12 +59,9 @@ def cpu(self) -> "BaseEngine":
pass

@abstractmethod
def to(self, device: str) -> "BaseEngine":
def to(self, device: str | torch.device) -> "BaseEngine":
pass

@abstractmethod
def set_stream(self, stream):
pass

@abstractmethod
def load(self):
Expand All @@ -80,10 +79,11 @@ def activate(
class Engine(BaseEngine):
def __init__(
self,
engine_path: str,
trt_config: TRTBaseConfig,
stream: cudart.cudaStream_t,
):
self.engine_path = engine_path
self.stream = None
self.trt_config = trt_config
self.stream = stream
self.engine: trt.ICudaEngine | None = None
self.context = None
self.tensors = OrderedDict()
Expand Down Expand Up @@ -120,29 +120,29 @@ def set_stream(self, stream):

def load(self):
if self.engine is not None:
print(f"[W]: Engine {self.engine_path} already loaded, skip reloading")
print(f"[W]: Engine {self.trt_config.engine_path} already loaded, skip reloading")
return

if not hasattr(self, "engine_bytes_cpu") or self.engine_bytes_cpu is None:
# keep a cpu copy of the engine to reduce reloading time.
print(f"Loading TensorRT engine to cpu bytes: {self.engine_path}")
self.engine_bytes_cpu = bytes_from_path(self.engine_path)
print(f"[I] Loading TensorRT engine to cpu bytes: {self.trt_config.engine_path}")
self.engine_bytes_cpu = bytes_from_path(self.trt_config.engine_path)

print(f"Loading TensorRT engine: {self.engine_path}")
print(f"[I] Loading TensorRT engine: {self.trt_config.engine_path}")
self.engine = engine_from_bytes(self.engine_bytes_cpu)

def unload(self):
if self.engine is not None:
print(f"Unloading TensorRT engine: {self.engine_path}")
print(f"[I] Unloading TensorRT engine: {self.trt_config.engine_path}")
del self.engine
self.engine = None
gc.collect()
else:
print(f"[W]: Unload an unloaded engine {self.engine_path}, skip unloading")
print(f"[W]: Unload an unloaded engine {self.trt_config.engine_path}, skip unloading")

def activate(
self,
device: str,
device: str | torch.device,
device_memory: int | None = None,
):
self.device = device
Expand All @@ -165,7 +165,7 @@ def deactivate(self):
def allocate_buffers(
self,
shape_dict: dict[str, tuple],
device="cuda",
device: str | torch.device = "cuda",
):
for binding in range(self.engine.num_io_tensors):
tensor_name = self.engine.get_tensor_name(binding)
Expand Down
25 changes: 12 additions & 13 deletions src/flux/trt/engine/clip_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,26 @@
# limitations under the License.

import torch
from cuda.cudart import cudaStream_t
from transformers import CLIPTokenizer

from flux.trt.engine import Engine
from flux.trt.mixin import CLIPMixin
from flux.trt.trt_config import ClipConfig


class CLIPEngine(CLIPMixin, Engine):
class CLIPEngine(Engine):
def __init__(
self,
text_maxlen: int,
hidden_size: int,
engine_path: str,
trt_config: ClipConfig,
stream: cudaStream_t,
):
super().__init__(
text_maxlen=text_maxlen,
hidden_size=hidden_size,
engine_path=engine_path,
trt_config=trt_config,
stream=stream,
)
self.tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14",
max_length=self.text_maxlen,
max_length=self.trt_config.text_maxlen,
)

def __call__(
Expand All @@ -48,7 +47,7 @@ def __call__(
feed_dict = self.tokenizer(
prompt,
truncation=True,
max_length=self.text_maxlen,
max_length=self.trt_config.text_maxlen,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
Expand All @@ -65,8 +64,8 @@ def get_shape_dict(
batch_size: int,
) -> dict[str, tuple]:
return {
"input_ids": (batch_size, self.text_maxlen),
"pooled_embeddings": (batch_size, self.hidden_size),
"input_ids": (batch_size, self.trt_config.text_maxlen),
"pooled_embeddings": (batch_size, self.trt_config.hidden_size),
# Onnx model coming from HF has also this input
"text_embeddings": (batch_size, self.text_maxlen, self.hidden_size),
"text_embeddings": (batch_size, self.trt_config.text_maxlen, self.trt_config.hidden_size),
}
23 changes: 11 additions & 12 deletions src/flux/trt/engine/t5_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,26 @@
# limitations under the License.

import torch
from cuda.cudart import cudaStream_t
from transformers import T5Tokenizer

from flux.trt.engine import Engine
from flux.trt.mixin import T5Mixin
from flux.trt.trt_config import T5Config


class T5Engine(T5Mixin, Engine):
class T5Engine(Engine):
def __init__(
self,
text_maxlen: int,
hidden_size: int,
engine_path: str,
trt_config: T5Config,
stream: cudaStream_t,
):
super().__init__(
text_maxlen=text_maxlen,
hidden_size=hidden_size,
engine_path=engine_path,
trt_config=trt_config,
stream=stream,
)
self.tokenizer = T5Tokenizer.from_pretrained(
"google/t5-v1_1-xxl",
max_length=self.text_maxlen,
max_length=self.trt_config.text_maxlen,
)

def __call__(
Expand All @@ -49,7 +48,7 @@ def __call__(
feed_dict = self.tokenizer(
prompt,
truncation=True,
max_length=self.text_maxlen,
max_length=self.trt_config.text_maxlen,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
Expand All @@ -66,6 +65,6 @@ def get_shape_dict(
batch_size: int,
) -> dict[str, tuple]:
return {
"input_ids": (batch_size, self.text_maxlen),
"text_embeddings": (batch_size, self.text_maxlen, self.hidden_size),
"input_ids": (batch_size, self.trt_config.text_maxlen),
"text_embeddings": (batch_size, self.trt_config.text_maxlen, self.trt_config.hidden_size),
}
Loading