Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# C++/CUDA Extensions in PyTorch

An example of writing a C++/CUDA extension for PyTorch. See
An example of writing a C++/CUDA/Sycl extension for PyTorch. See
[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
custom op that has both custom CPU and CUDA kernels.
custom op that has both custom CPU and CUDA/Sycl kernels.

The examples in this repo work with PyTorch 2.4+.
The examples in this repo work with PyTorch 2.4 or later for C++/CUDA & PyTorch 2.8 or later for Sycl.

To build:
```
Expand Down
189 changes: 189 additions & 0 deletions extension_cpp/csrc/sycl/muladd.sycl
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#include <c10/xpu/XPUStream.h>
#include <sycl/sycl.hpp>
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>

namespace extension_cpp {


// MulAdd Kernel: result = a * b + c
static void muladd_kernel(
int numel, const float* a, const float* b, float c, float* result,
const sycl::nd_item<1>& item) {
int idx = item.get_global_id(0);
if (idx < numel) {
result[idx] = a[idx] * b[idx] + c;
}
}

// Mul Kernel: result = a * b
static void mul_kernel(
int numel, const float* a, const float* b, float* result,
const sycl::nd_item<1>& item) {
int idx = item.get_global_id(0);
if (idx < numel) {
result[idx] = a[idx] * b[idx];
}
}

// Add Kernel: result = a + b
static void add_kernel(
int numel, const float* a, const float* b, float* result,
const sycl::nd_item<1>& item) {
int idx = item.get_global_id(0);
if (idx < numel) {
result[idx] = a[idx] + b[idx];
}
}


class MulAddKernelFunctor {
public:
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {}

void operator()(const sycl::nd_item<1>& item) const {
muladd_kernel(numel, a, b, c, result, item);
}

private:
int numel;
const float* a;
const float* b;
float c;
float* result;
};

class MulKernelFunctor {
public:
MulKernelFunctor(int _numel, const float* _a, const float* _b, float* _result)
: numel(_numel), a(_a), b(_b), result(_result) {}

void operator()(const sycl::nd_item<1>& item) const {
mul_kernel(numel, a, b, result, item);
}

private:
int numel;
const float* a;
const float* b;
float* result;
};

class AddKernelFunctor {
public:
AddKernelFunctor(int _numel, const float* _a, const float* _b, float* _result)
: numel(_numel), a(_a), b(_b), result(_result) {}

void operator()(const sycl::nd_item<1>& item) const {
add_kernel(numel, a, b, result, item);
}

private:
int numel;
const float* a;
const float* b;
float* result;
};


at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");

at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty_like(a_contig);

const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* res_ptr = result.data_ptr<float>();
int numel = a_contig.numel();

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
constexpr int threads = 256;
int blocks = (numel + threads - 1) / threads;

queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for<MulAddKernelFunctor>(
sycl::nd_range<1>(blocks * threads, threads),
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
);
});
return result;
}

at::Tensor mymul_xpu(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");

at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty_like(a_contig);

const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* res_ptr = result.data_ptr<float>();
int numel = a_contig.numel();

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
constexpr int threads = 256;
int blocks = (numel + threads - 1) / threads;

queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for<MulKernelFunctor>(
sycl::nd_range<1>(blocks * threads, threads),
MulKernelFunctor(numel, a_ptr, b_ptr, res_ptr)
);
});
return result;
}

void myadd_out_xpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
TORCH_CHECK(b.sizes() == out.sizes(), "b and out must have the same shape");
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
TORCH_CHECK(out.device().is_xpu(), "out must be an XPU tensor");

at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();

const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* out_ptr = out.data_ptr<float>();
int numel = a_contig.numel();

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
constexpr int threads = 256;
int blocks = (numel + threads - 1) / threads;

queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for<AddKernelFunctor>(
sycl::nd_range<1>(blocks * threads, threads),
AddKernelFunctor(numel, a_ptr, b_ptr, out_ptr)
);
});
}

// ==================================================
// Register Sycl Implementations to Torch Library
// ==================================================

TORCH_LIBRARY_IMPL(extension_cpp, XPU, m) {
m.impl("mymuladd", mymuladd_xpu);
m.impl("mymul", mymul_xpu);
m.impl("myadd_out", myadd_out_xpu);
}

} // namespace extension_cpp
109 changes: 79 additions & 30 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,116 @@
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import glob

from setuptools import find_packages, setup

from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
BuildExtension,
CUDA_HOME,
)
# Conditional import for SyclExtension
try:
from torch.utils.cpp_extension import SyclExtension
except ImportError:
SyclExtension = None

library_name = "extension_cpp"

# Configure Py_LIMITED_API based on PyTorch version
if torch.__version__ >= "2.6.0":
py_limited_api = True
else:
py_limited_api = False


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
use_cuda = os.getenv("USE_CUDA", "1") == "1"
if debug_mode:
print("Compiling in debug mode")

use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension
# Determine backend (CUDA, SYCL, or C++)
use_cuda = os.getenv("USE_CUDA", "auto")
use_sycl = os.getenv("USE_SYCL", "auto")

# Auto-detect CUDA
if use_cuda == "auto":
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
else:
use_cuda = use_cuda.lower() == "true" or use_cuda == "1"

# Auto-detect SYCL
if use_sycl == "auto":
use_sycl = SyclExtension is not None and torch.xpu.is_available()
else:
use_sycl = use_sycl.lower() == "true" or use_sycl == "1"

if use_cuda and use_sycl:
raise RuntimeError("Cannot enable both CUDA and SYCL backends simultaneously.")

print("use cuda & use sycl",use_cuda, use_sycl)

extension = None
if use_cuda:
extension = CUDAExtension
print("Building with CUDA backend")
elif use_sycl and SyclExtension is not None:
extension = SyclExtension
print("Building with SYCL backend")
else:
extension = CppExtension
print("Building with C++ backend")

# Compilation arguments
extra_link_args = []
extra_compile_args = {
"cxx": [
extra_compile_args = {"cxx": []}
if extension == CUDAExtension:
print("CUDA is available, compile using CUDAExtension")
extra_compile_args = {
"cxx": ["-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000"],
"nvcc": ["-O3" if not debug_mode else "-O0"]
}
elif extension == SyclExtension:
print("XPU is available, compile using SyclExtension")
extra_compile_args = {
"cxx": ["-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000"],
"sycl": ["-O3" if not debug_mode else "-O0"]
}
else:
extra_compile_args["cxx"] = [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000", # min CPython version 3.9
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
],
}
"-DPy_LIMITED_API=0x03090000"]

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])

if extension == CUDAExtension:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
elif extension == SyclExtension:
extra_compile_args["sycl"].append("-g")
extra_link_args.extend(["-O0", "-g"])

# Source files collection
this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, library_name, "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
backend_sources = []
if extension == CUDAExtension:
backend_dir = os.path.join(extensions_dir, "cuda")
backend_sources = glob.glob(os.path.join(backend_dir, "*.cu"))
elif extension == SyclExtension:
backend_dir = os.path.join(extensions_dir, "sycl")
backend_sources = glob.glob(os.path.join(backend_dir, "*.sycl"))

if use_cuda:
sources += cuda_sources
sources += backend_sources

print("sources",sources)
print(len(sources))
# Construct extension
ext_modules = [
extension(
f"{library_name}._C",
Expand All @@ -71,17 +124,13 @@ def get_extensions():

return ext_modules


setup(
name=library_name,
version="0.0.1",
packages=find_packages(),
ext_modules=get_extensions(),
install_requires=["torch"],
description="Example of PyTorch C++ and CUDA extensions",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/extension-cpp",
description="Hybrid PyTorch extension supporting CUDA/SYCL/C++",
cmdclass={"build_ext": BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
)
Loading