import torchao
import torch_tensorrt
import torch
import torch.nn.functional as F
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
_, _, h, w = x.shape
z = F.interpolate(x, [h*2, w*2])
return z
model = Model().cuda()
inputs = (torch.randn((2, 4, 8, 8)).cuda(),)
with torch.no_grad():
ep = torch.export.export(
model,
args=inputs,
strict=True
)
with torch_tensorrt.logging.debug():
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs,
reuse_cached_engines=False,
cache_built_engines=False,
require_full_compilation=True,
min_block_size=1,
)
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/projects/scripts/mre/torchao_tensorrt_import.py", line 25, in <module>
trt_gm = torch_tensorrt.dynamo.compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 228, in compile
exported_program = exported_program.run_decompositions(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 116, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 1111, in run_decompositions
return _decompose_exported_program(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 654, in _decompose_exported_program
gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 446, in _decompose_and_get_gm_with_new_signature_constants
gm, graph_signature = aot_export_module(
^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1262, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1497, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 625, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 194, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 859, in functional_call
out = PropagateUnbackedSymInts(mod).run(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6495, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 275, in call_function
return target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 64, in inner
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 37, in autograd_not_implemented_inner
result = operator(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
r = func.decompose(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
return self.py_kernels[dk](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 3184, in upsample_nearest2d_vec
return torch.ops.aten.upsample_nearest2d.default(input, osize)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
r = func.decompose(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
return self.py_kernels[dk](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_decomp/__init__.py", line 376, in _special_op_to_decompose_cia
raise AssertionError(
AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel
While executing %upsample_nearest2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_nearest2d.vec](args = (%x, [16, 16], None), kwargs = {})
Original traceback:
File "/projects/scripts/mre/torchao_tensorrt_import.py", line 12, in forward
z = F.interpolate(x, [h*2, w*2])
Bug Description
Importing
torchaobefore importingtorch_tensorrtcausesF.interpolateto fail duringrun_decompositionswith:AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernelTo Reproduce
Logs:
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/projects/scripts/mre/torchao_tensorrt_import.py", line 25, in <module> trt_gm = torch_tensorrt.dynamo.compile( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 228, in compile exported_program = exported_program.run_decompositions( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 116, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 1111, in run_decompositions return _decompose_exported_program( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 654, in _decompose_exported_program gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 446, in _decompose_and_get_gm_with_new_signature_constants gm, graph_signature = aot_export_module( ^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1262, in aot_export_module fx_g, metadata, in_spec, out_spec = _aot_export_function( ^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1497, in _aot_export_function fx_g, meta = create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function return _create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 625, in _create_aot_dispatcher_function fw_metadata = run_functionalized_fw_and_collect_metadata( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 194, in inner flat_f_outs = f(*flat_f_args) ^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn tree_out = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 859, in functional_call out = PropagateUnbackedSymInts(mod).run( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6495, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 275, in call_function return target(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__ return self._op(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 64, in inner return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 37, in autograd_not_implemented_inner result = operator(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__ return self._op(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__ r = func.decompose(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose return self.py_kernels[dk](*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 3184, in upsample_nearest2d_vec return torch.ops.aten.upsample_nearest2d.default(input, osize) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__ return self._op(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__ r = func.decompose(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose return self.py_kernels[dk](*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_decomp/__init__.py", line 376, in _special_op_to_decompose_cia raise AssertionError( AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel While executing %upsample_nearest2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_nearest2d.vec](args = (%x, [16, 16], None), kwargs = {}) Original traceback: File "/projects/scripts/mre/torchao_tensorrt_import.py", line 12, in forward z = F.interpolate(x, [h*2, w*2])Expected behavior
The import order between
torch_tensorrtandtorchaoshould not matter.Environment
conda,pip,libtorch, source): pipAdditional context
If you import
torch_tensorrtfirst and thentorchaothe error disappears.