Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(NPU_FILES csrc/npu_ops.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, npu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps npu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand All @@ -50,24 +51,36 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
set(BUILD_CUDA ON)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_NPU OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
if(APPLE)
message(FATAL_ERROR "HIP is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP ON)
set(BUILD_MPS OFF)
set(BUILD_NPU OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
set(BUILD_NPU OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "npu")
if(APPLE)
message(FATAL_ERROR "NPU is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_NPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_NPU OFF)
endif()


Expand Down Expand Up @@ -217,6 +230,40 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
elseif(BUILD_NPU)
list(APPEND SRC_FILES ${NPU_FILES})
execute_process(
COMMAND bash -c "npu-smi info|awk -F' ' 'NF > 0 && NR==7 {print $3}'"
OUTPUT_VARIABLE npu_info
RESULT_VARIABLE npu_result
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if("${npu_info}" STREQUAL "" OR ${npu_result})
message(FATAL_ERROR "Auto-detech ascend soc type failed, please specify manually or check ascend device working normally.")
endif()

set(SOC_VERSION "Ascend${npu_info}" CACHE STRING "system on chip type")
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE STRING "ASCEND CANN package installation directory")

# ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}.
# ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library
# file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp)
file(GLOB KERNEL_FILES csrc/npu_kernels.cpp)

if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
else()
message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the can package is installed")
endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)

# ascendc_library use to add kernel file to generate ascendc library
ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES})

string(APPEND BNB_OUTPUT_NAME "_npu")
add_compile_definitions(BUILD_NPU)
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
Expand All @@ -234,7 +281,11 @@ endif()

set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
add_library(bitsandbytes SHARED ${SRC_FILES})
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
if(BUILD_NPU)
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
else()
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
endif()
target_include_directories(bitsandbytes PUBLIC csrc include)


Expand Down Expand Up @@ -285,6 +336,10 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
if(BUILD_NPU)
target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17)
target_link_libraries(bitsandbytes PRIVATE $<BUILD_INTERFACE:host_intf_pub> ascendc_kernels_npu)
endif()

if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
if hasattr(torch, "hpu") and torch.hpu.is_available():
from .backends.hpu import ops as hpu_ops

if importlib.util.find_spec("torch") and importlib.util.find_spec("torch_npu"):
from .backends.npu import ops as npu_ops


def _import_backends():
"""
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def matmul_4bit(
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu" and A.device.type != "npu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
Empty file.
120 changes: 120 additions & 0 deletions bitsandbytes/backends/npu/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import ctypes as ct
from collections.abc import Sequence

import torch

from bitsandbytes.functional import get_ptr

from ..._ops import register_kernel
from ...cextension import lib
from ..utils import _NF4_QUANT_TABLE


@register_kernel("bitsandbytes::quantize_4bit", "npu")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on NPU, got {quant_type}")
n = A.numel()

global _NF4_QUANT_TABLE
if _NF4_QUANT_TABLE.device != A.device:
_NF4_QUANT_TABLE = _NF4_QUANT_TABLE.to(A.device)

# TODO: Support when weight matrix is not divisible by blocksize
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")

# Process tensor in chunks to avoid high memory usage from large intermediate tensors
# (e.g., during broadcasting with FP32 quant table)
chunks_absmax = []
chunks_out = []
total_blocks = A.numel() // blocksize
chunks = 8 if A.numel() > 1024 * 1024 else 1
chunksize = (total_blocks + chunks - 1) // chunks

for i in range(chunks):
start = i * chunksize * blocksize
end = min((i + 1) * chunksize * blocksize, A.numel())
chunk_data = A.view(-1)[start:end].view(-1, blocksize)

absmax = chunk_data.abs().max(dim=1, keepdim=True).values
chunks_absmax.append(absmax)

a = chunk_data / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - _NF4_QUANT_TABLE)
out = (torch.argmin(diff, dim=-1) + 8) % 16

out = out.reshape(-1, 2)
# Pack 4-bit values in NPU-compatible order (low nibble first) to match NPU-specific unpacking logic;
# differs from CUDA's packing
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
chunks_out.append(out)

absmax = torch.cat(chunks_absmax, dim=0)
packed = torch.cat(chunks_out, dim=0).reshape(-1, 1)
return packed, absmax


@register_kernel("bitsandbytes::dequantize_4bit", "npu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
out = torch.empty(shape, dtype=dtype, device=A.device)
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out


@register_kernel("bitsandbytes::dequantize_4bit.out", "npu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["nf4"])
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
args = (
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
torch.npu.current_stream(),
)

if out.dtype == torch.bfloat16:
# bf16: bf16 -> fp32 -> op -> fp32 -> bf16
absmax = absmax.to(torch.float32)
out = out.to(torch.float32)
lib.cdequantize_blockwise_fp32_nf4(*args)
out = out.to(torch.bfloat16)
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32_nf4(*args)
13 changes: 12 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
import re
from typing import Optional
import importlib

import torch

Expand Down Expand Up @@ -42,6 +43,13 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
return PACKAGE_DIR / library_name


def is_npu_available() -> bool:
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False
return True


class BNBNativeLibrary:
_lib: ct.CDLL
compiled_with_cuda = False
Expand Down Expand Up @@ -282,7 +290,8 @@ def get_native_library() -> BNBNativeLibrary:
raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")

binary_path = cuda_binary_path

elif is_npu_available():
binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}"
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")

# Try to load the library - any errors will propagate up
Expand Down Expand Up @@ -313,6 +322,8 @@ def get_native_library() -> BNBNativeLibrary:
try:
if torch.version.hip:
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
elif is_npu_available():
HIP_ENVIRONMENT, BNB_BACKEND = False, "NPU"
else:
HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"

Expand Down
Loading