-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Closed
Labels
triagedIssue has been triaged by maintainersIssue has been triaged by maintainers
Description
Description
Gathering topk index on a multi-batch tensor gives unexpected results.
Note that if we replace the profile with:
C=10
input_shapes = {
'input': {
'min_shape': [1, C, 4],
'opt_shape': [2, C, 4],
'max_shape': [4, C, 4]
}
}Given the right result.
Please read the code below for more detail.
Environment
TensorRT Version: 8.4.1.5
NVIDIA GPU: 2060s
NVIDIA Driver Version: 510.85.02
CUDA Version: 11.3
CUDNN Version: 8.2.1
Operating System: Ubuntu18.04
Python Version (if applicable): 3.7
Tensorflow Version (if applicable):
PyTorch Version (if applicable): 1.10.0
Baremetal or Container (if so, version):
Relevant Files
Steps To Reproduce
import torch
import tensorrt as trt
import onnx
from typing import Dict
def from_onnx(onnx_model, input_shapes, max_workspace_size):
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
EXPLICIT_BATCH = 1 << (int)(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)
# parse onnx
parser = trt.OnnxParser(network, logger)
if isinstance(onnx_model, str):
onnx_model = onnx.load(onnx_model)
if not parser.parse(onnx_model.SerializeToString()):
error_msgs = ''
for error in range(parser.num_errors):
error_msgs += f'{parser.get_error(error)}\n'
raise RuntimeError(f'Failed to parse onnx, {error_msgs}')
config = builder.create_builder_config()
config.max_workspace_size = max_workspace_size
profile = builder.create_optimization_profile()
for input_name, param in input_shapes.items():
min_shape = param['min_shape']
opt_shape = param['opt_shape']
max_shape = param['max_shape']
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
engine = builder.build_engine(network, config)
return engine
TORCH_DTYPE_MAP = {
trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32
}
class TRTWrapper(torch.nn.Module):
def __init__(self, engine: trt.ICudaEngine):
super().__init__()
self.engine = engine
if not isinstance(self.engine, trt.ICudaEngine):
raise TypeError(f'`engine` should be str or trt.ICudaEngine, \
but given: {type(self.engine)}')
self.context = self.engine.create_execution_context()
self.__load_io_names()
def __load_io_names(self):
"""Load input/output names from engine."""
names = [_ for _ in self.engine]
input_names = list(filter(self.engine.binding_is_input, names))
self._input_names = input_names
output_names = list(set(names) - set(input_names))
self._output_names = output_names
def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Run forward inference.
Args:
inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.
Return:
Dict[str, torch.Tensor]: The output name and tensor pairs.
"""
bindings = [None] * (len(self._input_names) + len(self._output_names))
for input_name, input_tensor in inputs.items():
idx = self.engine.get_binding_index(input_name)
# All input tensors must be gpu variables
input_tensor = input_tensor.contiguous()
if input_tensor.dtype == torch.long:
input_tensor = input_tensor.int()
self.context.set_binding_shape(idx, tuple(input_tensor.shape))
bindings[idx] = input_tensor.contiguous().data_ptr()
# create output tensors
outputs = {}
for output_name in self._output_names:
idx = self.engine.get_binding_index(output_name)
dtype = TORCH_DTYPE_MAP[self.engine.get_binding_dtype(idx)]
shape = tuple(self.context.get_binding_shape(idx))
output = torch.empty(size=shape, dtype=dtype, device='cuda')
outputs[output_name] = output
bindings[idx] = output.data_ptr()
self.context.execute_async_v2(bindings,
torch.cuda.current_stream().cuda_stream)
return outputs
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
batch_size = x.size(0)
C = x.size(1)
max_x, _ = x.max(-1)
_, inds = max_x.topk(4)
batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)
# new_x = torch.gather(x, 1, inds.unsqueeze(-1).expand(batch_size, 4, 4))
new_x = x[batch_inds, inds, ...]
# new_x = x.flatten(0, 1)[inds + batch_inds * C]
return new_x, inds + batch_inds * C
def main():
# models
model = TestModel().cuda()
x = torch.rand(1, 10, 4).cuda()
# export onnx
input_names = ['input']
output_names = ['output', 'inds']
torch.onnx.export(
model,
x,
'tmp.onnx',
input_names=input_names,
output_names=output_names,
dynamic_axes={'input': {
0: 'b',
1: 'n'
}},
opset_version=11)
# export tensorrt
input_shapes = {
'input': {
'min_shape': [1, 5, 4],
'opt_shape': [2, 10, 4],
'max_shape': [4, 40, 4]
}
}
engine = from_onnx(
'tmp.onnx', input_shapes=input_shapes, max_workspace_size=1 << 30)
wrapper = TRTWrapper(engine)
x = torch.rand(2, 10, 4).cuda()
torch_out = model(x)
out = wrapper({'input': x})
out = [out[name] for name in output_names]
# print(x)
for o, to in zip(out, torch_out):
print(o.shape)
torch.testing.assert_allclose(o, to)
# print(torch_out)
if __name__ == '__main__':
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
triagedIssue has been triaged by maintainersIssue has been triaged by maintainers