|
| 1 | +""" |
| 2 | +Using Custom Kernels with NVRTC in TensorRT AOT Plugins |
| 3 | +======================================================= |
| 4 | +
|
| 5 | +This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library |
| 6 | +to compile custom CUDA kernels at runtime and integrate them into a TensorRT |
| 7 | +Ahead-Of-Time (AOT) plugin. |
| 8 | +
|
| 9 | +This approach is powerful because it allows you to: |
| 10 | +1. Write raw CUDA C++ code for maximum performance. |
| 11 | +2. Compile it on-the-fly, adapting to the specific GPU architecture. |
| 12 | +3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library. |
| 13 | +4. Integrate it seamlessly into Torch-TensorRT's compilation flow. |
| 14 | +
|
| 15 | +The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)). |
| 16 | +""" |
| 17 | + |
| 18 | +from typing import List, Tuple, Union |
| 19 | + |
| 20 | +import torch |
| 21 | + |
| 22 | +import torch_tensorrt |
| 23 | + |
| 24 | +# ============================================================================ |
| 25 | +# 1. Define the CUDA Kernel Source |
| 26 | +# ============================================================================ |
| 27 | +# We define the CUDA kernel source code as a Python string. |
| 28 | +# This code will be compiled by NVRTC. |
| 29 | +# Note that we use extern "C" to avoid name mangling, making it easier to |
| 30 | +# retrieve the kernel function by name later. |
| 31 | + |
| 32 | +cu_code = """ |
| 33 | +// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) |
| 34 | +extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input, |
| 35 | + const int size, |
| 36 | + float* __restrict__ output) { |
| 37 | + const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 38 | +
|
| 39 | + if (idx < size) { |
| 40 | + const float x = input[idx]; |
| 41 | + // use fast device intrinsic to avoid headers |
| 42 | + output[idx] = 1.0f / (1.0f + __expf(-x)); |
| 43 | + } |
| 44 | +} |
| 45 | +""" |
| 46 | + |
| 47 | +# ============================================================================ |
| 48 | +# 2. Compile the Kernel using NVRTC (for eager mode) |
| 49 | +# ============================================================================ |
| 50 | +# Before defining the Torch custom op, we compile the kernel so we can run it |
| 51 | +# in standard PyTorch (eager mode) for verification and testing. |
| 52 | +# We use the cuda-python library's NVRTC bindings. |
| 53 | + |
| 54 | +from cuda.core.experimental import Device as _CudaDevice |
| 55 | +from cuda.core.experimental import LaunchConfig as _LaunchConfig |
| 56 | +from cuda.core.experimental import Program as _CudaProgram |
| 57 | +from cuda.core.experimental import ProgramOptions as _CudaProgramOptions |
| 58 | +from cuda.core.experimental import launch as _cuda_launch |
| 59 | + |
| 60 | +# Initialize CUDA device and stream |
| 61 | +_cuda_device = _CudaDevice() |
| 62 | +_cuda_device.set_current() |
| 63 | +_cuda_stream = _cuda_device.create_stream() |
| 64 | + |
| 65 | +# Configure compilation options |
| 66 | +_program_options = _CudaProgramOptions( |
| 67 | + std="c++17", |
| 68 | + arch=f"sm_{_cuda_device.arch}", # Target the current GPU architecture |
| 69 | + include_path=["/usr/local/cuda/include"], |
| 70 | +) |
| 71 | + |
| 72 | +# Create and compile the program |
| 73 | +_program = _CudaProgram(cu_code, code_type="c++", options=_program_options) |
| 74 | +_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) |
| 75 | +_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc") |
| 76 | + |
| 77 | + |
| 78 | +# ============================================================================ |
| 79 | +# 3. Register Custom Op in PyTorch |
| 80 | +# ============================================================================ |
| 81 | +# We register the custom operation with PyTorch so it can be used in models. |
| 82 | +# The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place). |
| 83 | + |
| 84 | + |
| 85 | +@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc] |
| 86 | +def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor: |
| 87 | + """ |
| 88 | + Implementation of the custom op for PyTorch eager execution. |
| 89 | + This function launches the pre-compiled NVRTC kernel. |
| 90 | + """ |
| 91 | + assert X.is_cuda, "Tensor must be on CUDA device." |
| 92 | + assert X.dtype == torch.float32, "For this test, expected float32 input" |
| 93 | + |
| 94 | + Y = torch.empty_like(X) |
| 95 | + N = int(X.numel()) |
| 96 | + |
| 97 | + block = 256 |
| 98 | + grid_x = max(1, (N + block - 1) // block) |
| 99 | + config = _LaunchConfig(grid=(grid_x), block=(block)) |
| 100 | + |
| 101 | + # Helper class to wrap PyTorch's stream for cuda-python |
| 102 | + class _PyTorchStreamWrapper: |
| 103 | + def __init__(self, pt_stream): |
| 104 | + self.pt_stream = pt_stream |
| 105 | + |
| 106 | + def __cuda_stream__(self): |
| 107 | + stream_id = self.pt_stream.cuda_stream |
| 108 | + return (0, stream_id) |
| 109 | + |
| 110 | + pt_stream = torch.cuda.current_stream() |
| 111 | + s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream)) |
| 112 | + |
| 113 | + # Launch kernel with raw pointers |
| 114 | + _cuda_launch( |
| 115 | + s, |
| 116 | + config, |
| 117 | + _kernel, |
| 118 | + X.data_ptr(), |
| 119 | + N, |
| 120 | + Y.data_ptr(), |
| 121 | + ) |
| 122 | + |
| 123 | + return Y |
| 124 | + |
| 125 | + |
| 126 | +# ============================================================================ |
| 127 | +# 4. Register Fake Implementation (Meta Kernel) |
| 128 | +# ============================================================================ |
| 129 | +# The fake implementation is crucial for TorchDynamo. It tells the compiler |
| 130 | +# about the output shape and data type without actually running the kernel. |
| 131 | +# This is used during the tracing phase. |
| 132 | + |
| 133 | + |
| 134 | +@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 135 | +def _(input: torch.Tensor) -> torch.Tensor: |
| 136 | + """Fake implementation for TorchDynamo tracing of base operation.""" |
| 137 | + return torch.empty_like(input) |
| 138 | + |
| 139 | + |
| 140 | +# ============================================================================ |
| 141 | +# 5. Define TensorRT AOT Plugin |
| 142 | +# ============================================================================ |
| 143 | +# Now we define how this operation should be handled within TensorRT. |
| 144 | +# We use the torch_tensorrt plugin auto generation feature and the AOT implementation using NVRTC. |
| 145 | + |
| 146 | +import tensorrt.plugin as trtp |
| 147 | +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions |
| 148 | + |
| 149 | +torch_tensorrt.dynamo.conversion.plugins.generate_plugin( |
| 150 | + "pointwise_sigmoid_ops::pointwise_sigmoid" |
| 151 | +) |
| 152 | + |
| 153 | + |
| 154 | +# This is where the magic happens. We provide the compiled PTX code and |
| 155 | +# launch parameters to TensorRT. This code runs during engine building. |
| 156 | +@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 157 | +def sigmoid_aot_nvrtc_impl( |
| 158 | + X: trtp.TensorDesc, |
| 159 | + outputs: Tuple[trtp.TensorDesc], |
| 160 | + tactic: int, |
| 161 | +) -> Tuple[ |
| 162 | + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs |
| 163 | +]: |
| 164 | + # Get the PTX code from our pre-compiled module |
| 165 | + compiled_kernel = _module.code.decode("utf-8") |
| 166 | + |
| 167 | + # Calculate grid and block dimensions based on input shape |
| 168 | + N = X.shape_expr.numel() |
| 169 | + launch_params = trtp.KernelLaunchParams() |
| 170 | + block = 256 |
| 171 | + launch_params.grid_x = trtp.cdiv(N, block) |
| 172 | + launch_params.block_x = block |
| 173 | + launch_params.shared_mem = 0 |
| 174 | + |
| 175 | + # Pass the number of elements (N) as an extra argument to the kernel |
| 176 | + extra_args = trtp.SymIntExprs(1) |
| 177 | + extra_args[0] = trtp.SymInt32(N) |
| 178 | + |
| 179 | + # Return: kernel name, PTX code, launch params, kernel arguments |
| 180 | + return ( |
| 181 | + "pointwise_sigmoid_kernel_nvrtc", |
| 182 | + compiled_kernel, |
| 183 | + launch_params, |
| 184 | + extra_args, |
| 185 | + ) |
| 186 | + |
| 187 | + |
| 188 | +# ============================================================================ |
| 189 | +# 6. Generate Plugin Converter |
| 190 | +# ============================================================================ |
| 191 | +# This registers the mapping between the PyTorch custom op and the TensorRT plugin. |
| 192 | +# It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid', |
| 193 | +# replace it with the TensorRT plugin we just defined." |
| 194 | + |
| 195 | +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( |
| 196 | + "pointwise_sigmoid_ops::pointwise_sigmoid", |
| 197 | + supports_dynamic_shapes=True, |
| 198 | + requires_output_allocator=False, |
| 199 | +) |
| 200 | + |
| 201 | + |
| 202 | +# ============================================================================ |
| 203 | +# 7. Test the Model |
| 204 | +# ============================================================================ |
| 205 | + |
| 206 | + |
| 207 | +class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module): |
| 208 | + """ |
| 209 | + Test model that uses the TRT wrapper with custom_op() registration. |
| 210 | + """ |
| 211 | + |
| 212 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 213 | + z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input) |
| 214 | + return z |
| 215 | + |
| 216 | + |
| 217 | +if __name__ == "__main__": |
| 218 | + model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval() |
| 219 | + input = torch.randn(1, 1024, device="cuda", dtype=torch.float32) |
| 220 | + |
| 221 | + print("PyTorch baseline result:") |
| 222 | + print(torch.sigmoid(input)) |
| 223 | + |
| 224 | + print("Custom Op eager result:") |
| 225 | + print(model(input)) |
| 226 | + |
| 227 | + print("\nCompiling with Torch-TensorRT...") |
| 228 | + with torch_tensorrt.logging.debug(): |
| 229 | + trt_inputs = [input] |
| 230 | + model_trt = torch_tensorrt.compile( |
| 231 | + model, |
| 232 | + inputs=trt_inputs, |
| 233 | + enabled_precisions={torch.float32}, |
| 234 | + min_block_size=1, |
| 235 | + ) |
| 236 | + print("Model compiled successfully!") |
| 237 | + |
| 238 | + print("Running inference with compiled model...") |
| 239 | + with torch.no_grad(): |
| 240 | + for i in range(10): |
| 241 | + res = model_trt(input) |
| 242 | + assert torch.allclose( |
| 243 | + res, model(input), rtol=1e-2, atol=1e-2 |
| 244 | + ), "Results do not match!" |
| 245 | + |
| 246 | + print("Inference successful!") |
0 commit comments