diff --git a/CMakeLists.txt b/CMakeLists.txt index 770b4ba30..c56aa0c99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(NPU_FILES csrc/npu_ops.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, npu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps npu) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -50,6 +51,7 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") set(BUILD_CUDA ON) set(BUILD_HIP OFF) set(BUILD_MPS OFF) + set(BUILD_NPU OFF) elseif(${COMPUTE_BACKEND} STREQUAL "hip") if(APPLE) message(FATAL_ERROR "HIP is not supported on macOS" ) @@ -57,6 +59,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "hip") set(BUILD_CUDA OFF) set(BUILD_HIP ON) set(BUILD_MPS OFF) + set(BUILD_NPU OFF) elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -64,10 +67,20 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS ON) + set(BUILD_NPU OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "npu") + if(APPLE) + message(FATAL_ERROR "NPU is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + set(BUILD_NPU ON) else() set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS OFF) + set(BUILD_NPU OFF) endif() @@ -217,6 +230,40 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_NPU) + list(APPEND SRC_FILES ${NPU_FILES}) + execute_process( + COMMAND bash -c "npu-smi info|awk -F' ' 'NF > 0 && NR==7 {print $3}'" + OUTPUT_VARIABLE npu_info + RESULT_VARIABLE npu_result + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if("${npu_info}" STREQUAL "" OR ${npu_result}) + message(FATAL_ERROR "Auto-detech ascend soc type failed, please specify manually or check ascend device working normally.") + endif() + + set(SOC_VERSION "Ascend${npu_info}" CACHE STRING "system on chip type") + set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE STRING "ASCEND CANN package installation directory") + + # ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}. + # ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library + # file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp) + file(GLOB KERNEL_FILES csrc/npu_kernels.cpp) + + if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the can package is installed") + endif() + include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + + # ascendc_library use to add kernel file to generate ascendc library + ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_npu") + add_compile_definitions(BUILD_NPU) else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -234,7 +281,11 @@ endif() set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) add_library(bitsandbytes SHARED ${SRC_FILES}) -target_compile_features(bitsandbytes PUBLIC cxx_std_14) +if(BUILD_NPU) + target_compile_features(bitsandbytes PUBLIC cxx_std_17) +else() + target_compile_features(bitsandbytes PUBLIC cxx_std_14) +endif() target_include_directories(bitsandbytes PUBLIC csrc include) @@ -285,6 +336,10 @@ if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_NPU) + target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17) + target_link_libraries(bitsandbytes PRIVATE $ ascendc_kernels_npu) +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 516afa51f..57efcbba2 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -46,6 +46,9 @@ if hasattr(torch, "hpu") and torch.hpu.is_available(): from .backends.hpu import ops as hpu_ops +if importlib.util.find_spec("torch") and importlib.util.find_spec("torch_npu"): + from .backends.npu import ops as npu_ops + def _import_backends(): """ diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 80fc86861..5b9522388 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -451,7 +451,7 @@ def matmul_4bit( else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu" and A.device.type != "npu": if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/npu/__init__.py b/bitsandbytes/backends/npu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/npu/ops.py b/bitsandbytes/backends/npu/ops.py new file mode 100644 index 000000000..e87b324fd --- /dev/null +++ b/bitsandbytes/backends/npu/ops.py @@ -0,0 +1,120 @@ +import ctypes as ct +from collections.abc import Sequence + +import torch + +from bitsandbytes.functional import get_ptr + +from ..._ops import register_kernel +from ...cextension import lib +from ..utils import _NF4_QUANT_TABLE + + +@register_kernel("bitsandbytes::quantize_4bit", "npu") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on NPU, got {quant_type}") + n = A.numel() + + global _NF4_QUANT_TABLE + if _NF4_QUANT_TABLE.device != A.device: + _NF4_QUANT_TABLE = _NF4_QUANT_TABLE.to(A.device) + + # TODO: Support when weight matrix is not divisible by blocksize + torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") + + # Process tensor in chunks to avoid high memory usage from large intermediate tensors + # (e.g., during broadcasting with FP32 quant table) + chunks_absmax = [] + chunks_out = [] + total_blocks = A.numel() // blocksize + chunks = 8 if A.numel() > 1024 * 1024 else 1 + chunksize = (total_blocks + chunks - 1) // chunks + + for i in range(chunks): + start = i * chunksize * blocksize + end = min((i + 1) * chunksize * blocksize, A.numel()) + chunk_data = A.view(-1)[start:end].view(-1, blocksize) + + absmax = chunk_data.abs().max(dim=1, keepdim=True).values + chunks_absmax.append(absmax) + + a = chunk_data / absmax.float() + diff = torch.abs(a.unsqueeze(-1) - _NF4_QUANT_TABLE) + out = (torch.argmin(diff, dim=-1) + 8) % 16 + + out = out.reshape(-1, 2) + # Pack 4-bit values in NPU-compatible order (low nibble first) to match NPU-specific unpacking logic; + # differs from CUDA's packing + out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) + chunks_out.append(out) + + absmax = torch.cat(chunks_absmax, dim=0) + packed = torch.cat(chunks_out, dim=0).reshape(-1, 1) + return packed, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "npu") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "npu") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + args = ( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + torch.npu.current_stream(), + ) + + if out.dtype == torch.bfloat16: + # bf16: bf16 -> fp32 -> op -> fp32 -> bf16 + absmax = absmax.to(torch.float32) + out = out.to(torch.float32) + lib.cdequantize_blockwise_fp32_nf4(*args) + out = out.to(torch.bfloat16) + elif out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32_nf4(*args) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index bb301e712..5d8d2bb86 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -5,6 +5,7 @@ from pathlib import Path import re from typing import Optional +import importlib import torch @@ -42,6 +43,13 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: return PACKAGE_DIR / library_name +def is_npu_available() -> bool: + """Checks if `torch_npu` is installed and potentially if a NPU is in the environment""" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + return True + + class BNBNativeLibrary: _lib: ct.CDLL compiled_with_cuda = False @@ -282,7 +290,8 @@ def get_native_library() -> BNBNativeLibrary: raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}") binary_path = cuda_binary_path - + elif is_npu_available(): + binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}" logger.debug(f"Loading bitsandbytes native library from: {binary_path}") # Try to load the library - any errors will propagate up @@ -313,6 +322,8 @@ def get_native_library() -> BNBNativeLibrary: try: if torch.version.hip: HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" + elif is_npu_available(): + HIP_ENVIRONMENT, BNB_BACKEND = False, "NPU" else: HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" diff --git a/csrc/npu_kernels.cpp b/csrc/npu_kernels.cpp new file mode 100644 index 000000000..1db0e9eb4 --- /dev/null +++ b/csrc/npu_kernels.cpp @@ -0,0 +1,218 @@ +#include "kernel_operator.h" +#include "npu_ops.h" + +using namespace AscendC; + +constexpr int32_t BUFFER_NUM = 1; + +#define CEIL32(num) (((num) + 32 - 1) / 32 * 32) + + +template +class KernelDequantizeBlockwiseNf4 { +public: + __aicore__ inline KernelDequantizeBlockwiseNf4() {} + __aicore__ inline void Init( + GM_ADDR A, + GM_ADDR absmax, + GM_ADDR out, + uint32_t blocksize, + uint32_t coreNum, + uint32_t singleCoreNumel, + uint32_t singleCoreNumelTail, + uint32_t numel, + uint32_t ubSize, + TPipe &pipe + ) + { + ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); + this->blocksize = blocksize; + uint32_t blockIdx = (uint32_t)GetBlockIdx(); + if (coreNum - blockIdx == 1) { + this->CurCoreFP16Num = singleCoreNumelTail; + } else { + this->CurCoreFP16Num = singleCoreNumel; + } + constexpr uint32_t ELEMENT_BYTES = (TypeMode == 1) ? 4 : 2; // FP32: 4bytes, FP16/BF16: 2bytes + uint32_t eachBatchPkgNum = (ubSize - 16 * ELEMENT_BYTES) / + (this->blocksize / 2 * BUFFER_NUM + ELEMENT_BYTES * BUFFER_NUM + this->blocksize * + (ELEMENT_BYTES * BUFFER_NUM + sizeof(half) + sizeof(uint32_t) + ELEMENT_BYTES)); + if (eachBatchPkgNum >= 32 / ELEMENT_BYTES) { + eachBatchPkgNum = (eachBatchPkgNum / (32 / ELEMENT_BYTES)) * (32 / ELEMENT_BYTES); + } else { + eachBatchPkgNum = (eachBatchPkgNum / 2) * 2; + } + this->eachBatchFP16Num = this->blocksize * eachBatchPkgNum; // 64 * 288 + + // gm, 32-byte alignment + uint32_t AOffset = singleCoreNumel / 2 * blockIdx; + uint32_t ABufferSize = singleCoreNumel / 2; + AGm.SetGlobalBuffer((__gm__ int8_t*)A + AOffset, ABufferSize); + uint32_t absmaxOffset = singleCoreNumel / this->blocksize * blockIdx; + uint32_t absmaxBufferSize = singleCoreNumel / this->blocksize; + absmaxGm.SetGlobalBuffer((__gm__ T*)absmax + absmaxOffset, absmaxBufferSize); + uint32_t outOffset = singleCoreNumel * blockIdx; + uint32_t outBufferSize = singleCoreNumel; + outGm.SetGlobalBuffer((__gm__ T*)out + outOffset, outBufferSize); + + // TQue, 32-byte alignment + pipe.InitBuffer(inQueueA, BUFFER_NUM, this->eachBatchFP16Num / 2); + pipe.InitBuffer(inQueueAbsmax, BUFFER_NUM, CEIL32(eachBatchPkgNum * ELEMENT_BYTES)); + pipe.InitBuffer(outQueueOut, BUFFER_NUM, this->eachBatchFP16Num * ELEMENT_BYTES); + + // TBuf, 32-byte alignment + pipe.InitBuffer(calcNf4ToFloat, 16 * ELEMENT_BYTES); + pipe.InitBuffer(calcAFP16, this->eachBatchFP16Num * sizeof(half)); + pipe.InitBuffer(calcAUint32, this->eachBatchFP16Num * sizeof(uint32_t)); + pipe.InitBuffer(calcAbsmaxBuf, this->eachBatchFP16Num * ELEMENT_BYTES); + } + + __aicore__ inline void Process(void) + { + Compute(); + } + +private: + __aicore__ inline void initNf4ToFloat(LocalTensor &nf4ToFloat) + { + nf4ToFloat(0) = static_cast(-1.0); + nf4ToFloat(1) = static_cast(-0.6961928009986877); + nf4ToFloat(2) = static_cast(-0.5250730514526367); + nf4ToFloat(3) = static_cast(-0.39491748809814453); + nf4ToFloat(4) = static_cast(-0.28444138169288635); + nf4ToFloat(5) = static_cast(-0.18477343022823334); + nf4ToFloat(6) = static_cast(-0.09105003625154495); + nf4ToFloat(7) = static_cast(0.0); + nf4ToFloat(8) = static_cast(0.07958029955625534); + nf4ToFloat(9) = static_cast(0.16093020141124725); + nf4ToFloat(10) = static_cast(0.24611230194568634); + nf4ToFloat(11) = static_cast(0.33791524171829224); + nf4ToFloat(12) = static_cast(0.44070982933044434); + nf4ToFloat(13) = static_cast(0.5626170039176941); + nf4ToFloat(14) = static_cast(0.7229568362236023); + nf4ToFloat(15) = static_cast(1.0); + } + + __aicore__ inline void Compute(void) + { + constexpr uint32_t ELEMENT_BYTES = (TypeMode == 1) ? 4 : 2; // FP32: 4bytes, FP16/BF16: 2bytes + LocalTensor ALocal = inQueueA.AllocTensor(); + LocalTensor absmaxLocal = inQueueAbsmax.AllocTensor(); + LocalTensor outLocal = outQueueOut.AllocTensor(); + + LocalTensor AFP16 = calcAFP16.Get(); + LocalTensor AInt32 = calcAUint32.Get(); + LocalTensor absmaxBuf = calcAbsmaxBuf.Get(); + LocalTensor nf4ToFloat = calcNf4ToFloat.Get(); + initNf4ToFloat(nf4ToFloat); + + DataCopyParams dataCopyParams = {1, 0, 0, 0}; + uint32_t curBatchNumel = this->eachBatchFP16Num; + uint32_t curBatchPkgNum = curBatchNumel / this->blocksize; + + uint32_t batchCount = (this->CurCoreFP16Num + this->eachBatchFP16Num - 1) / this->eachBatchFP16Num; + for (uint32_t batchIdx = 0; batchIdx < batchCount; batchIdx++) { + if (batchCount - batchIdx == 1) { + curBatchNumel = this->CurCoreFP16Num - this->eachBatchFP16Num * batchIdx; + curBatchPkgNum = (curBatchNumel + this->blocksize - 1) / this->blocksize; + } + + dataCopyParams.blockLen = curBatchNumel / 2; // Byte + DataCopyPad(ALocal, AGm[this->eachBatchFP16Num / 2 * batchIdx], dataCopyParams, {true, 0, 0, 0}); + dataCopyParams.blockLen = ELEMENT_BYTES * curBatchPkgNum; // Byte + uint32_t gmOffset = this->eachBatchFP16Num / this->blocksize * batchIdx; + DataCopyPad(absmaxLocal, absmaxGm[gmOffset], dataCopyParams, {true, 0, 0, 0}); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_ALL); + + LocalTensor AInt4 = ALocal.ReinterpretCast(); + Cast(AFP16, AInt4, RoundMode::CAST_NONE, curBatchNumel); + pipe_barrier(PIPE_V); + Adds(AFP16, AFP16, static_cast(8), curBatchNumel); + pipe_barrier(PIPE_V); + if constexpr (TypeMode == 1) { + Muls(AFP16, AFP16, static_cast(4), curBatchNumel); + } else { + Muls(AFP16, AFP16, static_cast(2), curBatchNumel); + } + pipe_barrier(PIPE_V); + Cast(AInt32, AFP16, RoundMode::CAST_ROUND, curBatchNumel); + pipe_barrier(PIPE_V); + LocalTensor AUint32 = AInt32.ReinterpretCast(); + Gather(outLocal, nf4ToFloat, AUint32, 0, curBatchNumel); + pipe_barrier(PIPE_V); + uint32_t dstShape[] = {curBatchPkgNum, this->blocksize}; + uint32_t srcShape[] = {curBatchPkgNum, 1}; + BroadCast(absmaxBuf, absmaxLocal, dstShape, srcShape); + pipe_barrier(PIPE_ALL); + Mul(outLocal, outLocal, absmaxBuf, curBatchNumel); + pipe_barrier(PIPE_ALL); + + dataCopyParams.blockLen = ELEMENT_BYTES * curBatchNumel; // Byte + DataCopyPad(outGm[batchIdx * this->eachBatchFP16Num], outLocal, dataCopyParams); + pipe_barrier(PIPE_MTE3); + } + pipe_barrier(PIPE_ALL); + + inQueueA.FreeTensor(ALocal); + inQueueAbsmax.FreeTensor(absmaxLocal); + outQueueOut.FreeTensor(outLocal); + } + +private: + TQue inQueueA; + TQue inQueueAbsmax; + TQue outQueueOut; + TBuf calcAFP16; + TBuf calcAUint32; + TBuf calcNf4ToFloat; + TBuf calcAbsmaxBuf; + GlobalTensor AGm; + GlobalTensor absmaxGm; + GlobalTensor outGm; + uint32_t blocksize; + uint32_t CurCoreFP16Num; + uint32_t eachBatchFP16Num; +}; + + +extern "C" { + +__global__ __aicore__ void dequantize_blockwise_fp32_nf4( + GM_ADDR A, + GM_ADDR absmax, + GM_ADDR out, + uint32_t blocksize, + uint32_t coreNum, + uint32_t singleCoreNumel, + uint32_t singleCoreNumelTail, + uint32_t numel, + uint32_t ubSize +) +{ + TPipe pipe; + KernelDequantizeBlockwiseNf4 op; + op.Init(A, absmax, out, blocksize, coreNum, singleCoreNumel, singleCoreNumelTail, numel, ubSize, pipe); + op.Process(); +} + +__global__ __aicore__ void dequantize_blockwise_fp16_nf4( + GM_ADDR A, + GM_ADDR absmax, + GM_ADDR out, + uint32_t blocksize, + uint32_t coreNum, + uint32_t singleCoreNumel, + uint32_t singleCoreNumelTail, + uint32_t numel, + uint32_t ubSize +) +{ + TPipe pipe; + KernelDequantizeBlockwiseNf4 op; + op.Init(A, absmax, out, blocksize, coreNum, singleCoreNumel, singleCoreNumelTail, numel, ubSize, pipe); + op.Process(); +} + +} \ No newline at end of file diff --git a/csrc/npu_ops.cpp b/csrc/npu_ops.cpp new file mode 100644 index 000000000..9946b7c54 --- /dev/null +++ b/csrc/npu_ops.cpp @@ -0,0 +1,72 @@ +#include +#include +#include "acl/acl.h" +#include "tiling/platform/platform_ascendc.h" +#include "npu_ops.h" + +#include "aclrtlaunch_dequantize_blockwise_fp32_nf4.h" +#include "aclrtlaunch_dequantize_blockwise_fp16_nf4.h" + + +extern "C" { + +int32_t get_dequantize_blockwise_nf4_tiling(uint32_t blocksize, uint32_t n, BlockwiseNf4TilingData *tiling) { + tiling->ubSize = 196 * 1024; + uint32_t coreNum = 40; + uint32_t totalPkgNum = (n + blocksize - 1) / blocksize; + uint32_t singleCorePkgNum = (totalPkgNum + coreNum - 1) / coreNum; + coreNum = (totalPkgNum + singleCorePkgNum - 1) / singleCorePkgNum; + uint32_t singleCoreNumel = singleCorePkgNum * blocksize; + uint32_t singleCoreNumelTail = n % singleCoreNumel; + if (singleCoreNumelTail == 0) { + singleCoreNumelTail = singleCoreNumel; + } + tiling->coreNum = coreNum; + tiling->blocksize = blocksize; + tiling->numel = n; + tiling->singleCoreNumel = singleCoreNumel; + tiling->singleCoreNumelTail = singleCoreNumelTail; + return 0; +} + +void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream, const uint32_t type_mode) { + uint32_t blockDim = 40; + size_t tilingSize = sizeof(struct BlockwiseNf4TilingData); + BlockwiseNf4TilingData *tilingHost; + tilingHost = (struct BlockwiseNf4TilingData *)malloc(tilingSize); + uint32_t error = get_dequantize_blockwise_nf4_tiling(blocksize, n, tilingHost); + if (error != 0) { + printf("An error occurred.\n"); + } + if (type_mode == 1) { + ACLRT_LAUNCH_KERNEL(dequantize_blockwise_fp32_nf4)( + blockDim, + stream, + A, + absmax, + out, + tilingHost->blocksize, + tilingHost->coreNum, + tilingHost->singleCoreNumel, + tilingHost->singleCoreNumelTail, + tilingHost->numel, + tilingHost->ubSize + ); + } else if (type_mode == 2) { + ACLRT_LAUNCH_KERNEL(dequantize_blockwise_fp16_nf4)( + blockDim, + stream, + A, + absmax, + out, + tilingHost->blocksize, + tilingHost->coreNum, + tilingHost->singleCoreNumel, + tilingHost->singleCoreNumelTail, + tilingHost->numel, + tilingHost->ubSize + ); + } +} + +} diff --git a/csrc/npu_ops.h b/csrc/npu_ops.h new file mode 100644 index 000000000..491ef007b --- /dev/null +++ b/csrc/npu_ops.h @@ -0,0 +1,20 @@ +#ifndef NPU_OPS_H +#define NPU_OPS_H +#include + + +struct BlockwiseNf4TilingData { + uint32_t coreNum; + uint32_t blocksize; + uint32_t numel; + uint32_t singleCoreNumel; + uint32_t singleCoreNumelTail; + uint32_t ubSize; +}; + +extern "C" { + +void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream, const uint32_t type_mode); + +} +#endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 9c4cab9cc..a8c15cf13 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -12,6 +12,9 @@ #if BUILD_MPS // #include #endif +#if BUILD_NPU +#include +#endif #include // Compatibility between HIP/CUDA APIs @@ -658,6 +661,22 @@ void cgemm_4bit_inference_naive_fp32( #endif +#if BUILD_NPU + +void cdequantize_blockwise_fp32_nf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, + void* stream +) { + dequantizeBlockwiseNf4(A, absmax, out, blocksize, n, stream, 1); +} + +void cdequantize_blockwise_fp16_nf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, + void* stream +) { + dequantizeBlockwiseNf4(A, absmax, out, blocksize, n, stream, 2); +} + +#endif + void cquantize_blockwise_cpu_fp32( float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n ) { diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index e61ce4655..cc5c57913 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -263,10 +263,19 @@ pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git Please refer to [the official Ascend installations instructions](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/configandinstg/instg/insg_0001.html) for guidance on how to install the necessary `torch_npu` dependency. +> **Before building on Ascend NPU:** +> Make sure to source the CANN environment script (adjust path based on your installation): +> +> ```bash +> source /usr/local/Ascend/ascend-toolkit/set_env.sh +> ``` +> +> CANN toolkit can be downloaded from: [https://www.hiascend.com/zh/developer/download/community/result?module=cann](https://www.hiascend.com/zh/developer/download/community/result?module=cann) + ```bash # Install bitsandbytes from source -# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch -git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ +# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on main branch +git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ # Compile & install apt-get install -y build-essential cmake # install build tools dependencies, unless present