|
| 1 | +.. _cpp-custom-ops-tutorial-sycl: |
| 2 | + |
| 3 | +Custom SYCL Operators |
| 4 | +===================== |
| 5 | + |
| 6 | +.. grid:: 2 |
| 7 | + |
| 8 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 9 | + :class-card: card-prerequisites |
| 10 | + |
| 11 | + * How to integrate custom operators written in SYCL with PyTorch |
| 12 | + |
| 13 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 14 | + :class-card: card-prerequisites |
| 15 | + |
| 16 | + * PyTorch 2.8 or later |
| 17 | + * Basic understanding of SYCL programming |
| 18 | + |
| 19 | +.. note:: |
| 20 | + |
| 21 | + ``SYCL`` serves as the backend programming language for Intel GPUs (device label ``xpu``). For configuration details, see: |
| 22 | + `Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html>`_. The Intel Compiler, which comes bundled with Intel Deep Learning Essentials, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial. |
| 23 | + |
| 24 | +PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc). |
| 25 | +However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the |
| 26 | +best path to authoring a custom operator written in SYCL. Tutorials for C++ and CUDA operators are available in the :ref:`cpp-custom-ops-tutorial`. |
| 27 | + |
| 28 | +Follow the structure to create a custom SYCL operator: |
| 29 | + |
| 30 | +.. code-block:: text |
| 31 | +
|
| 32 | + sycl_example/ |
| 33 | + ├── setup.py |
| 34 | + ├── sycl_extension |
| 35 | + │ ├── __init__.py |
| 36 | + │ ├── muladd.sycl |
| 37 | + │ └── ops.py |
| 38 | + └── test_sycl_extension.py |
| 39 | +
|
| 40 | +Setting up the Build System |
| 41 | +--------------------------- |
| 42 | + |
| 43 | +If you need to compile **SYCL** code (for example, ``.sycl`` files), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension>`_. |
| 44 | +The setup process is very similar to C++/CUDA, except the compilation arguments need to be adjusted for SYCL. |
| 45 | + |
| 46 | +Using ``sycl_extension`` is as straightforward as writing the following ``setup.py``: |
| 47 | + |
| 48 | +.. code-block:: python |
| 49 | +
|
| 50 | + import os |
| 51 | + import torch |
| 52 | + import glob |
| 53 | + from setuptools import find_packages, setup |
| 54 | + from torch.utils.cpp_extension import SyclExtension, BuildExtension |
| 55 | +
|
| 56 | + library_name = "sycl_extension" |
| 57 | + py_limited_api = True |
| 58 | + extra_compile_args = { |
| 59 | + "cxx": ["-O3", |
| 60 | + "-fdiagnostics-color=always", |
| 61 | + "-DPy_LIMITED_API=0x03090000"], |
| 62 | + "sycl": ["-O3" ] |
| 63 | + } |
| 64 | +
|
| 65 | + assert(torch.xpu.is_available()), "XPU is not available, please check your environment" |
| 66 | + # Source files collection |
| 67 | + this_dir = os.path.dirname(os.path.curdir) |
| 68 | + extensions_dir = os.path.join(this_dir, library_name) |
| 69 | + sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl"))) |
| 70 | + # Construct extension |
| 71 | + ext_modules = [ |
| 72 | + SyclExtension( |
| 73 | + f"{library_name}._C", |
| 74 | + sources, |
| 75 | + extra_compile_args=extra_compile_args, |
| 76 | + py_limited_api=py_limited_api, |
| 77 | + ) |
| 78 | + ] |
| 79 | + setup( |
| 80 | + name=library_name, |
| 81 | + packages=find_packages(), |
| 82 | + ext_modules=ext_modules, |
| 83 | + install_requires=["torch"], |
| 84 | + description="Simple Example of PyTorch Sycl extensions", |
| 85 | + cmdclass={"build_ext": BuildExtension}, |
| 86 | + options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, |
| 87 | + ) |
| 88 | +
|
| 89 | +
|
| 90 | +Defining the custom op and adding backend implementations |
| 91 | +--------------------------------------------------------- |
| 92 | +First, let's write a SYCL function that computes ``mymuladd``: |
| 93 | + |
| 94 | +In order to use this from PyTorch’s Python frontend, we need to register it |
| 95 | +as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically |
| 96 | +bind the operator to Python. |
| 97 | + |
| 98 | + |
| 99 | +If you also have a SYCL implementation of ``myaddmul``, you can also register it |
| 100 | +in a separate ``TORCH_LIBRARY_IMPL`` block: |
| 101 | + |
| 102 | +.. code-block:: cpp |
| 103 | +
|
| 104 | + #include <c10/xpu/XPUStream.h> |
| 105 | + #include <sycl/sycl.hpp> |
| 106 | + #include <ATen/Operators.h> |
| 107 | + #include <torch/all.h> |
| 108 | + #include <torch/library.h> |
| 109 | +
|
| 110 | + namespace sycl_extension { |
| 111 | + // MulAdd Kernel: result = a * b + c |
| 112 | + static void muladd_kernel( |
| 113 | + int numel, const float* a, const float* b, float c, float* result, |
| 114 | + const sycl::nd_item<1>& item) { |
| 115 | + int idx = item.get_global_id(0); |
| 116 | + if (idx < numel) { |
| 117 | + result[idx] = a[idx] * b[idx] + c; |
| 118 | + } |
| 119 | + } |
| 120 | +
|
| 121 | + class MulAddKernelFunctor { |
| 122 | + public: |
| 123 | + MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result) |
| 124 | + : numel(_numel), a(_a), b(_b), c(_c), result(_result) {} |
| 125 | + void operator()(const sycl::nd_item<1>& item) const { |
| 126 | + muladd_kernel(numel, a, b, c, result, item); |
| 127 | + } |
| 128 | +
|
| 129 | + private: |
| 130 | + int numel; |
| 131 | + const float* a; |
| 132 | + const float* b; |
| 133 | + float c; |
| 134 | + float* result; |
| 135 | + }; |
| 136 | +
|
| 137 | + at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) { |
| 138 | + TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape"); |
| 139 | + TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor"); |
| 140 | + TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor"); |
| 141 | + TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor"); |
| 142 | + TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor"); |
| 143 | +
|
| 144 | + at::Tensor a_contig = a.contiguous(); |
| 145 | + at::Tensor b_contig = b.contiguous(); |
| 146 | + at::Tensor result = at::empty_like(a_contig); |
| 147 | +
|
| 148 | + const float* a_ptr = a_contig.data_ptr<float>(); |
| 149 | + const float* b_ptr = b_contig.data_ptr<float>(); |
| 150 | + float* res_ptr = result.data_ptr<float>(); |
| 151 | + int numel = a_contig.numel(); |
| 152 | +
|
| 153 | + sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); |
| 154 | + constexpr int threads = 256; |
| 155 | + int blocks = (numel + threads - 1) / threads; |
| 156 | +
|
| 157 | + queue.submit([&](sycl::handler& cgh) { |
| 158 | + cgh.parallel_for<MulAddKernelFunctor>( |
| 159 | + sycl::nd_range<1>(blocks * threads, threads), |
| 160 | + MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr) |
| 161 | + ); |
| 162 | + }); |
| 163 | +
|
| 164 | + return result; |
| 165 | + } |
| 166 | + // Defines the operators |
| 167 | + TORCH_LIBRARY(sycl_extension, m) { |
| 168 | + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); |
| 169 | + } |
| 170 | +
|
| 171 | + // ================================================== |
| 172 | + // Register SYCL Implementations to Torch Library |
| 173 | + // ================================================== |
| 174 | + TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) { |
| 175 | + m.impl("mymuladd", &mymuladd_xpu); |
| 176 | + } |
| 177 | +
|
| 178 | + } // namespace sycl_extension |
| 179 | +
|
| 180 | +
|
| 181 | +
|
| 182 | +Create a Python Interface |
| 183 | +------------------------- |
| 184 | + |
| 185 | +Create a Python interface for our operator in the ``sycl_extension/ops.py`` file: |
| 186 | + |
| 187 | +.. code-block:: python |
| 188 | +
|
| 189 | + import torch |
| 190 | + from torch import Tensor |
| 191 | + __all__ = ["mymuladd"] |
| 192 | +
|
| 193 | + def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor: |
| 194 | + """Performs a * b + c in an efficient fused kernel""" |
| 195 | + return torch.ops.sycl_extension.mymuladd.default(a, b, c) |
| 196 | +
|
| 197 | +Initialize Package |
| 198 | +------------------ |
| 199 | + |
| 200 | +Create ``sycl_extension/__init__.py`` file to make the package importable: |
| 201 | + |
| 202 | +.. code-block:: python |
| 203 | +
|
| 204 | + import ctypes |
| 205 | + from pathlib import Path |
| 206 | +
|
| 207 | + import torch |
| 208 | +
|
| 209 | + current_dir = Path(__file__).parent.parent |
| 210 | + build_dir = current_dir / "build" |
| 211 | + so_files = list(build_dir.glob("**/*.so")) |
| 212 | +
|
| 213 | + assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" |
| 214 | +
|
| 215 | + with torch._ops.dl_open_guard(): |
| 216 | + loaded_lib = ctypes.CDLL(so_files[0]) |
| 217 | +
|
| 218 | + from . import ops |
| 219 | +
|
| 220 | + __all__ = [ |
| 221 | + "loaded_lib", |
| 222 | + "ops", |
| 223 | + ] |
| 224 | +
|
| 225 | +Testing SYCL extension operator |
| 226 | +------------------- |
| 227 | + |
| 228 | +Use simple test to verify that the operator works correctly. |
| 229 | + |
| 230 | +.. code-block:: python |
| 231 | +
|
| 232 | + import torch |
| 233 | + from torch.testing._internal.common_utils import TestCase |
| 234 | + import unittest |
| 235 | + import sycl_extension |
| 236 | +
|
| 237 | + def reference_muladd(a, b, c): |
| 238 | + return a * b + c |
| 239 | +
|
| 240 | + class TestMyMulAdd(TestCase): |
| 241 | + def sample_inputs(self, device, *, requires_grad=False): |
| 242 | + def make_tensor(*size): |
| 243 | + return torch.randn(size, device=device, requires_grad=requires_grad) |
| 244 | +
|
| 245 | + def make_nondiff_tensor(*size): |
| 246 | + return torch.randn(size, device=device, requires_grad=False) |
| 247 | +
|
| 248 | + return [ |
| 249 | + [make_tensor(3), make_tensor(3), 1], |
| 250 | + [make_tensor(20), make_tensor(20), 3.14], |
| 251 | + [make_tensor(20), make_nondiff_tensor(20), -123], |
| 252 | + [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3], |
| 253 | + ] |
| 254 | +
|
| 255 | + def _test_correctness(self, device): |
| 256 | + samples = self.sample_inputs(device) |
| 257 | + for args in samples: |
| 258 | + result = sycl_extension.ops.mymuladd(*args) |
| 259 | + expected = reference_muladd(*args) |
| 260 | + torch.testing.assert_close(result, expected) |
| 261 | +
|
| 262 | + @unittest.skipIf(not torch.xpu.is_available(), "requires Intel GPU") |
| 263 | + def test_correctness_xpu(self): |
| 264 | + self._test_correctness("xpu") |
| 265 | +
|
| 266 | + if __name__ == "__main__": |
| 267 | + unittest.main() |
| 268 | +
|
| 269 | +This test checks the correctness of the custom operator by comparing its output against a reference implementation. |
| 270 | + |
| 271 | +Conclusion |
| 272 | +---------- |
| 273 | + |
| 274 | +In this tutorial, we demonstrated how to implement and compile custom SYCL operators for PyTorch. We specifically showcased an inference operation ``muladd``. For adding backward support or enabling torch.compile compatibility, please refer to :ref:`cpp-custom-ops-tutorial`. |
0 commit comments