Skip to content

Commit abaaf96

Browse files
committed
feat: TorchTRT Annotation Layer for Cuda generated kernels
1 parent 2361ec5 commit abaaf96

16 files changed

Lines changed: 2335 additions & 5 deletions

File tree

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,12 @@ jobs:
459459
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py
460460
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_flashinfer_rmsnorm.py
461461
popd
462+
pushd .
463+
# cuda-python is an optional runtime dep for the torch_tensorrt.annotation QDP layer.
464+
python -m pip install cuda-python
465+
cd tests/py/annotation
466+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_annotation_test_results.xml .
467+
popd
462468
463469
L2-torchscript-tests:
464470
name: ${{ matrix.display-name }}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
.. _cuda_python_aot_annotation:
3+
4+
Non-Pointwise Custom Plugin via torch_tensorrt.annotation
5+
==========================================================
6+
7+
This example demonstrates a shape-changing (non-pointwise) CUDA kernel that
8+
duplicates each input element into two output elements:
9+
y[2*i] = x[i], y[2*i + 1] = x[i]
10+
"""
11+
12+
import sys
13+
14+
import torch
15+
import torch_tensorrt
16+
17+
if not torch_tensorrt.ENABLED_FEATURES.qdp_plugin:
18+
print(
19+
"[cuda_python_aot_annotation] Skipping example: "
20+
"torch_tensorrt.annotation requires TensorRT QDP plugin support."
21+
)
22+
sys.exit(0)
23+
24+
try:
25+
import tensorrt.plugin as trtp
26+
except ImportError:
27+
print("[cuda_python_aot_annotation] Skipping example: tensorrt.plugin unavailable.")
28+
sys.exit(0)
29+
30+
try:
31+
from cuda.core import Device as _Device
32+
from cuda.core import LaunchConfig as _LaunchConfig
33+
from cuda.core import Program as _Program
34+
from cuda.core import ProgramOptions as _ProgramOptions
35+
from cuda.core import launch as _cuda_launch
36+
except ImportError:
37+
try:
38+
from cuda.core.experimental import Device as _Device
39+
from cuda.core.experimental import LaunchConfig as _LaunchConfig
40+
from cuda.core.experimental import Program as _Program
41+
from cuda.core.experimental import ProgramOptions as _ProgramOptions
42+
from cuda.core.experimental import launch as _cuda_launch
43+
except ImportError:
44+
print(
45+
"[cuda_python_aot_annotation] Skipping example: cuda-python is not "
46+
"installed. Install with `pip install cuda-python` to run this example."
47+
)
48+
sys.exit(0)
49+
50+
import torch_tensorrt.annotation as tta
51+
52+
53+
CU_REPEAT2 = """
54+
extern "C" __global__ void repeat2_kernel(
55+
const float* __restrict__ x, const int n, float* __restrict__ y) {
56+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
57+
if (i < n) {
58+
const float v = x[i];
59+
y[2 * i] = v;
60+
y[2 * i + 1] = v;
61+
}
62+
}
63+
"""
64+
65+
_device = _Device()
66+
_device.set_current()
67+
_opts = _ProgramOptions(
68+
std="c++17", arch=f"sm_{_device.arch}", include_path=["/usr/local/cuda/include"]
69+
)
70+
_program = _Program(CU_REPEAT2, code_type="c++", options=_opts)
71+
_module = _program.compile("ptx", name_expressions=("repeat2_kernel",))
72+
_kernel = _module.get_kernel("repeat2_kernel")
73+
74+
75+
class _PTStream:
76+
def __cuda_stream__(self):
77+
return (0, torch.cuda.current_stream().cuda_stream)
78+
79+
80+
def _eager_repeat2(x: torch.Tensor) -> torch.Tensor:
81+
if x.dtype != torch.float32:
82+
raise ValueError("This example expects float32 input")
83+
flat = x.contiguous().view(-1)
84+
n = int(flat.numel())
85+
y = torch.empty((n * 2,), device=x.device, dtype=x.dtype)
86+
block = 256
87+
grid = max(1, (n + block - 1) // block)
88+
stream = _device.create_stream(_PTStream())
89+
_cuda_launch(
90+
stream,
91+
_LaunchConfig(grid=(grid,), block=(block,)),
92+
_kernel,
93+
flat.data_ptr(),
94+
n,
95+
y.data_ptr(),
96+
)
97+
return y
98+
99+
100+
def _aot_repeat2(inputs, outputs, tactic):
101+
n = inputs[0].shape_expr.numel()
102+
params = trtp.KernelLaunchParams()
103+
params.grid_x = trtp.cdiv(n, 256)
104+
params.block_x = 256
105+
params.shared_mem = 0
106+
extra = trtp.SymIntExprs(1)
107+
extra[0] = trtp.SymInt32(n)
108+
return params, extra
109+
110+
111+
@tta.cuda_plugin(
112+
op_name="ann_ex::repeat2",
113+
kernel_source=CU_REPEAT2,
114+
kernel_name="repeat2_kernel",
115+
eager_fn=_eager_repeat2,
116+
aot_fn=_aot_repeat2,
117+
supports_dynamic_shapes=True,
118+
)
119+
def _repeat2_meta(x: torch.Tensor) -> torch.Tensor:
120+
return torch.empty((x.numel() * 2,), device=x.device, dtype=x.dtype)
121+
122+
123+
class Repeat2Model(torch.nn.Module):
124+
def forward(self, x: torch.Tensor) -> torch.Tensor:
125+
return torch.ops.ann_ex.repeat2(x)
126+
127+
128+
if __name__ == "__main__":
129+
x = torch.randn(1024, device="cuda", dtype=torch.float32)
130+
ref = torch.repeat_interleave(x, 2, dim=0)
131+
132+
model = Repeat2Model().cuda().eval()
133+
eager_out = model(x)
134+
print("Eager result matches repeat_interleave:", torch.allclose(eager_out, ref, atol=1e-4))
135+
136+
print("Compiling with Torch-TensorRT...")
137+
with torch_tensorrt.logging.debug():
138+
trt_model = torch_tensorrt.compile(
139+
model,
140+
inputs=[x],
141+
enabled_precisions={torch.float32},
142+
min_block_size=1,
143+
)
144+
145+
with torch.no_grad():
146+
for _ in range(5):
147+
out = trt_model(x)
148+
assert torch.allclose(out, ref, atol=1e-2, rtol=1e-2), "Mismatch!"
149+
150+
print("TRT inference successful - results match repeat_interleave")

py/torch_tensorrt/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def _register_with_torch() -> None:
9999
from torch_tensorrt.dynamo import backend # noqa: F401
100100
from torch_tensorrt import dynamo # noqa: F401
101101

102+
if ENABLED_FEATURES.qdp_plugin:
103+
from torch_tensorrt import annotation # noqa: F401
104+
102105
from torch_tensorrt._compile import * # noqa: F403
103106
from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
104107
MutableTorchTensorRTModule,

py/torch_tensorrt/_features.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
linked_file_full_path = os.path.join(trtorch_dir, linked_file)
4444
linked_file_runtime_full_path = os.path.join(trtorch_dir, linked_file_runtime)
4545

46-
_TENSORRT_RTX = tensorrt._package_name == "tensorrt_rtx"
46+
_TENSORRT_RTX = getattr(tensorrt, "_package_name", "") == "tensorrt_rtx"
4747
_TS_FE_AVAIL = os.path.isfile(linked_file_full_path)
4848
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
4949
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
@@ -57,10 +57,8 @@
5757
elif importlib.util.find_spec("tensorrt.plugin") and importlib.util.find_spec(
5858
"tensorrt.plugin._lib"
5959
):
60-
# there is a bug in tensorrt 10.14.* and 10.15.* that causes the plugin to not work, disable it for now
61-
if tensorrt.__version__.startswith("10.15.") or tensorrt.__version__.startswith(
62-
"10.14."
63-
):
60+
# TensorRT 10.14.* has a known bug that breaks QDP plugins; 10.15.+ works.
61+
if tensorrt.__version__.startswith("10.14."):
6462
_QDP_PLUGIN_AVAIL = False
6563
else:
6664
_QDP_PLUGIN_AVAIL = True
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
torch_tensorrt.annotation (experimental)
3+
==========================================
4+
High-level decorators for registering custom CUDA C++ kernels — compiled at
5+
runtime with NVRTC via **cuda-python** — as TensorRT Quick Deployable Plugins
6+
(QDP) with full AOT support.
7+
8+
Two registration paths are offered:
9+
10+
``cuda_plugin``
11+
One-shot decorator that combines ``cuda_python`` + ``custom_plugin`` for
12+
lower boilerplate in common cases.
13+
14+
``custom_plugin``
15+
Full auto-registration. Provide an eager (CUDA) implementation and a
16+
meta/fake implementation; the framework registers both the PyTorch custom
17+
op **and** the TensorRT plugin.
18+
19+
``register_custom_plugin``
20+
TRT-only registration. Use when ``@torch.library.custom_op`` has already
21+
been called. Only the TRT plugin descriptor, AOT implementation, and
22+
Torch-TensorRT converter are added.
23+
24+
``pointwise_aot``
25+
Helper that creates a standard AOT launch-config function for pointwise
26+
kernels using 1D launch geometry.
27+
28+
``pointwise_eager``
29+
Helper that builds a unary pointwise eager CUDA implementation from kernel
30+
source, reducing cuda-python compile/launch boilerplate in examples.
31+
32+
``kernel_template_aot`` / ``kernel_template_eager``
33+
Generic templates for non-pointwise and multi-dimensional kernels where
34+
users provide output allocation and launch/argument mapping callbacks.
35+
36+
Minimal example (full auto-registration)::
37+
38+
import torch, torch_tensorrt
39+
import torch_tensorrt.annotation as tta
40+
import tensorrt.plugin as trtp
41+
42+
cu_code = \"\"\"
43+
extern "C" __global__ void pointwise_relu(const float* x, int n, float* y) {
44+
int i = blockIdx.x * blockDim.x + threadIdx.x;
45+
if (i < n) y[i] = x[i] > 0.f ? x[i] : 0.f;
46+
}
47+
\"\"\"
48+
49+
def _eager_relu(x: torch.Tensor) -> torch.Tensor:
50+
from cuda.core import Device, LaunchConfig, launch as cuda_launch
51+
y = torch.empty_like(x)
52+
n = x.numel()
53+
block = 256
54+
cfg = LaunchConfig(grid=(max(1, (n + block - 1) // block),), block=(block,))
55+
dev = Device(); dev.set_current()
56+
# wrap current PyTorch stream so the kernel stays on the same stream
57+
class _Stream:
58+
def __cuda_stream__(self): return (0, torch.cuda.current_stream().cuda_stream)
59+
s = dev.create_stream(_Stream())
60+
cuda_launch(s, cfg, _kernel_obj, x.data_ptr(), n, y.data_ptr())
61+
return y
62+
63+
def _aot_relu(inputs, outputs, tactic):
64+
N = inputs[0].shape_expr.numel()
65+
p = trtp.KernelLaunchParams()
66+
p.grid_x = trtp.cdiv(N, 256)
67+
p.block_x = 256
68+
p.shared_mem = 0
69+
extra = trtp.SymIntExprs(1)
70+
extra[0] = trtp.SymInt32(N)
71+
return p, extra
72+
73+
spec = tta.cuda_python(cu_code, "pointwise_relu", aot_fn=_aot_relu, eager_fn=_eager_relu)
74+
75+
@tta.custom_plugin("myns::relu", spec, supports_dynamic_shapes=True)
76+
def _(x: torch.Tensor) -> torch.Tensor:
77+
return torch.empty_like(x)
78+
79+
# Use in a model
80+
class M(torch.nn.Module):
81+
def forward(self, x): return torch.ops.myns.relu(x)
82+
83+
model_trt = torch_tensorrt.compile(M().cuda().eval(), inputs=[torch.randn(1024, device="cuda")])
84+
"""
85+
86+
from torch_tensorrt.annotation._specs import CudaPythonSpec
87+
from torch_tensorrt.annotation._custom_plugin import (
88+
cuda_plugin,
89+
cuda_python,
90+
custom_plugin,
91+
pointwise_aot,
92+
pointwise_eager,
93+
register_custom_plugin,
94+
)
95+
from torch_tensorrt.annotation._kernel_spec import (
96+
Custom,
97+
DimSize,
98+
Elementwise,
99+
InputDecl,
100+
KernelSpec,
101+
Numel,
102+
OutputDecl,
103+
ReduceDims,
104+
Reduction,
105+
SameAs,
106+
)
107+
from torch_tensorrt.annotation._kernel_plugin import kernel_plugin
108+
109+
__all__ = [
110+
"CudaPythonSpec",
111+
"Custom",
112+
"DimSize",
113+
"Elementwise",
114+
"InputDecl",
115+
"KernelSpec",
116+
"Numel",
117+
"OutputDecl",
118+
"ReduceDims",
119+
"Reduction",
120+
"SameAs",
121+
"cuda_plugin",
122+
"cuda_python",
123+
"custom_plugin",
124+
"kernel_plugin",
125+
"pointwise_aot",
126+
"pointwise_eager",
127+
"register_custom_plugin",
128+
]

0 commit comments

Comments
 (0)