Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
dd7b173
Add SYCL Kernels for XPU backend
xiaolil1 Jun 15, 2025
df93cdd
Merge pull request #1 from xiaolil1/jiqing
xiaolil1 Jun 16, 2025
872aa02
fix transpose
jiqing-feng Jun 16, 2025
04437a3
fix log and format
jiqing-feng Jun 16, 2025
d585bea
revert cpu changes
jiqing-feng Jun 16, 2025
1781611
clean ipex_xpu
jiqing-feng Jun 16, 2025
c982781
clean ipex import
jiqing-feng Jun 16, 2025
a4c5f8c
fix ipex cpu import
jiqing-feng Jun 16, 2025
4f076bb
fix typo
jiqing-feng Jun 16, 2025
76d7178
fix comments
jiqing-feng Jun 16, 2025
b31ea62
Merge pull request #2 from xiaolil1/jiqing
xiaolil1 Jun 16, 2025
452aa84
refine gemv_4bit kernel
xiaolil1 Jun 17, 2025
e8ac8b5
Merge branch 'main' into main
jiqing-feng Jun 17, 2025
8620a95
enable FP4 for dequant_4bit and gemv_4bit
xiaolil1 Jun 17, 2025
00f064b
refine FP4 dequantization performance
xiaolil1 Jun 17, 2025
d60750f
remove check for better performance
jiqing-feng Jun 17, 2025
59f2aa8
Merge pull request #3 from xiaolil1/jiqing
xiaolil1 Jun 17, 2025
aad358f
fix doc
jiqing-feng Jun 17, 2025
45e4451
Merge pull request #4 from xiaolil1/jiqing
xiaolil1 Jun 17, 2025
1e21ee9
clean code
xiaolil1 Jun 18, 2025
4e7f5c1
Merge branch 'main' into main
xiaolil1 Jun 18, 2025
1601652
fix tests
jiqing-feng Jun 18, 2025
1cc25ff
rm comments
jiqing-feng Jun 18, 2025
c44f38e
Merge pull request #5 from xiaolil1/jiqing
xiaolil1 Jun 18, 2025
9f283bd
fix memory issue
xiaolil1 Jun 20, 2025
9897eae
fix ut failure
xiaolil1 Jun 20, 2025
411a276
adjust threshold
jiqing-feng Jun 20, 2025
b6a3524
fix xpu check
jiqing-feng Jun 20, 2025
1c4f478
change test_functional check
jiqing-feng Jun 20, 2025
e5cf821
fix test_module
jiqing-feng Jun 20, 2025
502fe83
Merge pull request #6 from xiaolil1/jiqing
xiaolil1 Jun 20, 2025
8b54381
fix device check
jiqing-feng Jun 23, 2025
1e0f661
Merge pull request #7 from xiaolil1/jiqing_test
jiqing-feng Jun 23, 2025
99698d2
fix tests
jiqing-feng Jun 23, 2025
b88236a
Merge pull request #8 from xiaolil1/jiqing
jiqing-feng Jun 23, 2025
56c48bc
Merge branch 'main' into main
jiqing-feng Jun 24, 2025
302413e
Merge branch 'main' into main
jiqing-feng Jun 25, 2025
685962c
Enable Windows build and refine code
xiaolil1 Jun 27, 2025
7842f9d
Merge branch 'main' into main
jiqing-feng Jun 30, 2025
041b442
Merge branch 'main' into main
jiqing-feng Jul 1, 2025
aa0cf92
Merge branch 'main' into main
jiqing-feng Jul 2, 2025
b3db4bf
fix xpu log
jiqing-feng Jul 2, 2025
d66f93d
Merge pull request #9 from xiaolil1/jiqing
xiaolil1 Jul 2, 2025
5bf3159
remove ipex entirely
jiqing-feng Jul 3, 2025
005a63c
fix cpu int8 CB
jiqing-feng Jul 3, 2025
683f37c
Merge pull request #10 from xiaolil1/jiqing
xiaolil1 Jul 3, 2025
223d7d7
fix lint
jiqing-feng Jul 3, 2025
dc75ad8
Merge pull request #11 from xiaolil1/jiqing
xiaolil1 Jul 3, 2025
883d693
fix logs (#12)
jiqing-feng Jul 4, 2025
732022d
Fix sycl lint error and tests (#13)
jiqing-feng Jul 7, 2025
fc4480f
skip typo check for xpu kernel codes (#14)
jiqing-feng Jul 9, 2025
c42a38f
Merge branch 'main' into main
jiqing-feng Jul 9, 2025
38054f6
register triton kernel for quantization (#15)
jiqing-feng Jul 14, 2025
622c0ab
Merge branch 'main' into main
jiqing-feng Sep 2, 2025
bdab075
rebase main branch
jiqing-feng Sep 4, 2025
3e53783
Fix version comparison issue (#18)
shangerxin Sep 8, 2025
7d7c74c
Merge branch 'main' into main
xiaolil1 Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps, xpu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps xpu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand All @@ -54,9 +55,17 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
endif()
set(BUILD_CUDA OFF)
set(BUILD_MPS ON)
elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
if(APPLE)
message(FATAL_ERROR "XPU is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU OFF)
endif()


Expand Down Expand Up @@ -179,6 +188,12 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
elseif(BUILD_XPU)
list(APPEND SRC_FILES ${XPU_FILES})
string(APPEND BNB_OUTPUT_NAME "_xpu")
add_compile_definitions(BUILD_XPU)
set(CMAKE_C_COMPILER icx)
set(CMAKE_CXX_COMPILER icpx)
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
Expand Down Expand Up @@ -212,6 +227,19 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
if(BUILD_XPU)
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")

target_link_libraries(bitsandbytes PUBLIC ${SYCL_LIBRARY})
target_include_directories(bitsandbytes PUBLIC ${SYCL_INCLUDE_DIR})
target_link_directories(bitsandbytes PUBLIC ${SYCL_LIBRARY_DIR})

set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})
target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})

endif()

if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .cextension import ipex_cpu, ipex_xpu
from .utils import ipex_cpu

_IS_TORCH_GTE_24 = False

Expand Down Expand Up @@ -331,7 +331,7 @@ def _(
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")


if ipex_cpu or ipex_xpu:
if ipex_cpu:
# Register the dequantize_nf4_ipex implementation
torch.library.define(
"bitsandbytes::dequantize_nf4_ipex",
Expand Down
6 changes: 2 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing_extensions import deprecated

import bitsandbytes.functional as F
from bitsandbytes.functional import ipex_cpu, ipex_xpu

# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
Expand Down Expand Up @@ -426,7 +425,7 @@ def matmul(
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)

Expand All @@ -440,7 +439,7 @@ def matmul_4bit(
):
assert quant_state is not None

if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if A.device.type == "cpu" and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
Expand All @@ -450,7 +449,6 @@ def matmul_4bit(
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..._ops import register_kernel
from ...cextension import lib
from ..utils import ipex_cpu
from ...utils import ipex_cpu

# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
Expand Down
10 changes: 0 additions & 10 deletions bitsandbytes/backends/utils.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from packaging import version
import torch

try:
# to support Intel CPU/XPU (IPEX) backend
import intel_extension_for_pytorch as ipex

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None

try:
import triton # noqa: F401
import triton.language as tl # noqa: F401
Expand Down
Empty file modified bitsandbytes/backends/xpu/__init__.py
100755 → 100644
Empty file.
214 changes: 186 additions & 28 deletions bitsandbytes/backends/xpu/ops.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,58 +1,214 @@
from collections.abc import Sequence
import ctypes as ct
import warnings

import torch

from bitsandbytes.functional import _get_tensor_stream, get_ptr

from ..._ops import register_kernel
from ..utils import ipex_xpu, triton_available
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
from ..utils import triton_available

# TODO: Enable _int_mm in torch
# if torch.__version__ >= (2, 9):
# @register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
# def _(A: torch.Tensor, B: torch.Tensor):
# return torch._int_mm(
# A.reshape(-1, A.shape[-1]),
# B.t(),
# ).reshape(*A.shape[:-1], B.shape[0])


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
_get_tensor_stream(A),
)
if dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(*args)
elif dtype == torch.float16:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif dtype == torch.float32:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)

# _int_mm is available in torch starting from 2.7 version,
# but currently it's don't have xpu implementation.
if ipex_xpu and torch.__version__ >= (2, 7):

@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor):
return torch._int_mm(
A.reshape(-1, A.shape[-1]),
B.t(),
).reshape(*A.shape[:-1], B.shape[0])
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(A.numel()),
_get_tensor_stream(A),
)
if dtype == torch.float16:
lib.cdequantize_blockwise_fp16(*args)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(*args)
elif dtype == torch.float32:
lib.cdequantize_blockwise_fp32(*args)


# IPEX should be faster for xpu, so at first checking if it is available.
if ipex_xpu:
def _gemv_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
m = ct.c_int32(1)
n = ct.c_int32(shapeB[0])
k = ct.c_int32(shapeB[1])

@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
lda = m
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
ldc = m

stream = _get_tensor_stream(A)
if A.dtype == torch.float16:
lib.cgemv_4bit_inference_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.bfloat16:
lib.cgemv_4bit_inference_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.float32:
lib.cgemv_4bit_inference_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)


# SYCL should be faster for xpu, so at first checking if it is available.
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
Copy link
Contributor

@Egor-Krivov Egor-Krivov Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently you either pick all methods from SYCL or all methods from triton. However, sycl implementation right now is missing these methods, available in triton:

quantize_blockwize
quantize_4bit

I suggest we keep using these triton methods even with SYCL, since that's the only option on XPU for new.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two kernels don't affect the performance of QLoRA, they are now default running with pytorch ops and we will implemented them with SYCL kernel later.


@register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype)
out = torch.zeros(shape, dtype=dtype, device=A.device)
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out

@register_kernel("bitsandbytes::dequantize_blockwise", "xpu")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
out = torch.empty_like(A, dtype=dtype)
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
return out

@register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)

@register_kernel("bitsandbytes::gemv_4bit", "xpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
shape = A.shape
out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device)
# void cdequantize_blockwise_fp32(
# float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
if dtype == torch.float16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.bfloat16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.float32:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(shape, device=A.device, dtype=A.dtype)
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
return out

return out.reshape(shape)
@register_kernel("bitsandbytes::gemv_4bit.out", "xpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
elif triton_available:
from ..triton import ops as triton_ops

Expand All @@ -64,4 +220,6 @@ def _(
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
else:
warnings.warn("XPU available but no ipex or triton packages found.")
warnings.warn(
"XPU available but no native library or triton packages found. Please follow the installation instructions in the documentation."
)
Loading