Skip to content
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ bitsandbytes has the following minimum requirements for all platforms:
#### Accelerator support:

<small>Note: this table reflects the status of the current development branch. For the latest stable release, see the
[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support).
[document in the 0.47.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.47.0/README.md#accelerator-support).
</small>

##### Legend:
Expand Down
3 changes: 1 addition & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops


if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch
Expand Down Expand Up @@ -76,4 +75,4 @@ def _import_backends():
"optim.optimizer.MockArgs": False,
}

__version__ = "0.47.0.dev0"
__version__ = "0.48.0.dev0"
40 changes: 40 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,46 @@ def to(self, *args, **kwargs):

return new_param

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

if func in [torch.chunk, torch.split]:
tensor = args[0]

result = super().__torch_function__(func, types, args, kwargs)

if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)
else:
return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)

return super().__torch_function__(func, types, args, kwargs)


def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
Expand Down
175 changes: 175 additions & 0 deletions bitsandbytes/nn/parametrize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from functools import partial
from typing import Any, Literal, Optional

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P

from .. import functional as F


class Bnb4bitParametrization(nn.Module):
"""
A parametrization module that handles dequantization of a 4-bit quantized parameter.

The parameter data is expected to be already quantized when this parametrization is applied.
This module will dequantize the parameter data to its original floating-point representation
when the forward method is called (i.e. when the parameter is accessed).

Args:
quant_state (`F.QuantState`):
The quantization state containing the necessary information for dequantization.
"""

def __init__(self, quant_state: F.QuantState, p_name="unknown"):
super().__init__()
self.quant_state = quant_state
self.p_name = p_name

def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
"""
Forward pass to dequantize the parameter.

Args:
quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original)

Returns:
`torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
"""
return F.dequantize_4bit(quantized_param, self.quant_state)


def replace_parameter_4bit_prequantized(
module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device
):
if not hasattr(module, param_name):
raise AttributeError(f"Module does not have parameter '{param_name}'")

original_param = getattr(module, param_name)

if not isinstance(original_param, nn.Parameter):
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")

quant_state = F.QuantState.from_dict(qs_dict, device=device)

# Apply a parametrization to the module to handle dequantization.
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)

# Next, register state dict hook for saving.
module.register_state_dict_post_hook(
partial(
_parametrized_state_dict_post_hook,
param_name=param_name,
)
)


def replace_parameter_4bit(
module: nn.Module,
param_name: str,
compress_statistics: bool = False,
quant_type: Literal["nf4", "fp4"] = "nf4",
blocksize: Optional[int] = None,
):
"""
Replace a module parameter with a 4-bit quantized version using parametrization.

This function quantizes an existing parameter in a PyTorch module to 4-bit precision
and sets up parametrization to handle automatic dequantization during forward passes.
The original parameter is replaced with quantized data, and a parametrization layer
is registered to manage the quantization state and dequantization process.

Additional, it registers a state dict post-hook to ensure that the quantization state
is saved correctly when the model's state dict is saved.

It is useful for MoE models or other scenarios where you want to quantize parameters
outside of nn.Linear layers without changing the model's architecture.

<Tip warning={true}>This feature is experimental and may change in future releases.</Tip>

Args:
module (`nn.Module`):
The PyTorch module containing the parameter to be quantized.
param_name (`str`):
The name of the parameter within the module to quantize.
compress_statistics (`bool`, *optional*, defaults to `False`):
Whether to compress quantization statistics to reduce memory usage.
quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`):
The quantization format to use.
blocksize (`int`, *optional*, defaults to `None`):
The block size for quantization. If None, uses the default block size.

Raises:
AttributeError: If the module does not have the specified parameter.
TypeError: If the specified attribute is not an instance of nn.Parameter.
"""

if not hasattr(module, param_name):
raise AttributeError(f"Module does not have parameter '{param_name}'")

original_param = getattr(module, param_name)

if not isinstance(original_param, nn.Parameter):
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")

# Quantize the original parameter.
quantized_data, quant_state = F.quantize_4bit(
original_param.data,
blocksize=blocksize,
compress_statistics=compress_statistics,
quant_type=quant_type,
)

# Replace the parameter with the quantized data.
setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False))
del original_param

# Apply a parametrization to the module to handle dequantization.
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)

# Next, register state dict hook for saving.
module.register_state_dict_post_hook(
partial(
_parametrized_state_dict_post_hook,
param_name=param_name,
)
)


def _parametrized_state_dict_post_hook(
module: nn.Module,
state_dict: dict[str, Any],
prefix: str,
local_metadata: Any,
*,
param_name: str = "weight",
**kwargs: dict[str, Any],
) -> None:
"""
Hook to modify the state dict to include the quantization state.
"""

original_key = f"{prefix}parametrizations.{param_name}.original"

if original_key in state_dict:
# Create a clean entry.
# The `parametrizations.{param_name}.original` key will have the quantized data,
# but we would like it to keep it in the state_dict as `{param_name}`.
clean_key = f"{prefix}{param_name}"
state_dict[clean_key] = state_dict.pop(original_key)

assert P.is_parametrized(module, param_name)

# Find the parametrization, which should have the quantization state.
parametrization: Bnb4bitParametrization = next(
filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
)

assert parametrization is not None, "Parametrization not found for the parameter."

quant_state = parametrization.quant_state

# Next, we need to store the quantization state.
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
state_dict[f"{prefix}{param_name}.{k}"] = v
11 changes: 4 additions & 7 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise(
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}

unsigned char packed_4bit = 0;
switch (DATA_TYPE) {
case General8bit:
#pragma unroll NUM_PER_TH
Expand All @@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise(
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def run(self):


setup(
version="0.47.0.dev0",
version="0.48.0.dev0",
packages=find_packages(),
distclass=BinaryDistribution,
cmake_source_dir=".",
Expand Down
61 changes: 46 additions & 15 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):

# With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
if blocksize <= 64:
assert err.item() < 0.1
assert relerr.item() < 0.28
elif blocksize <= 256:
assert err.item() < 0.11
assert relerr.item() < 0.30
elif blocksize <= 512:
assert err.item() < 0.12
assert relerr.item() < 0.31
elif quant_type == "fp4":
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
else:
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert err.item() < math.log2(blocksize) * 8e-2
#
# Actually, the above is not true anymore after fixing the integer packing bug.
# The following values were taken from averaging 1k samples per test configuration after fixing the bug.
error_dict = dict()
error_dict["fp4"] = dict()
error_dict["nf4"] = dict()
error_dict["fp4"]["err"] = {
64: 0.096545,
128: 0.102947,
256: 0.108685,
512: 0.114087,
1024: 0.119312,
2048: 0.124460,
4096: 0.129573,
}
error_dict["fp4"]["rel_err"] = {
64: 0.260130,
128: 0.275734,
256: 0.289842,
512: 0.302852,
1024: 0.314982,
2048: 0.326402,
4096: 0.337228,
}

error_dict["nf4"]["err"] = {
64: 0.072792,
128: 0.076835,
256: 0.080326,
512: 0.083535,
1024: 0.086603,
2048: 0.089592,
4096: 0.092537,
}
error_dict["nf4"]["rel_err"] = {
64: 0.203299,
128: 0.215252,
256: 0.226044,
512: 0.236021,
1024: 0.245365,
2048: 0.254146,
4096: 0.262457,
}

assert err < error_dict[quant_type]["err"][blocksize] + 1e-3
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
Expand Down
35 changes: 35 additions & 0 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_params4bit_torch_chunk_split(device, quant_type):
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
pytest.skip("This configuration is not supported on HPU.")

if device == "cpu":
pytest.skip("CPU quantization causes segfault, skipping CPU test")

original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")

params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)

if device != "cpu":
params4bit = params4bit.to(device)

chunks = torch.chunk(params4bit, 2, dim=0)

assert isinstance(chunks, tuple), "torch.chunk should return tuple"
for chunk in chunks:
assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
assert hasattr(chunk, "quant_type"), "Should preserve metadata"
assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"

splits = torch.split(params4bit, 2, dim=0)

assert isinstance(splits, tuple), "torch.split should return tuple"
assert len(splits) > 0, "Should have at least one split"
for split in splits:
assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
assert hasattr(split, "quant_type"), "Should preserve metadata"
assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
Expand Down
Loading