|
| 1 | +""" |
| 2 | +.. _cuda_python_aot_annotation: |
| 3 | +
|
| 4 | +Non-Pointwise Custom Plugin via torch_tensorrt.annotation |
| 5 | +========================================================== |
| 6 | +
|
| 7 | +This example demonstrates a shape-changing (non-pointwise) CUDA kernel that |
| 8 | +duplicates each input element into two output elements: |
| 9 | + y[2*i] = x[i], y[2*i + 1] = x[i] |
| 10 | +""" |
| 11 | + |
| 12 | +import sys |
| 13 | + |
| 14 | +import torch |
| 15 | +import torch_tensorrt |
| 16 | + |
| 17 | +if not torch_tensorrt.ENABLED_FEATURES.qdp_plugin: |
| 18 | + print( |
| 19 | + "[cuda_python_aot_annotation] Skipping example: " |
| 20 | + "torch_tensorrt.annotation requires TensorRT QDP plugin support." |
| 21 | + ) |
| 22 | + sys.exit(0) |
| 23 | + |
| 24 | +try: |
| 25 | + import tensorrt.plugin as trtp |
| 26 | +except ImportError: |
| 27 | + print("[cuda_python_aot_annotation] Skipping example: tensorrt.plugin unavailable.") |
| 28 | + sys.exit(0) |
| 29 | + |
| 30 | +try: |
| 31 | + from cuda.core import Device as _Device |
| 32 | + from cuda.core import LaunchConfig as _LaunchConfig |
| 33 | + from cuda.core import Program as _Program |
| 34 | + from cuda.core import ProgramOptions as _ProgramOptions |
| 35 | + from cuda.core import launch as _cuda_launch |
| 36 | +except ImportError: |
| 37 | + try: |
| 38 | + from cuda.core.experimental import Device as _Device |
| 39 | + from cuda.core.experimental import LaunchConfig as _LaunchConfig |
| 40 | + from cuda.core.experimental import Program as _Program |
| 41 | + from cuda.core.experimental import ProgramOptions as _ProgramOptions |
| 42 | + from cuda.core.experimental import launch as _cuda_launch |
| 43 | + except ImportError: |
| 44 | + print( |
| 45 | + "[cuda_python_aot_annotation] Skipping example: cuda-python is not " |
| 46 | + "installed. Install with `pip install cuda-python` to run this example." |
| 47 | + ) |
| 48 | + sys.exit(0) |
| 49 | + |
| 50 | +import torch_tensorrt.annotation as tta |
| 51 | + |
| 52 | + |
| 53 | +CU_REPEAT2 = """ |
| 54 | +extern "C" __global__ void repeat2_kernel( |
| 55 | + const float* __restrict__ x, const int n, float* __restrict__ y) { |
| 56 | + const int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 57 | + if (i < n) { |
| 58 | + const float v = x[i]; |
| 59 | + y[2 * i] = v; |
| 60 | + y[2 * i + 1] = v; |
| 61 | + } |
| 62 | +} |
| 63 | +""" |
| 64 | + |
| 65 | +_device = _Device() |
| 66 | +_device.set_current() |
| 67 | +_opts = _ProgramOptions( |
| 68 | + std="c++17", arch=f"sm_{_device.arch}", include_path=["/usr/local/cuda/include"] |
| 69 | +) |
| 70 | +_program = _Program(CU_REPEAT2, code_type="c++", options=_opts) |
| 71 | +_module = _program.compile("ptx", name_expressions=("repeat2_kernel",)) |
| 72 | +_kernel = _module.get_kernel("repeat2_kernel") |
| 73 | + |
| 74 | + |
| 75 | +class _PTStream: |
| 76 | + def __cuda_stream__(self): |
| 77 | + return (0, torch.cuda.current_stream().cuda_stream) |
| 78 | + |
| 79 | + |
| 80 | +def _eager_repeat2(x: torch.Tensor) -> torch.Tensor: |
| 81 | + if x.dtype != torch.float32: |
| 82 | + raise ValueError("This example expects float32 input") |
| 83 | + flat = x.contiguous().view(-1) |
| 84 | + n = int(flat.numel()) |
| 85 | + y = torch.empty((n * 2,), device=x.device, dtype=x.dtype) |
| 86 | + block = 256 |
| 87 | + grid = max(1, (n + block - 1) // block) |
| 88 | + stream = _device.create_stream(_PTStream()) |
| 89 | + _cuda_launch( |
| 90 | + stream, |
| 91 | + _LaunchConfig(grid=(grid,), block=(block,)), |
| 92 | + _kernel, |
| 93 | + flat.data_ptr(), |
| 94 | + n, |
| 95 | + y.data_ptr(), |
| 96 | + ) |
| 97 | + return y |
| 98 | + |
| 99 | + |
| 100 | +def _aot_repeat2(inputs, outputs, tactic): |
| 101 | + n = inputs[0].shape_expr.numel() |
| 102 | + params = trtp.KernelLaunchParams() |
| 103 | + params.grid_x = trtp.cdiv(n, 256) |
| 104 | + params.block_x = 256 |
| 105 | + params.shared_mem = 0 |
| 106 | + extra = trtp.SymIntExprs(1) |
| 107 | + extra[0] = trtp.SymInt32(n) |
| 108 | + return params, extra |
| 109 | + |
| 110 | + |
| 111 | +@tta.cuda_plugin( |
| 112 | + op_name="ann_ex::repeat2", |
| 113 | + kernel_source=CU_REPEAT2, |
| 114 | + kernel_name="repeat2_kernel", |
| 115 | + eager_fn=_eager_repeat2, |
| 116 | + aot_fn=_aot_repeat2, |
| 117 | + supports_dynamic_shapes=True, |
| 118 | +) |
| 119 | +def _repeat2_meta(x: torch.Tensor) -> torch.Tensor: |
| 120 | + return torch.empty((x.numel() * 2,), device=x.device, dtype=x.dtype) |
| 121 | + |
| 122 | + |
| 123 | +class Repeat2Model(torch.nn.Module): |
| 124 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 125 | + return torch.ops.ann_ex.repeat2(x) |
| 126 | + |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + x = torch.randn(1024, device="cuda", dtype=torch.float32) |
| 130 | + ref = torch.repeat_interleave(x, 2, dim=0) |
| 131 | + |
| 132 | + model = Repeat2Model().cuda().eval() |
| 133 | + eager_out = model(x) |
| 134 | + print("Eager result matches repeat_interleave:", torch.allclose(eager_out, ref, atol=1e-4)) |
| 135 | + |
| 136 | + print("Compiling with Torch-TensorRT...") |
| 137 | + with torch_tensorrt.logging.debug(): |
| 138 | + trt_model = torch_tensorrt.compile( |
| 139 | + model, |
| 140 | + inputs=[x], |
| 141 | + enabled_precisions={torch.float32}, |
| 142 | + min_block_size=1, |
| 143 | + ) |
| 144 | + |
| 145 | + with torch.no_grad(): |
| 146 | + for _ in range(5): |
| 147 | + out = trt_model(x) |
| 148 | + assert torch.allclose(out, ref, atol=1e-2, rtol=1e-2), "Mismatch!" |
| 149 | + |
| 150 | + print("TRT inference successful - results match repeat_interleave") |
0 commit comments