Skip to content

🐛 [Bug] FLUX Attention Bug #4194

@cehongwang

Description

@cehongwang

Bug Description

File "/home/TensorRT/experiments/refit_flux_benchmark.py", line 112, in run_flux_benchmark
    trt_gm = torch_trt.dynamo.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 798, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 1044, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 343, in convert_module
    serialized_interpreter_result = interpret_module_to_result(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 277, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 605, in run
    self._construct_trt_network_def()
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 412, in _construct_trt_network_def
    super().run()
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 200, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 678, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 297, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 785, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 4044, in aten_ops_scaled_dot_product_flash_attention
    return impl.attention.scaled_dot_product_flash_attention(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py", line 285, in scaled_dot_product_flash_attention
    assert attention_layer is not None, "attention layer is None"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: attention layer is None

To Reproduce

Run flux_demo.py in examples/apps

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions