Skip to content

Commit 3f90de8

Browse files
ZhaoqiongZAlannaBurkesvekarssekyondaMeta
authored
Add a standalone tutorial for integrating custom op using sycl for Intel GPU (#3470)
* add tutorial for integrate custom op using sycl * Update cpp_custom_ops_sycl.rst * Update advanced_source/cpp_custom_ops_sycl.rst Co-authored-by: Alanna Burke <[email protected]> * Update advanced_source/cpp_custom_ops_sycl.rst Co-authored-by: Alanna Burke <[email protected]> * Update advanced_source/cpp_custom_ops_sycl.rst Co-authored-by: Alanna Burke <[email protected]> * Update advanced_source/cpp_custom_ops_sycl.rst Co-authored-by: Alanna Burke <[email protected]> * lintrunner apply * Update advanced_source/cpp_custom_ops_sycl.rst Co-authored-by: Alanna Burke <[email protected]> --------- Co-authored-by: Alanna Burke <[email protected]> Co-authored-by: Alanna Burke <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: sekyondaMeta <[email protected]>
1 parent ef98a6b commit 3f90de8

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
.. _cpp-custom-ops-tutorial-sycl:
2+
3+
Custom SYCL Operators
4+
=====================
5+
6+
.. grid:: 2
7+
8+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
9+
:class-card: card-prerequisites
10+
11+
* How to integrate custom operators written in SYCL with PyTorch
12+
13+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
14+
:class-card: card-prerequisites
15+
16+
* PyTorch 2.8 or later
17+
* Basic understanding of SYCL programming
18+
19+
.. note::
20+
21+
``SYCL`` serves as the backend programming language for Intel GPUs (device label ``xpu``). For configuration details, see:
22+
`Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html>`_. The Intel Compiler, which comes bundled with Intel Deep Learning Essentials, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial.
23+
24+
PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
25+
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
26+
best path to authoring a custom operator written in SYCL. Tutorials for C++ and CUDA operators are available in the :ref:`cpp-custom-ops-tutorial`.
27+
28+
Follow the structure to create a custom SYCL operator:
29+
30+
.. code-block:: text
31+
32+
sycl_example/
33+
├── setup.py
34+
├── sycl_extension
35+
│ ├── __init__.py
36+
│ ├── muladd.sycl
37+
│ └── ops.py
38+
└── test_sycl_extension.py
39+
40+
Setting up the Build System
41+
---------------------------
42+
43+
If you need to compile **SYCL** code (for example, ``.sycl`` files), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension>`_.
44+
The setup process is very similar to C++/CUDA, except the compilation arguments need to be adjusted for SYCL.
45+
46+
Using ``sycl_extension`` is as straightforward as writing the following ``setup.py``:
47+
48+
.. code-block:: python
49+
50+
import os
51+
import torch
52+
import glob
53+
from setuptools import find_packages, setup
54+
from torch.utils.cpp_extension import SyclExtension, BuildExtension
55+
56+
library_name = "sycl_extension"
57+
py_limited_api = True
58+
extra_compile_args = {
59+
"cxx": ["-O3",
60+
"-fdiagnostics-color=always",
61+
"-DPy_LIMITED_API=0x03090000"],
62+
"sycl": ["-O3" ]
63+
}
64+
65+
assert(torch.xpu.is_available()), "XPU is not available, please check your environment"
66+
# Source files collection
67+
this_dir = os.path.dirname(os.path.curdir)
68+
extensions_dir = os.path.join(this_dir, library_name)
69+
sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl")))
70+
# Construct extension
71+
ext_modules = [
72+
SyclExtension(
73+
f"{library_name}._C",
74+
sources,
75+
extra_compile_args=extra_compile_args,
76+
py_limited_api=py_limited_api,
77+
)
78+
]
79+
setup(
80+
name=library_name,
81+
packages=find_packages(),
82+
ext_modules=ext_modules,
83+
install_requires=["torch"],
84+
description="Simple Example of PyTorch Sycl extensions",
85+
cmdclass={"build_ext": BuildExtension},
86+
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
87+
)
88+
89+
90+
Defining the custom op and adding backend implementations
91+
---------------------------------------------------------
92+
First, let's write a SYCL function that computes ``mymuladd``:
93+
94+
In order to use this from PyTorch’s Python frontend, we need to register it
95+
as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically
96+
bind the operator to Python.
97+
98+
99+
If you also have a SYCL implementation of ``myaddmul``, you can also register it
100+
in a separate ``TORCH_LIBRARY_IMPL`` block:
101+
102+
.. code-block:: cpp
103+
104+
#include <c10/xpu/XPUStream.h>
105+
#include <sycl/sycl.hpp>
106+
#include <ATen/Operators.h>
107+
#include <torch/all.h>
108+
#include <torch/library.h>
109+
110+
namespace sycl_extension {
111+
// MulAdd Kernel: result = a * b + c
112+
static void muladd_kernel(
113+
int numel, const float* a, const float* b, float c, float* result,
114+
const sycl::nd_item<1>& item) {
115+
int idx = item.get_global_id(0);
116+
if (idx < numel) {
117+
result[idx] = a[idx] * b[idx] + c;
118+
}
119+
}
120+
121+
class MulAddKernelFunctor {
122+
public:
123+
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
124+
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {}
125+
void operator()(const sycl::nd_item<1>& item) const {
126+
muladd_kernel(numel, a, b, c, result, item);
127+
}
128+
129+
private:
130+
int numel;
131+
const float* a;
132+
const float* b;
133+
float c;
134+
float* result;
135+
};
136+
137+
at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
138+
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
139+
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
140+
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
141+
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
142+
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
143+
144+
at::Tensor a_contig = a.contiguous();
145+
at::Tensor b_contig = b.contiguous();
146+
at::Tensor result = at::empty_like(a_contig);
147+
148+
const float* a_ptr = a_contig.data_ptr<float>();
149+
const float* b_ptr = b_contig.data_ptr<float>();
150+
float* res_ptr = result.data_ptr<float>();
151+
int numel = a_contig.numel();
152+
153+
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
154+
constexpr int threads = 256;
155+
int blocks = (numel + threads - 1) / threads;
156+
157+
queue.submit([&](sycl::handler& cgh) {
158+
cgh.parallel_for<MulAddKernelFunctor>(
159+
sycl::nd_range<1>(blocks * threads, threads),
160+
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
161+
);
162+
});
163+
164+
return result;
165+
}
166+
// Defines the operators
167+
TORCH_LIBRARY(sycl_extension, m) {
168+
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
169+
}
170+
171+
// ==================================================
172+
// Register SYCL Implementations to Torch Library
173+
// ==================================================
174+
TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) {
175+
m.impl("mymuladd", &mymuladd_xpu);
176+
}
177+
178+
} // namespace sycl_extension
179+
180+
181+
182+
Create a Python Interface
183+
-------------------------
184+
185+
Create a Python interface for our operator in the ``sycl_extension/ops.py`` file:
186+
187+
.. code-block:: python
188+
189+
import torch
190+
from torch import Tensor
191+
__all__ = ["mymuladd"]
192+
193+
def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor:
194+
"""Performs a * b + c in an efficient fused kernel"""
195+
return torch.ops.sycl_extension.mymuladd.default(a, b, c)
196+
197+
Initialize Package
198+
------------------
199+
200+
Create ``sycl_extension/__init__.py`` file to make the package importable:
201+
202+
.. code-block:: python
203+
204+
import ctypes
205+
from pathlib import Path
206+
207+
import torch
208+
209+
current_dir = Path(__file__).parent.parent
210+
build_dir = current_dir / "build"
211+
so_files = list(build_dir.glob("**/*.so"))
212+
213+
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
214+
215+
with torch._ops.dl_open_guard():
216+
loaded_lib = ctypes.CDLL(so_files[0])
217+
218+
from . import ops
219+
220+
__all__ = [
221+
"loaded_lib",
222+
"ops",
223+
]
224+
225+
Testing SYCL extension operator
226+
-------------------
227+
228+
Use simple test to verify that the operator works correctly.
229+
230+
.. code-block:: python
231+
232+
import torch
233+
from torch.testing._internal.common_utils import TestCase
234+
import unittest
235+
import sycl_extension
236+
237+
def reference_muladd(a, b, c):
238+
return a * b + c
239+
240+
class TestMyMulAdd(TestCase):
241+
def sample_inputs(self, device, *, requires_grad=False):
242+
def make_tensor(*size):
243+
return torch.randn(size, device=device, requires_grad=requires_grad)
244+
245+
def make_nondiff_tensor(*size):
246+
return torch.randn(size, device=device, requires_grad=False)
247+
248+
return [
249+
[make_tensor(3), make_tensor(3), 1],
250+
[make_tensor(20), make_tensor(20), 3.14],
251+
[make_tensor(20), make_nondiff_tensor(20), -123],
252+
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
253+
]
254+
255+
def _test_correctness(self, device):
256+
samples = self.sample_inputs(device)
257+
for args in samples:
258+
result = sycl_extension.ops.mymuladd(*args)
259+
expected = reference_muladd(*args)
260+
torch.testing.assert_close(result, expected)
261+
262+
@unittest.skipIf(not torch.xpu.is_available(), "requires Intel GPU")
263+
def test_correctness_xpu(self):
264+
self._test_correctness("xpu")
265+
266+
if __name__ == "__main__":
267+
unittest.main()
268+
269+
This test checks the correctness of the custom operator by comparing its output against a reference implementation.
270+
271+
Conclusion
272+
----------
273+
274+
In this tutorial, we demonstrated how to implement and compile custom SYCL operators for PyTorch. We specifically showcased an inference operation ``muladd``. For adding backward support or enabling torch.compile compatibility, please refer to :ref:`cpp-custom-ops-tutorial`.

advanced_source/custom_ops_landing_page.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ Integrating custom C++ and/or CUDA code with PyTorch
3030

3131
Please see :ref:`cpp-custom-ops-tutorial`.
3232

33+
.. note::
34+
35+
``SYCL`` serves as the backend programming language for Intel GPUs. Integrate custom Sycl code refer to :ref:`cpp-custom-ops-tutorial-sycl`.
36+
3337
You may wish to author a custom operator from C++ (as opposed to Python) if:
3438

3539
- you have custom C++ and/or CUDA code.

0 commit comments

Comments
 (0)