diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index ca0da0574165c0..47b369da64687f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -38,9 +38,9 @@ #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" -#include "paddle/fluid/platform/init.h" #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" @@ -166,45 +166,45 @@ bool AnalysisPredictor::Init( return true; } -void ParseCommandLineFlags(){ - std::vector internal_argv; - std::string dummy = "dummy"; - internal_argv.push_back(strdup(dummy.c_str())); - std::vector envs; - std::vector undefok; +void ParseCommandLineFlags() { + std::vector internal_argv; + std::string dummy = "dummy"; + internal_argv.push_back(strdup(dummy.c_str())); + std::vector envs; + std::vector undefok; #ifdef PADDLE_WITH_CUDA - envs.push_back("fraction_of_gpu_memory_to_use"); - envs.push_back("initial_gpu_memory_in_mb"); - envs.push_back("reallocate_gpu_memory_in_mb"); + envs.push_back("fraction_of_gpu_memory_to_use"); + envs.push_back("initial_gpu_memory_in_mb"); + envs.push_back("reallocate_gpu_memory_in_mb"); #endif - envs.push_back("allocator_strategy"); - envs.push_back("initial_cpu_memory_in_mb"); - undefok.push_back("initial_cpu_memory_in_mb"); - char* env_str = nullptr; - if (envs.size() > 0) { - std::string env_string = "--tryfromenv="; - for (auto t : envs) { - env_string += t + ","; - } - env_string = env_string.substr(0, env_string.length() - 1); - env_str = strdup(env_string.c_str()); - internal_argv.push_back(env_str); - LOG(INFO) << "get env_string" << env_string; + envs.push_back("allocator_strategy"); + envs.push_back("initial_cpu_memory_in_mb"); + undefok.push_back("initial_cpu_memory_in_mb"); + char *env_str = nullptr; + if (envs.size() > 0) { + std::string env_string = "--tryfromenv="; + for (auto t : envs) { + env_string += t + ","; } + env_string = env_string.substr(0, env_string.length() - 1); + env_str = strdup(env_string.c_str()); + internal_argv.push_back(env_str); + LOG(INFO) << "get env_string" << env_string; + } - char* undefok_str = nullptr; - if (undefok.size() > 0) { - std::string undefok_string = "--undefok="; - for (auto t : undefok) { - undefok_string += t + ","; - } - undefok_string = undefok_string.substr(0, undefok_string.length() - 1); - undefok_str = strdup(undefok_string.c_str()); - internal_argv.push_back(undefok_str); + char *undefok_str = nullptr; + if (undefok.size() > 0) { + std::string undefok_string = "--undefok="; + for (auto t : undefok) { + undefok_string += t + ","; } - int internal_argc = internal_argv.size(); - char** arr = internal_argv.data(); - paddle::platform::ParseCommandLineFlags(internal_argc, arr, true); + undefok_string = undefok_string.substr(0, undefok_string.length() - 1); + undefok_str = strdup(undefok_string.c_str()); + internal_argv.push_back(undefok_str); + } + int internal_argc = internal_argv.size(); + char **arr = internal_argv.data(); + paddle::platform::ParseCommandLineFlags(internal_argc, arr, true); } bool AnalysisPredictor::PrepareScope( @@ -1136,6 +1136,7 @@ USE_TRT_CONVERTER(hard_swish); USE_TRT_CONVERTER(split); USE_TRT_CONVERTER(transpose); USE_TRT_CONVERTER(prelu); +USE_TRT_CONVERTER(box_coder); USE_TRT_CONVERTER(conv2d_transpose); USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(shuffle_channel); @@ -1148,17 +1149,18 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(skip_layernorm); USE_TRT_CONVERTER(slice); USE_TRT_CONVERTER(scale); +USE_TRT_CONVERTER(cast); USE_TRT_CONVERTER(stack); USE_TRT_CONVERTER(reshape); USE_TRT_CONVERTER(flatten); -//USE_TRT_CONVERTER(clip); -//USE_TRT_CONVERTER(gather); +// USE_TRT_CONVERTER(clip); +// USE_TRT_CONVERTER(gather); // USE_TRT_CONVERTER(anchor_generator); -//USE_TRT_CONVERTER(yolo_box); -//USE_TRT_CONVERTER(roi_align); -//USE_TRT_CONVERTER(affine_channel); -//USE_TRT_CONVERTER(multiclass_nms); -//USE_TRT_CONVERTER(nearest_interp); +// USE_TRT_CONVERTER(yolo_box); +// USE_TRT_CONVERTER(roi_align); +// USE_TRT_CONVERTER(affine_channel); +// USE_TRT_CONVERTER(multiclass_nms); +// USE_TRT_CONVERTER(nearest_interp); #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 8de6b5dca9c3a8..158481a20cbbc5 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,10 +1,11 @@ # Add TRT tests nv_library(tensorrt_converter SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc - batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc + batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc box_coder_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc reshape_op.cc flatten_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc + cast_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/box_coder_op.cc b/paddle/fluid/inference/tensorrt/convert/box_coder_op.cc new file mode 100644 index 00000000000000..52ed08d033ee03 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/box_coder_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.h" + +namespace nvinfer1 { +class ILayer; +} // namespace nvinfer1 +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class BoxCoderOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + + auto* input = engine_->GetITensor(op_desc.Input("TargetBox")[0]); + int input_num = op_desc.Input("TargetBox").size(); + // Get output + auto* prior_box_var = scope.FindVar(op_desc.Input("PriorBox")[0]); + auto* prior_box_var_var = scope.FindVar(op_desc.Input("PriorBoxVar")[0]); + PADDLE_ENFORCE_NOT_NULL( + prior_box_var, + platform::errors::NotFound( + "Variable Alpha of prelu TRT converter is not found.")); + auto* prior_box_tensor = prior_box_var->GetMutable(); + auto* prior_box_var_tensor = + prior_box_var_var->GetMutable(); + platform::CPUPlace cpu_place; + std::unique_ptr prior_box_tensor_temp( + new framework::LoDTensor()); + std::unique_ptr prior_box_var_tensor_temp( + new framework::LoDTensor()); + prior_box_tensor_temp->Resize(prior_box_tensor->dims()); + prior_box_var_tensor_temp->Resize(prior_box_var_tensor_temp->dims()); + TensorCopySync(*prior_box_tensor, cpu_place, prior_box_tensor_temp.get()); + TensorCopySync(*prior_box_var_tensor, cpu_place, + prior_box_var_tensor_temp.get()); + float* prior_box_data = + prior_box_tensor_temp->mutable_data(cpu_place); + float* prior_box_var_data = + prior_box_var_tensor_temp->mutable_data(cpu_place); + + nvinfer1::ILayer* layer = nullptr; + + plugin::BoxCoderPlugin* plugin = new plugin::BoxCoderPlugin( + prior_box_data, prior_box_var_data, prior_box_tensor_temp->numel()); + layer = engine_->AddPlugin(&input, input_num, plugin); + + auto output_name = op_desc.Output("OutputBox")[0]; + RreplenishLayerAndOutput(layer, "box_coder", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(box_coder, BoxCoderOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/cast_op.cc b/paddle/fluid/inference/tensorrt/convert/cast_op.cc new file mode 100644 index 00000000000000..d4756377241fb1 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/cast_op.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace nvinfer1 { +class ILayer; +} // namespace nvinfer1 +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * cast converter from fluid to tensorRT. + */ +class CastOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // int in_dtype = BOOST_GET_CONST(int, desc.GetAttr("in_dtype")); + // int out_dtype = BOOST_GET_CONST(int, desc.GetAttr("out_dtype")); + nvinfer1::ILayer* layer = nullptr; + // if(in_dtype == 20 && out_dtype == 5) { + // input->setType(nvinfer1::DataType::kINT8); + // PADDLE_ENFORCE_EQ(input->getType() == nvinfer1::DataType::kINT8, true, + // platform::errors::InvalidArgument("xxx")); + layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); + // layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); + // } else { + // PADDLE_THROW(platform::errors::InvalidArgument("not supported dtype")); + // } + // layer->setPrecision(nvinfer1::DataType::kFLOAT); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "cast", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(cast, CastOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 20671bc361e63f..64b144ffb18381 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -100,6 +100,7 @@ struct SimpleOpTypeSetTeller : public Teller { "split", "instance_norm", "gelu", + "box_coder", "layer_norm", "scale", "stack", @@ -108,6 +109,7 @@ struct SimpleOpTypeSetTeller : public Teller { "reshape", "flatten2", "flatten", + "cast", }; }; @@ -137,6 +139,14 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, if (op_type == "transpose2") { if (!desc.HasAttr("axis")) return false; } + if (op_type == "cast") { + if (!desc.HasAttr("in_dtype") || !desc.HasAttr("out_dtype")) return false; + int in_dtype = BOOST_GET_CONST(int, desc.GetAttr("in_dtype")); + int out_dtype = BOOST_GET_CONST(int, desc.GetAttr("out_dtype")); + if (!(in_dtype == 20 && out_dtype == 5) || + !(in_dtype == 3 && out_dtype == 5)) + return false; + } if (op_type == "matmul") { auto* block = desc.Block(); for (auto& param_name : desc.Inputs()) { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 913f575de05e9e..fe49dcd42183b2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -6,6 +6,7 @@ nv_library(tensorrt_plugin instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu + box_coder_op_plugin.cu #anchor_generator_op_plugin.cu #yolo_box_op_plugin.cu #roi_align_op_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.cu new file mode 100644 index 00000000000000..e4fc206f7cb07a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.cu @@ -0,0 +1,141 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +PReluPlugin *CreateBoxCoderPluginDeserialize(const void *buffer, + size_t length) { + return new BoxCoderPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("box_coder_plugin", CreateBoxCoderPluginDeserialize); + +int BoxCoderPlugin::initialize() { + cudaMalloc(&p_gpu_prior_box_weight_, + sizeof(float) * prior_box_weight_.size()); + cudaMalloc(&p_gpu_prior_box_var_weight_, + sizeof(float) * prior_box_var_weight_.size()); + cudaMemcpy(p_gpu_prior_box_weight_, prior_box_weight_.data(), + prior_box_weight_.size() * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(p_gpu_prior_box_var_weight_, prior_box_var_weight_.data(), + prior_box_var_weight_.size() * sizeof(float), + cudaMemcpyHostToDevice); + return 0; +} + +nvinfer1::Dims BoxCoderPlugin::getOutputDimensions( + int index, const nvinfer1::Dims *inputDims, int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const &input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +__global__ void DecodeCenterSizeKernel(const float *prior_box_data, + const float *prior_box_var_data, + const float *target_box_data, + const int row, const int col, + const int len, const bool normalized, + const float prior_box_var_size, + const int axis, float *output) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + int prior_box_offset = 0; + if (idx < row * col) { + const int col_idx = idx % col; + const int row_idx = idx / col; + prior_box_offset = axis == 0 ? col_idx * len : row_idx * len; + float prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + float prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + + (normalized == false); + float prior_box_center_x = + prior_box_data[prior_box_offset] + prior_box_width / 2; + float prior_box_center_y = + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; + float target_box_width, target_box_height; + float target_box_center_x, target_box_center_y; + float box_var_x = 1, box_var_y = 1; + float box_var_w = 1, box_var_h = 1; + + int prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; + box_var_x = prior_box_var_data[prior_var_offset]; + box_var_y = prior_box_var_data[prior_var_offset + 1]; + box_var_w = prior_box_var_data[prior_var_offset + 2]; + box_var_h = prior_box_var_data[prior_var_offset + 3]; + + target_box_width = + exp(box_var_w * target_box_data[idx * len + 2]) * prior_box_width; + target_box_height = + exp(box_var_h * target_box_data[idx * len + 3]) * prior_box_height; + target_box_center_x = + box_var_x * target_box_data[idx * len] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[idx * len + 1] * prior_box_height + + prior_box_center_y; + + output[idx * len] = target_box_center_x - target_box_width / 2; + output[idx * len + 1] = target_box_center_y - target_box_height / 2; + output[idx * len + 2] = + target_box_center_x + target_box_width / 2 - (normalized == false); + output[idx * len + 3] = + target_box_center_y + target_box_height / 2 - (normalized == false); + } +} + +int BoxCoderPlugin::enqueue(int batch_size, const void *const *inputs, + void **outputs, void *workspace, + cudaStream_t stream) { + // input dims is CHW. + const auto &input_dims = this->getInputDims(0); + const float *input = reinterpret_cast(inputs[0]); + const float *prior_box_weight_ = p_gpu_prior_box_weight_; + const float *prior_box_var_weight_ = p_gpu_prior_box_var_weight_; + float *output = reinterpret_cast(outputs)[0]; + int numel = 1; + for (int i = 0; i < input_dims.nbDims; i++) { + numel *= input_dims.d[i]; + } + + auto row = 16; // target_box->dims()[0]; + auto col = 2550; // target_box->dims()[1]; + auto len = 4; // prior_box->dims()[1]; + int block = 512; + int grid = (row * col + block - 1) / block; + + DecodeCenterSizeKernel<<>>( + prior_box_weight_, prior_box_var_weight_, input, row, col, len, true, + prior_box_var_size, 0, output); + + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.h new file mode 100644 index 00000000000000..01acb1e05a6e83 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/box_coder_op_plugin.h @@ -0,0 +1,90 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" + +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class BoxCoderPlugin : public PluginTensorRT { + std::vector prior_box_weight_; + std::vector prior_box_var_weight_; + float* p_gpu_prior_box_weight_; + float* p_gpu_prior_box_var_weight_; + + protected: + size_t getSerializationSize() override { + return getBaseSerializationSize() + SerializedSize(prior_box_weight_) + + SerializedSize(prior_box_var_weight_) + + SerializedSize(getPluginType()); + } + + void serialize(void* buffer) override { + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, prior_box_weight_); + SerializeValue(&buffer, prior_box_var_weight_); + } + + public: + BoxCoderPlugin(const float* prior_box_weight, + const float* prior_box_var_weight, const int weight_num) { + prior_box_weight_.resize(weight_num); + prior_box_var_weight_.resize(weight_num); + std::copy(prior_box_weight, prior_box_weight + weight_num, + prior_box_weight_.data()); + std::copy(prior_box_var_weight, prior_box_var_weight + weight_num, + prior_box_var_weight_.data()); + } + + BoxCoderPlugin(void const* serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &prior_box_weight_); + DeserializeValue(&serialData, &serialLength, &prior_box_var_weight_); + } + ~BoxCoderPlugin() { + cudaFree(p_gpu_prior_box_weight_); + cudaFree(p_gpu_prior_box_var_weight_); + } + int initialize() override; + + BoxCoderPlugin* clone() const override { + return new BoxCoderPlugin( + prior_box_weight_.data(), + prior_box_var_weight_.data() prior_box_weight_.size()); + } + + const char* getPluginType() const override { return "box_coder_plugin"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nbInputDims) override; + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index b8805c025a768e..2ec9622ee851c7 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -11,18 +11,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once - #ifdef PADDLE_WITH_CUDA +#include #include #include #include #include #include #include - #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -30,7 +28,6 @@ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" - namespace paddle { namespace inference { namespace tensorrt { @@ -76,6 +73,13 @@ static void RuntimeStaticShapeCheck(std::vector runtime_input_shape, model_input_shape_str, runtime_input_shape_str)); } +#ifdef PADDLE_WITH_CUDA +static inline void __global__ CastCUDAKernel(const uint8_t *in, const int64_t N, + float *out) { + CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } +} +#endif + class TensorRTEngineOp : public framework::OperatorBase { private: std::vector input_names_; @@ -280,9 +284,30 @@ class TensorRTEngineOp : public framework::OperatorBase { buffers[bind_index] = static_cast(t.data()); } else if (type == framework::proto::VarType::INT32) { buffers[bind_index] = static_cast(t.data()); + } else if (type == framework::proto::VarType::UINT8) { + const int theory_thread_count = element_count; + // Get Max threads in all SM + int max_pyhsical_threads = dev_ctx.GetMaxPhysicalThreadCount(); + int sm = dev_ctx.GetSMCount(); + // Compute pyhsical threads we need, should small than max sm threads + const int physical_thread_count = + std::min(max_pyhsical_threads, theory_thread_count); + + // Need get from device + const int thread_per_block = + std::min(1024, dev_ctx.GetMaxThreadsPerBlock()); + const int block_count = std::min( + (physical_thread_count + thread_per_block - 1) / thread_per_block, + sm); + buffers[bind_index] = + CastCUDAKernel<<>>( + t.data(), framework::product(t.dims()), + t.mutable_data( + BOOST_GET_CONST(platform::CUDAPlace, dev_place))); } else { - PADDLE_THROW(platform::errors::Fatal( - "The TRT Engine OP only support float/int32_t/int64_t input.")); + PADDLE_THROW( + platform::errors::Fatal("The TRT Engine OP only support " + "float/int32_t/int64_t/uint8_t input.")); } }