-
-
Notifications
You must be signed in to change notification settings - Fork 790
Add SYCL Kernels for XPU backend #1679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
dd7b173
df93cdd
872aa02
04437a3
d585bea
1781611
c982781
a4c5f8c
4f076bb
76d7178
b31ea62
452aa84
e8ac8b5
8620a95
00f064b
d60750f
59f2aa8
aad358f
45e4451
1e21ee9
4e7f5c1
1601652
1cc25ff
c44f38e
9f283bd
9897eae
411a276
b6a3524
1c4f478
e5cf821
502fe83
8b54381
1e0f661
99698d2
b88236a
56c48bc
302413e
685962c
7842f9d
041b442
aa0cf92
b3db4bf
d66f93d
5bf3159
005a63c
683f37c
223d7d7
dc75ad8
883d693
732022d
fc4480f
c42a38f
38054f6
622c0ab
bdab075
3e53783
7d7c74c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,214 @@ | ||
from collections.abc import Sequence | ||
import ctypes as ct | ||
import warnings | ||
|
||
import torch | ||
|
||
from bitsandbytes.functional import _get_tensor_stream, get_ptr | ||
|
||
from ..._ops import register_kernel | ||
from ..utils import ipex_xpu, triton_available | ||
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib | ||
from ..utils import triton_available | ||
|
||
# TODO: Enable _int_mm in torch | ||
# if torch.__version__ >= (2, 9): | ||
# @register_kernel("bitsandbytes::int8_linear_matmul", "xpu") | ||
# def _(A: torch.Tensor, B: torch.Tensor): | ||
# return torch._int_mm( | ||
# A.reshape(-1, A.shape[-1]), | ||
# B.t(), | ||
# ).reshape(*A.shape[:-1], B.shape[0]) | ||
|
||
xiaolil1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def _dequantize_4bit_impl( | ||
A: torch.Tensor, | ||
absmax: torch.Tensor, | ||
blocksize: int, | ||
quant_type: str, | ||
dtype: torch.dtype, | ||
out: torch.Tensor, | ||
) -> None: | ||
args = ( | ||
None, | ||
get_ptr(A), | ||
get_ptr(absmax), | ||
get_ptr(out), | ||
ct.c_int(blocksize), | ||
ct.c_int(out.numel()), | ||
_get_tensor_stream(A), | ||
) | ||
if dtype == torch.bfloat16: | ||
if quant_type == "fp4": | ||
lib.cdequantize_blockwise_bf16_fp4(*args) | ||
else: | ||
lib.cdequantize_blockwise_bf16_nf4(*args) | ||
elif dtype == torch.float16: | ||
if quant_type == "fp4": | ||
lib.cdequantize_blockwise_fp16_fp4(*args) | ||
else: | ||
lib.cdequantize_blockwise_fp16_nf4(*args) | ||
elif dtype == torch.float32: | ||
if quant_type == "fp4": | ||
lib.cdequantize_blockwise_fp32_fp4(*args) | ||
else: | ||
lib.cdequantize_blockwise_fp32_nf4(*args) | ||
|
||
# _int_mm is available in torch starting from 2.7 version, | ||
# but currently it's don't have xpu implementation. | ||
if ipex_xpu and torch.__version__ >= (2, 7): | ||
|
||
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu") | ||
def _(A: torch.Tensor, B: torch.Tensor): | ||
return torch._int_mm( | ||
A.reshape(-1, A.shape[-1]), | ||
B.t(), | ||
).reshape(*A.shape[:-1], B.shape[0]) | ||
def _dequantize_blockwise_impl( | ||
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor | ||
) -> None: | ||
args = ( | ||
get_ptr(code), | ||
get_ptr(A), | ||
get_ptr(absmax), | ||
get_ptr(out), | ||
ct.c_int(blocksize), | ||
ct.c_int(A.numel()), | ||
_get_tensor_stream(A), | ||
) | ||
if dtype == torch.float16: | ||
lib.cdequantize_blockwise_fp16(*args) | ||
elif dtype == torch.bfloat16: | ||
lib.cdequantize_blockwise_bf16(*args) | ||
elif dtype == torch.float32: | ||
lib.cdequantize_blockwise_fp32(*args) | ||
|
||
|
||
# IPEX should be faster for xpu, so at first checking if it is available. | ||
if ipex_xpu: | ||
def _gemv_4bit_impl( | ||
A: torch.Tensor, | ||
B: torch.Tensor, | ||
shapeB: Sequence[int], | ||
absmax: torch.Tensor, | ||
code: torch.Tensor, | ||
blocksize: int, | ||
out: torch.Tensor, | ||
) -> None: | ||
m = ct.c_int32(1) | ||
n = ct.c_int32(shapeB[0]) | ||
k = ct.c_int32(shapeB[1]) | ||
|
||
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") | ||
lda = m | ||
ldb = ct.c_int32((A.shape[-1] + 1) // 2) | ||
ldc = m | ||
|
||
stream = _get_tensor_stream(A) | ||
if A.dtype == torch.float16: | ||
lib.cgemv_4bit_inference_fp16( | ||
m, | ||
n, | ||
k, | ||
get_ptr(A), | ||
get_ptr(B), | ||
get_ptr(absmax), | ||
get_ptr(code), | ||
get_ptr(out), | ||
lda, | ||
ldb, | ||
ldc, | ||
ct.c_int32(blocksize), | ||
stream, | ||
) | ||
elif A.dtype == torch.bfloat16: | ||
lib.cgemv_4bit_inference_bf16( | ||
m, | ||
n, | ||
k, | ||
get_ptr(A), | ||
get_ptr(B), | ||
get_ptr(absmax), | ||
get_ptr(code), | ||
get_ptr(out), | ||
lda, | ||
ldb, | ||
ldc, | ||
ct.c_int32(blocksize), | ||
stream, | ||
) | ||
elif A.dtype == torch.float32: | ||
lib.cgemv_4bit_inference_fp32( | ||
m, | ||
n, | ||
k, | ||
get_ptr(A), | ||
get_ptr(B), | ||
get_ptr(absmax), | ||
get_ptr(code), | ||
get_ptr(out), | ||
lda, | ||
ldb, | ||
ldc, | ||
ct.c_int32(blocksize), | ||
stream, | ||
) | ||
|
||
|
||
# SYCL should be faster for xpu, so at first checking if it is available. | ||
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently you either pick all methods from SYCL or all methods from triton. However, sycl implementation right now is missing these methods, available in triton:
I suggest we keep using these triton methods even with SYCL, since that's the only option on XPU for new. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two kernels don't affect the performance of QLoRA, they are now default running with pytorch ops and we will implemented them with SYCL kernel later. |
||
|
||
@register_kernel("bitsandbytes::dequantize_4bit", "xpu") | ||
def _( | ||
A: torch.Tensor, | ||
absmax: torch.Tensor, | ||
blocksize: int, | ||
quant_type: str, | ||
shape: Sequence[int], | ||
dtype: torch.dtype, | ||
) -> torch.Tensor: | ||
return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) | ||
out = torch.zeros(shape, dtype=dtype, device=A.device) | ||
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) | ||
return out | ||
|
||
@register_kernel("bitsandbytes::dequantize_blockwise", "xpu") | ||
def _( | ||
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype | ||
) -> torch.Tensor: | ||
out = torch.empty_like(A, dtype=dtype) | ||
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) | ||
return out | ||
|
||
@register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu") | ||
def _( | ||
A: torch.Tensor, | ||
absmax: torch.Tensor, | ||
code: torch.Tensor, | ||
blocksize: int, | ||
dtype: torch.dtype, | ||
out: torch.Tensor, | ||
) -> None: | ||
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") | ||
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") | ||
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) | ||
|
||
@register_kernel("bitsandbytes::gemv_4bit", "xpu") | ||
def _( | ||
A: torch.Tensor, | ||
B: torch.Tensor, | ||
shapeB: Sequence[int], | ||
absmax: torch.Tensor, | ||
code: torch.Tensor, | ||
blocksize: int, | ||
) -> torch.Tensor: | ||
shape = A.shape | ||
out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) | ||
# void cdequantize_blockwise_fp32( | ||
# float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) | ||
if dtype == torch.float16: | ||
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) | ||
elif dtype == torch.bfloat16: | ||
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) | ||
elif dtype == torch.float32: | ||
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) | ||
else: | ||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") | ||
shape = (*A.shape[:-1], shapeB[0]) | ||
out = torch.empty(shape, device=A.device, dtype=A.dtype) | ||
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) | ||
return out | ||
|
||
return out.reshape(shape) | ||
@register_kernel("bitsandbytes::gemv_4bit.out", "xpu") | ||
def _( | ||
A: torch.Tensor, | ||
B: torch.Tensor, | ||
shapeB: Sequence[int], | ||
absmax: torch.Tensor, | ||
code: torch.Tensor, | ||
blocksize: int, | ||
out: torch.Tensor, | ||
) -> None: | ||
torch._check( | ||
out.shape == (*A.shape[:-1], shapeB[0]), | ||
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", | ||
) | ||
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") | ||
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) | ||
elif triton_available: | ||
from ..triton import ops as triton_ops | ||
|
||
|
@@ -64,4 +220,6 @@ def _( | |
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) | ||
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) | ||
else: | ||
warnings.warn("XPU available but no ipex or triton packages found.") | ||
warnings.warn( | ||
"XPU available but no native library or triton packages found. Please follow the installation instructions in the documentation." | ||
) |
Uh oh!
There was an error while loading. Please reload this page.