Skip to content

TensorRT show no improvement in inference speed #43

@Tigerold

Description

@Tigerold

I attempted to deploy the dsvt model to TensorRT according to your deployment code, By the TensorRT official example code I used dynamic shape for dsvt_block model input, Model inference time is about 260ms. However, using pytorch version takes less time, about 140ms. Why the time takes more with TensorRT c++ code?

Environment
TensorRT Version: 8.5.1.7
CUDA Version: 11.8
CUDNN Version: 8.6
Hardware GPU: p4000
(the rest is the same as the public)

inference code

#include "trt_infer.h"
#include"cnpy.h"
TRTInfer::TRTInfer(TrtConfig trt_config): mEngine_(nullptr)
{
    // return;
    sum_cpy_feature_ = 0.0f;
    sum_cpy_output_ = 0.0f;
    count_ = 0;
    trt_config_ = trt_config;

    input_cpy_kind_ = cudaMemcpyHostToDevice;
    output_cpy_kind_ = cudaMemcpyDeviceToHost;

    build();

    CHECKCUDA(cudaStreamCreate(&stream_), "failed to create cuda stream");

    std::cout << "tensorrt init done." << std::endl;
}


bool TRTInfer::build()
{
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if (!builder)
    {
        return false;
    }

    SampleUniquePtr<nvinfer1::IRuntime> runtime{createInferRuntime(sample::gLogger.getTRTLogger())};
    if (!runtime)
    {
        return false;
    }

    // CUDA stream used for profiling by the builder.
    auto profileStream = samplesCommon::makeCudaStream();
    if (!profileStream)
    {
        return false;
    }

    const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
    if (!network)
    {
        return false;
    }

    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if (!config)
    {
        return false;
    }

    auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
    if (!parser)
    {
        return false;
    }

    // auto constructed = constructNetwork(builder, network, config, parser);
    // if (!constructed)
    // {
    //     return false;
    // }

    //replace conscructNetwork with following code:
    auto parsed = parser->parseFromFile(trt_config_.model_file.c_str(), static_cast<int>(sample::gLogger.getReportableSeverity()));
    if (!parsed)
    {
        return false;
    }

    for (int i = 0; i < network->getNbInputs(); i++) {
        std::cout << "network->getInput(i)->getDimensions(): " << network->getInput(i)->getDimensions() << std::endl;
        mInputDims.push_back(network->getInput(i)->getDimensions());
    }
    for (int i = 0; i < network->getNbOutputs(); i++) {
        mOutputDims.push_back(network->getOutput(i)->getDimensions());
    }

    config->setProfileStream(*profileStream);


    config->setAvgTimingIterations(1);
    config->setMinTimingIterations(1);
    config->setMaxWorkspaceSize(static_cast<size_t>(trt_config_.max_workspace)<<20);
    if (builder->platformHasFastFp16() && trt_config_.fp16mode)
    {
        config->setFlag(BuilderFlag::kFP16);
    }
    if (builder->platformHasFastInt8() && trt_config_.int8mode)
    {
        config->setFlag(BuilderFlag::kINT8);
        // samplesCommon::setAllDynamicRanges(network.get(), 127.0f, 127.0f); // in case use int8 without calibration
    }
    builder->setMaxBatchSize(1);
    
    std::unique_ptr<nvinfer1::IInt8Calibrator> calibrator;
    if (builder->platformHasFastInt8() && trt_config_.int8mode)
    {
        MNISTBatchStream calibrationStream(trt_config_.calib_data);
        calibrator.reset(new Int8EntropyCalibrator2<MNISTBatchStream>(calibrationStream, -1, trt_config_.net_name.c_str(), trt_config_.input_name.c_str()));
        config->setInt8Calibrator(calibrator.get());
    }

    IOptimizationProfile* profile = builder->createOptimizationProfile();
    profile->setDimensions("src", OptProfileSelector::kMIN, Dims2(1000,128));
    profile->setDimensions("src", OptProfileSelector::kOPT, Dims2(24629,128));
    profile->setDimensions("src", OptProfileSelector::kMAX, Dims2(100000,128));
    profile->setDimensions("set_voxel_inds_tensor_shift_0", OptProfileSelector::kMIN, Dims3(2,50,36));
    profile->setDimensions("set_voxel_inds_tensor_shift_0", OptProfileSelector::kOPT, Dims3(2,1156,36));
    profile->setDimensions("set_voxel_inds_tensor_shift_0", OptProfileSelector::kMAX, Dims3(2,5000,36));
    profile->setDimensions("set_voxel_inds_tensor_shift_1", OptProfileSelector::kMIN, Dims3(2,50,36));
    profile->setDimensions("set_voxel_inds_tensor_shift_1", OptProfileSelector::kOPT, Dims3(2,834,36));
    profile->setDimensions("set_voxel_inds_tensor_shift_1", OptProfileSelector::kMAX, Dims3(2,3200,36));
    profile->setDimensions("set_voxel_masks_tensor_shift_0", OptProfileSelector::kMIN, Dims3(2,50,36));
    profile->setDimensions("set_voxel_masks_tensor_shift_0", OptProfileSelector::kOPT, Dims3(2,1156,36));
    profile->setDimensions("set_voxel_masks_tensor_shift_0", OptProfileSelector::kMAX, Dims3(2,5000,36));
    profile->setDimensions("set_voxel_masks_tensor_shift_1", OptProfileSelector::kMIN, Dims3(2,50,36));
    profile->setDimensions("set_voxel_masks_tensor_shift_1", OptProfileSelector::kOPT, Dims3(2,834,36));
    profile->setDimensions("set_voxel_masks_tensor_shift_1", OptProfileSelector::kMAX, Dims3(2,3200,36));
    profile->setDimensions("pos_embed_tensor", OptProfileSelector::kMIN, Dims4(4,2,1000,128));
    profile->setDimensions("pos_embed_tensor", OptProfileSelector::kOPT, Dims4(4,2,24629,128));
    profile->setDimensions("pos_embed_tensor", OptProfileSelector::kMAX, Dims4(4,2,100000,128));
    config->addOptimizationProfile(profile);

    SampleUniquePtr<nvinfer1::IHostMemory> plan{builder->buildSerializedNetwork(*network, *config)};
    if (!plan)
    {
        return false;
    }

    mEngine_ = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(plan->data(), plan->size()), samplesCommon::InferDeleter());
    if (!mEngine_)
    {
        return false;
    }

    // Create RAII buffer manager object
    context_ = mEngine_->createExecutionContext();
    if (!context_)
    {
        return false;
    }

    return true;

}


void TRTInfer::doinference(std::vector<void*> &inputs, std::vector<float*> &outputs, std::vector<int> &input_dynamic)
{
   infer_dynamic(inputs, outputs, input_dynamic);
   cudaStreamSynchronize(stream_);
}


bool TRTInfer::infer_dynamic(std::vector<void*> &inputs, std::vector<float*> &outputs, std::vector<int> &input_dynamic)
{
    double t0 = getTime();
    mInputDims[0] = Dims2{input_dynamic[0], 128};
    mInputDims[1] = Dims3{2, input_dynamic[1], 36};
    mInputDims[2] = Dims3{2, input_dynamic[2], 36};
    mInputDims[3] = Dims3{2, input_dynamic[3], 36};
    mInputDims[4] = Dims3{2, input_dynamic[4], 36};
    mInputDims[5] = Dims4{4, 2, input_dynamic[5], 128};

    mInput[0].hostBuffer.resize(mInputDims[0]);
    mInput[1].hostBuffer.resize(mInputDims[1]);
    mInput[2].hostBuffer.resize(mInputDims[2]);
    mInput[3].hostBuffer.resize(mInputDims[3]);
    mInput[4].hostBuffer.resize(mInputDims[4]);
    mInput[5].hostBuffer.resize(mInputDims[5]);
    

    std::copy((float*)(inputs[0]), (float*)(inputs[0]) + 1, static_cast<float*>(mInput[0].hostBuffer.data()));
    std::copy((int*)inputs[1], (int*)inputs[1] + 2* input_dynamic[1] * 36, static_cast<int*>(mInput[1].hostBuffer.data()));
    std::copy((int*)inputs[2], (int*)inputs[2] + 2* input_dynamic[2] * 36, static_cast<int*>(mInput[2].hostBuffer.data()));
    std::copy((bool*)inputs[3], (bool*)inputs[3] + 2* input_dynamic[3] * 36, static_cast<bool*>(mInput[3].hostBuffer.data()));
    std::copy((bool*)inputs[4], (bool*)inputs[4] + 2* input_dynamic[4] * 36, static_cast<bool*>(mInput[4].hostBuffer.data()));
    std::copy((float*)inputs[5], (float*)inputs[5] + 4* 2* input_dynamic[5] * 128, static_cast<float*>(mInput[5].hostBuffer.data()));
    cudaStreamSynchronize(stream_);
    double t1 = getTime();

    mInput[0].deviceBuffer.resize(mInputDims[0]);
    mInput[1].deviceBuffer.resize(mInputDims[1]);
    mInput[2].deviceBuffer.resize(mInputDims[2]);
    mInput[3].deviceBuffer.resize(mInputDims[3]);
    mInput[4].deviceBuffer.resize(mInputDims[4]);
    mInput[5].deviceBuffer.resize(mInputDims[5]);

    CHECK(cudaMemcpy(mInput[0].deviceBuffer.data(), mInput[0].hostBuffer.data(), mInput[0].hostBuffer.nbBytes(), cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(mInput[1].deviceBuffer.data(), mInput[1].hostBuffer.data(), mInput[1].hostBuffer.nbBytes(), cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(mInput[2].deviceBuffer.data(), mInput[2].hostBuffer.data(), mInput[2].hostBuffer.nbBytes(), cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(mInput[3].deviceBuffer.data(), mInput[3].hostBuffer.data(), mInput[3].hostBuffer.nbBytes(), cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(mInput[4].deviceBuffer.data(), mInput[4].hostBuffer.data(), mInput[4].hostBuffer.nbBytes(), cudaMemcpyHostToDevice));
    CHECK(cudaMemcpy(mInput[5].deviceBuffer.data(), mInput[5].hostBuffer.data(), mInput[5].hostBuffer.nbBytes(), cudaMemcpyHostToDevice));
    cudaStreamSynchronize(stream_);
    double t2 = getTime();

    context_->setBindingDimensions(0, mInputDims[0]);
    context_->setBindingDimensions(1, mInputDims[1]);
    context_->setBindingDimensions(2, mInputDims[2]);
    context_->setBindingDimensions(3, mInputDims[3]);
    context_->setBindingDimensions(4, mInputDims[4]);
    context_->setBindingDimensions(5, mInputDims[5]);
    // context_->setBindingDimensions(6, mInputDims[6]);
    std::cout << "mEngine_->getNbBindings(): " << mEngine_->getNbBindings() << std::endl;
    std::cout << " mEngine_->getBindingDimensions(i)" <<  mEngine_->getBindingDimensions(0) << std::endl;
    std::cout << " context_->getBindingDimensions(i)" <<  context_->getBindingDimensions(0) << std::endl;
    cudaStreamSynchronize(stream_);
    double t3 = getTime();

    // We can only run inference once all dynamic input shapes have been specified.
    if (!context_->allInputDimensionsSpecified())
    {
        return false;
    }
    mOutputDims[0] = mInputDims[0];
    mOutput[0].deviceBuffer.resize(mOutputDims[0]);
    mOutput[0].hostBuffer.resize(mOutputDims[0]);
    std::vector<void*> processorBindings = {mInput[0].deviceBuffer.data(),
                                            mInput[1].deviceBuffer.data(),
                                            mInput[2].deviceBuffer.data(),
                                            mInput[3].deviceBuffer.data(),
                                            mInput[4].deviceBuffer.data(),
                                            mInput[5].deviceBuffer.data(),
                                            mOutput[0].deviceBuffer.data()};
    cudaStreamSynchronize(stream_);
    double t4 = getTime();
    bool status = context_->executeV2(processorBindings.data());
    if (!status)
    {
        return false;
    }
    cudaStreamSynchronize(stream_);
    double t5 = getTime();

    CHECK(cudaMemcpy(mOutput[0].hostBuffer.data(), mOutput[0].deviceBuffer.data(), mOutput[0].deviceBuffer.nbBytes(),
        cudaMemcpyDeviceToHost));
    cudaStreamSynchronize(stream_);
    double t6 = getTime();
    // cnpy::npy_save("dsvt_output_tensor.npy", static_cast<float*>(mOutput[0].hostBuffer.data()), {mOutput[0].deviceBuffer.nbBytes()/4},"w");
    std::cout << "time elapse:" << t1-t0 << std::endl;
    std::cout << "time elapse:" << t2-t1 << std::endl;
    std::cout << "time elapse:" << t3-t2 << std::endl;
    std::cout << "time elapse:" << t4-t3 << std::endl;
    std::cout << "time elapse:" << t5-t4 << std::endl;
    std::cout << "time elapse:" << t6-t5 << std::endl;
    return true;

}

according to results, the average time cost of each stage, as following:
t1-t0:0.00860953
t2-t1:0.0124242
t3-t2:4.72069e-05
t4-t3:8.10623e-06
t5-t4:0.260188
t6-t5:0.00110817

c++ code takes more time? Have some mistakes in inference code?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions