Skip to content

Unexpected result on multi-batch gather. #2299

@grimoire

Description

@grimoire

Description

Gathering topk index on a multi-batch tensor gives unexpected results.
Note that if we replace the profile with:

    C=10
    input_shapes = {
        'input': {
            'min_shape': [1, C, 4],
            'opt_shape': [2, C, 4],
            'max_shape': [4, C, 4]
        }
    }

Given the right result.

Please read the code below for more detail.

Environment

TensorRT Version: 8.4.1.5
NVIDIA GPU: 2060s
NVIDIA Driver Version: 510.85.02
CUDA Version: 11.3
CUDNN Version: 8.2.1
Operating System: Ubuntu18.04
Python Version (if applicable): 3.7
Tensorflow Version (if applicable):
PyTorch Version (if applicable): 1.10.0
Baremetal or Container (if so, version):

Relevant Files

Steps To Reproduce

import torch
import tensorrt as trt
import onnx
from typing import Dict


def from_onnx(onnx_model, input_shapes, max_workspace_size):
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    EXPLICIT_BATCH = 1 << (int)(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(EXPLICIT_BATCH)

    # parse onnx
    parser = trt.OnnxParser(network, logger)

    if isinstance(onnx_model, str):
        onnx_model = onnx.load(onnx_model)

    if not parser.parse(onnx_model.SerializeToString()):
        error_msgs = ''
        for error in range(parser.num_errors):
            error_msgs += f'{parser.get_error(error)}\n'
        raise RuntimeError(f'Failed to parse onnx, {error_msgs}')

    config = builder.create_builder_config()
    config.max_workspace_size = max_workspace_size

    profile = builder.create_optimization_profile()

    for input_name, param in input_shapes.items():
        min_shape = param['min_shape']
        opt_shape = param['opt_shape']
        max_shape = param['max_shape']
        profile.set_shape(input_name, min_shape, opt_shape, max_shape)
    config.add_optimization_profile(profile)

    engine = builder.build_engine(network, config)

    return engine


TORCH_DTYPE_MAP = {
    trt.bool: torch.bool,
    trt.int8: torch.int8,
    trt.int32: torch.int32,
    trt.float16: torch.float16,
    trt.float32: torch.float32
}


class TRTWrapper(torch.nn.Module):

    def __init__(self, engine: trt.ICudaEngine):
        super().__init__()
        self.engine = engine

        if not isinstance(self.engine, trt.ICudaEngine):
            raise TypeError(f'`engine` should be str or trt.ICudaEngine, \
                but given: {type(self.engine)}')

        self.context = self.engine.create_execution_context()
        self.__load_io_names()

    def __load_io_names(self):
        """Load input/output names from engine."""
        names = [_ for _ in self.engine]
        input_names = list(filter(self.engine.binding_is_input, names))
        self._input_names = input_names

        output_names = list(set(names) - set(input_names))
        self._output_names = output_names

    def forward(self, inputs: Dict[str,
                                   torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Run forward inference.

        Args:
            inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.

        Return:
            Dict[str, torch.Tensor]: The output name and tensor pairs.
        """
        bindings = [None] * (len(self._input_names) + len(self._output_names))

        for input_name, input_tensor in inputs.items():
            idx = self.engine.get_binding_index(input_name)

            # All input tensors must be gpu variables
            input_tensor = input_tensor.contiguous()
            if input_tensor.dtype == torch.long:
                input_tensor = input_tensor.int()
            self.context.set_binding_shape(idx, tuple(input_tensor.shape))
            bindings[idx] = input_tensor.contiguous().data_ptr()

        # create output tensors
        outputs = {}
        for output_name in self._output_names:
            idx = self.engine.get_binding_index(output_name)
            dtype = TORCH_DTYPE_MAP[self.engine.get_binding_dtype(idx)]
            shape = tuple(self.context.get_binding_shape(idx))

            output = torch.empty(size=shape, dtype=dtype, device='cuda')
            outputs[output_name] = output
            bindings[idx] = output.data_ptr()

        self.context.execute_async_v2(bindings,
                                      torch.cuda.current_stream().cuda_stream)

        return outputs


class TestModel(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        batch_size = x.size(0)
        C = x.size(1)
        max_x, _ = x.max(-1)
        _, inds = max_x.topk(4)
        batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)

        # new_x = torch.gather(x, 1, inds.unsqueeze(-1).expand(batch_size, 4, 4))
        new_x = x[batch_inds, inds, ...]
        # new_x = x.flatten(0, 1)[inds + batch_inds * C]
        return new_x, inds + batch_inds * C


def main():
    # models
    model = TestModel().cuda()
    x = torch.rand(1, 10, 4).cuda()

    # export onnx
    input_names = ['input']
    output_names = ['output', 'inds']
    torch.onnx.export(
        model,
        x,
        'tmp.onnx',
        input_names=input_names,
        output_names=output_names,
        dynamic_axes={'input': {
            0: 'b',
            1: 'n'
        }},
        opset_version=11)

    # export tensorrt
    input_shapes = {
        'input': {
            'min_shape': [1, 5, 4],
            'opt_shape': [2, 10, 4],
            'max_shape': [4, 40, 4]
        }
    }
    engine = from_onnx(
        'tmp.onnx', input_shapes=input_shapes, max_workspace_size=1 << 30)

    wrapper = TRTWrapper(engine)

    x = torch.rand(2, 10, 4).cuda()

    torch_out = model(x)
    out = wrapper({'input': x})
    out = [out[name] for name in output_names]

    # print(x)

    for o, to in zip(out, torch_out):
        print(o.shape)
        torch.testing.assert_allclose(o, to)

    # print(torch_out)


if __name__ == '__main__':
    main()

Metadata

Metadata

Assignees

Labels

triagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions