Skip to content

Commit 221feb7

Browse files
committed
example: using nvrtc kernel for aot plugin
1 parent a80572d commit 221feb7

File tree

2 files changed

+247
-2
lines changed

2 files changed

+247
-2
lines changed

examples/dynamo/aot_plugin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,8 @@ def add_plugin_aot_impl(
9898
"x_ptr": f"*{type_str}",
9999
"n_elements": "i32",
100100
"y_ptr": f"*{type_str}",
101-
"BLOCK_SIZE": "constexpr",
102101
},
103-
constants={
102+
constexprs={
104103
"BLOCK_SIZE": block_size,
105104
},
106105
)
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""
2+
Using Custom Kernels with NVRTC in TensorRT AOT Plugins
3+
=======================================================
4+
5+
This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library
6+
to compile custom CUDA kernels at runtime and integrate them into a TensorRT
7+
Ahead-Of-Time (AOT) plugin.
8+
9+
This approach is powerful because it allows you to:
10+
1. Write raw CUDA C++ code for maximum performance.
11+
2. Compile it on-the-fly, adapting to the specific GPU architecture.
12+
3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library.
13+
4. Integrate it seamlessly into Torch-TensorRT's compilation flow.
14+
15+
The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)).
16+
"""
17+
18+
from typing import List, Tuple, Union
19+
20+
import torch
21+
22+
import torch_tensorrt
23+
24+
# ============================================================================
25+
# 1. Define the CUDA Kernel Source
26+
# ============================================================================
27+
# We define the CUDA kernel source code as a Python string.
28+
# This code will be compiled by NVRTC.
29+
# Note that we use extern "C" to avoid name mangling, making it easier to
30+
# retrieve the kernel function by name later.
31+
32+
cu_code = """
33+
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
34+
extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input,
35+
const int size,
36+
float* __restrict__ output) {
37+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
38+
39+
if (idx < size) {
40+
const float x = input[idx];
41+
// use fast device intrinsic to avoid headers
42+
output[idx] = 1.0f / (1.0f + __expf(-x));
43+
}
44+
}
45+
"""
46+
47+
# ============================================================================
48+
# 2. Compile the Kernel using NVRTC (for eager mode)
49+
# ============================================================================
50+
# Before defining the Torch custom op, we compile the kernel so we can run it
51+
# in standard PyTorch (eager mode) for verification and testing.
52+
# We use the cuda-python library's NVRTC bindings.
53+
54+
from cuda.core.experimental import Device as _CudaDevice
55+
from cuda.core.experimental import LaunchConfig as _LaunchConfig
56+
from cuda.core.experimental import Program as _CudaProgram
57+
from cuda.core.experimental import ProgramOptions as _CudaProgramOptions
58+
from cuda.core.experimental import launch as _cuda_launch
59+
60+
# Initialize CUDA device and stream
61+
_cuda_device = _CudaDevice()
62+
_cuda_device.set_current()
63+
_cuda_stream = _cuda_device.create_stream()
64+
65+
# Configure compilation options
66+
_program_options = _CudaProgramOptions(
67+
std="c++17",
68+
arch=f"sm_{_cuda_device.arch}", # Target the current GPU architecture
69+
include_path=["/usr/local/cuda/include"],
70+
)
71+
72+
# Create and compile the program
73+
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
74+
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
75+
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")
76+
77+
78+
# ============================================================================
79+
# 3. Register Custom Op in PyTorch
80+
# ============================================================================
81+
# We register the custom operation with PyTorch so it can be used in models.
82+
# The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place).
83+
84+
85+
@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc]
86+
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
87+
"""
88+
Implementation of the custom op for PyTorch eager execution.
89+
This function launches the pre-compiled NVRTC kernel.
90+
"""
91+
assert X.is_cuda, "Tensor must be on CUDA device."
92+
assert X.dtype == torch.float32, "For this test, expected float32 input"
93+
94+
Y = torch.empty_like(X)
95+
N = int(X.numel())
96+
97+
block = 256
98+
grid_x = max(1, (N + block - 1) // block)
99+
config = _LaunchConfig(grid=(grid_x), block=(block))
100+
101+
# Helper class to wrap PyTorch's stream for cuda-python
102+
class _PyTorchStreamWrapper:
103+
def __init__(self, pt_stream):
104+
self.pt_stream = pt_stream
105+
106+
def __cuda_stream__(self):
107+
stream_id = self.pt_stream.cuda_stream
108+
return (0, stream_id)
109+
110+
pt_stream = torch.cuda.current_stream()
111+
s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))
112+
113+
# Launch kernel with raw pointers
114+
_cuda_launch(
115+
s,
116+
config,
117+
_kernel,
118+
X.data_ptr(),
119+
N,
120+
Y.data_ptr(),
121+
)
122+
123+
return Y
124+
125+
126+
# ============================================================================
127+
# 4. Register Fake Implementation (Meta Kernel)
128+
# ============================================================================
129+
# The fake implementation is crucial for TorchDynamo. It tells the compiler
130+
# about the output shape and data type without actually running the kernel.
131+
# This is used during the tracing phase.
132+
133+
134+
@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
135+
def _(input: torch.Tensor) -> torch.Tensor:
136+
"""Fake implementation for TorchDynamo tracing of base operation."""
137+
return torch.empty_like(input)
138+
139+
140+
# ============================================================================
141+
# 5. Define TensorRT AOT Plugin
142+
# ============================================================================
143+
# Now we define how this operation should be handled within TensorRT.
144+
# We use the torch_tensorrt plugin auto generation feature and the AOT implementation using NVRTC.
145+
146+
import tensorrt.plugin as trtp
147+
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions
148+
149+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
150+
"pointwise_sigmoid_ops::pointwise_sigmoid"
151+
)
152+
153+
154+
# This is where the magic happens. We provide the compiled PTX code and
155+
# launch parameters to TensorRT. This code runs during engine building.
156+
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
157+
def sigmoid_aot_nvrtc_impl(
158+
X: trtp.TensorDesc,
159+
outputs: Tuple[trtp.TensorDesc],
160+
tactic: int,
161+
) -> Tuple[
162+
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
163+
]:
164+
# Get the PTX code from our pre-compiled module
165+
compiled_kernel = _module.code.decode("utf-8")
166+
167+
# Calculate grid and block dimensions based on input shape
168+
N = X.shape_expr.numel()
169+
launch_params = trtp.KernelLaunchParams()
170+
block = 256
171+
launch_params.grid_x = trtp.cdiv(N, block)
172+
launch_params.block_x = block
173+
launch_params.shared_mem = 0
174+
175+
# Pass the number of elements (N) as an extra argument to the kernel
176+
extra_args = trtp.SymIntExprs(1)
177+
extra_args[0] = trtp.SymInt32(N)
178+
179+
# Return: kernel name, PTX code, launch params, kernel arguments
180+
return (
181+
"pointwise_sigmoid_kernel_nvrtc",
182+
compiled_kernel,
183+
launch_params,
184+
extra_args,
185+
)
186+
187+
188+
# ============================================================================
189+
# 6. Generate Plugin Converter
190+
# ============================================================================
191+
# This registers the mapping between the PyTorch custom op and the TensorRT plugin.
192+
# It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid',
193+
# replace it with the TensorRT plugin we just defined."
194+
195+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
196+
"pointwise_sigmoid_ops::pointwise_sigmoid",
197+
supports_dynamic_shapes=True,
198+
requires_output_allocator=False,
199+
)
200+
201+
202+
# ============================================================================
203+
# 7. Test the Model
204+
# ============================================================================
205+
206+
207+
class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
208+
"""
209+
Test model that uses the TRT wrapper with custom_op() registration.
210+
"""
211+
212+
def forward(self, input: torch.Tensor) -> torch.Tensor:
213+
z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
214+
return z
215+
216+
217+
if __name__ == "__main__":
218+
model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
219+
input = torch.randn(1, 1024, device="cuda", dtype=torch.float32)
220+
221+
print("PyTorch baseline result:")
222+
print(torch.sigmoid(input))
223+
224+
print("Custom Op eager result:")
225+
print(model(input))
226+
227+
print("\nCompiling with Torch-TensorRT...")
228+
with torch_tensorrt.logging.debug():
229+
trt_inputs = [input]
230+
model_trt = torch_tensorrt.compile(
231+
model,
232+
inputs=trt_inputs,
233+
enabled_precisions={torch.float32},
234+
min_block_size=1,
235+
)
236+
print("Model compiled successfully!")
237+
238+
print("Running inference with compiled model...")
239+
with torch.no_grad():
240+
for i in range(10):
241+
res = model_trt(input)
242+
assert torch.allclose(
243+
res, model(input), rtol=1e-2, atol=1e-2
244+
), "Results do not match!"
245+
246+
print("Inference successful!")

0 commit comments

Comments
 (0)