diff --git a/README.md b/README.md
index 0d9e561ce..47510db9b 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ bitsandbytes has the following minimum requirements for all platforms:
#### Accelerator support:
Note: this table reflects the status of the current development branch. For the latest stable release, see the
-[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support).
+[document in the 0.47.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.47.0/README.md#accelerator-support).
##### Legend:
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 516afa51f..d58b7b441 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -38,7 +38,6 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops
-
if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch
@@ -76,4 +75,4 @@ def _import_backends():
"optim.optimizer.MockArgs": False,
}
-__version__ = "0.47.0.dev0"
+__version__ = "0.48.0.dev0"
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index ba134f52a..e599643cc 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -356,6 +356,46 @@ def to(self, *args, **kwargs):
return new_param
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ if func in [torch.chunk, torch.split]:
+ tensor = args[0]
+
+ result = super().__torch_function__(func, types, args, kwargs)
+
+ if isinstance(result, tuple):
+ return tuple(
+ cls(
+ data=chunk,
+ requires_grad=tensor.requires_grad,
+ quant_state=tensor.quant_state,
+ blocksize=tensor.blocksize,
+ compress_statistics=tensor.compress_statistics,
+ quant_type=tensor.quant_type,
+ quant_storage=tensor.quant_storage,
+ module=tensor.module,
+ bnb_quantized=tensor.bnb_quantized,
+ )
+ for chunk in result
+ )
+ else:
+ return cls(
+ data=result,
+ requires_grad=tensor.requires_grad,
+ quant_state=tensor.quant_state,
+ blocksize=tensor.blocksize,
+ compress_statistics=tensor.compress_statistics,
+ quant_type=tensor.quant_type,
+ quant_storage=tensor.quant_storage,
+ module=tensor.module,
+ bnb_quantized=tensor.bnb_quantized,
+ )
+
+ return super().__torch_function__(func, types, args, kwargs)
+
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
diff --git a/bitsandbytes/nn/parametrize.py b/bitsandbytes/nn/parametrize.py
new file mode 100644
index 000000000..d8f94c283
--- /dev/null
+++ b/bitsandbytes/nn/parametrize.py
@@ -0,0 +1,175 @@
+from functools import partial
+from typing import Any, Literal, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.utils.parametrize as P
+
+from .. import functional as F
+
+
+class Bnb4bitParametrization(nn.Module):
+ """
+ A parametrization module that handles dequantization of a 4-bit quantized parameter.
+
+ The parameter data is expected to be already quantized when this parametrization is applied.
+ This module will dequantize the parameter data to its original floating-point representation
+ when the forward method is called (i.e. when the parameter is accessed).
+
+ Args:
+ quant_state (`F.QuantState`):
+ The quantization state containing the necessary information for dequantization.
+ """
+
+ def __init__(self, quant_state: F.QuantState, p_name="unknown"):
+ super().__init__()
+ self.quant_state = quant_state
+ self.p_name = p_name
+
+ def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass to dequantize the parameter.
+
+ Args:
+ quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original)
+
+ Returns:
+ `torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
+ """
+ return F.dequantize_4bit(quantized_param, self.quant_state)
+
+
+def replace_parameter_4bit_prequantized(
+ module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device
+):
+ if not hasattr(module, param_name):
+ raise AttributeError(f"Module does not have parameter '{param_name}'")
+
+ original_param = getattr(module, param_name)
+
+ if not isinstance(original_param, nn.Parameter):
+ raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
+
+ quant_state = F.QuantState.from_dict(qs_dict, device=device)
+
+ # Apply a parametrization to the module to handle dequantization.
+ P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
+
+ # Next, register state dict hook for saving.
+ module.register_state_dict_post_hook(
+ partial(
+ _parametrized_state_dict_post_hook,
+ param_name=param_name,
+ )
+ )
+
+
+def replace_parameter_4bit(
+ module: nn.Module,
+ param_name: str,
+ compress_statistics: bool = False,
+ quant_type: Literal["nf4", "fp4"] = "nf4",
+ blocksize: Optional[int] = None,
+):
+ """
+ Replace a module parameter with a 4-bit quantized version using parametrization.
+
+ This function quantizes an existing parameter in a PyTorch module to 4-bit precision
+ and sets up parametrization to handle automatic dequantization during forward passes.
+ The original parameter is replaced with quantized data, and a parametrization layer
+ is registered to manage the quantization state and dequantization process.
+
+ Additional, it registers a state dict post-hook to ensure that the quantization state
+ is saved correctly when the model's state dict is saved.
+
+ It is useful for MoE models or other scenarios where you want to quantize parameters
+ outside of nn.Linear layers without changing the model's architecture.
+
+ This feature is experimental and may change in future releases.
+
+ Args:
+ module (`nn.Module`):
+ The PyTorch module containing the parameter to be quantized.
+ param_name (`str`):
+ The name of the parameter within the module to quantize.
+ compress_statistics (`bool`, *optional*, defaults to `False`):
+ Whether to compress quantization statistics to reduce memory usage.
+ quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`):
+ The quantization format to use.
+ blocksize (`int`, *optional*, defaults to `None`):
+ The block size for quantization. If None, uses the default block size.
+
+ Raises:
+ AttributeError: If the module does not have the specified parameter.
+ TypeError: If the specified attribute is not an instance of nn.Parameter.
+ """
+
+ if not hasattr(module, param_name):
+ raise AttributeError(f"Module does not have parameter '{param_name}'")
+
+ original_param = getattr(module, param_name)
+
+ if not isinstance(original_param, nn.Parameter):
+ raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
+
+ # Quantize the original parameter.
+ quantized_data, quant_state = F.quantize_4bit(
+ original_param.data,
+ blocksize=blocksize,
+ compress_statistics=compress_statistics,
+ quant_type=quant_type,
+ )
+
+ # Replace the parameter with the quantized data.
+ setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False))
+ del original_param
+
+ # Apply a parametrization to the module to handle dequantization.
+ P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
+
+ # Next, register state dict hook for saving.
+ module.register_state_dict_post_hook(
+ partial(
+ _parametrized_state_dict_post_hook,
+ param_name=param_name,
+ )
+ )
+
+
+def _parametrized_state_dict_post_hook(
+ module: nn.Module,
+ state_dict: dict[str, Any],
+ prefix: str,
+ local_metadata: Any,
+ *,
+ param_name: str = "weight",
+ **kwargs: dict[str, Any],
+) -> None:
+ """
+ Hook to modify the state dict to include the quantization state.
+ """
+
+ original_key = f"{prefix}parametrizations.{param_name}.original"
+
+ if original_key in state_dict:
+ # Create a clean entry.
+ # The `parametrizations.{param_name}.original` key will have the quantized data,
+ # but we would like it to keep it in the state_dict as `{param_name}`.
+ clean_key = f"{prefix}{param_name}"
+ state_dict[clean_key] = state_dict.pop(original_key)
+
+ assert P.is_parametrized(module, param_name)
+
+ # Find the parametrization, which should have the quantization state.
+ parametrization: Bnb4bitParametrization = next(
+ filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
+ )
+
+ assert parametrization is not None, "Parametrization not found for the parameter."
+
+ quant_state = parametrization.quant_state
+
+ # Next, we need to store the quantization state.
+ if quant_state is not None:
+ for k, v in quant_state.as_dict(packed=True).items():
+ state_dict[f"{prefix}{param_name}.{k}"] = v
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 649f2ee1f..97b80f050 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise(
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
- unsigned char packed_4bit = 0;
switch (DATA_TYPE) {
case General8bit:
#pragma unroll NUM_PER_TH
@@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise(
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
- packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
- packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
- qvals[j] = packed_4bit;
+ qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
+ qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
- packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
- packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
- qvals[j] = packed_4bit;
+ qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
+ qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}
diff --git a/setup.py b/setup.py
index 7aa50c1b8..a04630b8a 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,7 @@ def run(self):
setup(
- version="0.47.0.dev0",
+ version="0.48.0.dev0",
packages=find_packages(),
distclass=BinaryDistribution,
cmake_source_dir=".",
diff --git a/tests/test_functional.py b/tests/test_functional.py
index b84db6502..fc37cb4c3 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
# With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
- if blocksize <= 64:
- assert err.item() < 0.1
- assert relerr.item() < 0.28
- elif blocksize <= 256:
- assert err.item() < 0.11
- assert relerr.item() < 0.30
- elif blocksize <= 512:
- assert err.item() < 0.12
- assert relerr.item() < 0.31
- elif quant_type == "fp4":
- # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
- assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
- else:
- # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
- assert err.item() < math.log2(blocksize) * 8e-2
+ #
+ # Actually, the above is not true anymore after fixing the integer packing bug.
+ # The following values were taken from averaging 1k samples per test configuration after fixing the bug.
+ error_dict = dict()
+ error_dict["fp4"] = dict()
+ error_dict["nf4"] = dict()
+ error_dict["fp4"]["err"] = {
+ 64: 0.096545,
+ 128: 0.102947,
+ 256: 0.108685,
+ 512: 0.114087,
+ 1024: 0.119312,
+ 2048: 0.124460,
+ 4096: 0.129573,
+ }
+ error_dict["fp4"]["rel_err"] = {
+ 64: 0.260130,
+ 128: 0.275734,
+ 256: 0.289842,
+ 512: 0.302852,
+ 1024: 0.314982,
+ 2048: 0.326402,
+ 4096: 0.337228,
+ }
+
+ error_dict["nf4"]["err"] = {
+ 64: 0.072792,
+ 128: 0.076835,
+ 256: 0.080326,
+ 512: 0.083535,
+ 1024: 0.086603,
+ 2048: 0.089592,
+ 4096: 0.092537,
+ }
+ error_dict["nf4"]["rel_err"] = {
+ 64: 0.203299,
+ 128: 0.215252,
+ 256: 0.226044,
+ 512: 0.236021,
+ 1024: 0.245365,
+ 2048: 0.254146,
+ 4096: 0.262457,
+ }
+
+ assert err < error_dict[quant_type]["err"][blocksize] + 1e-3
+ assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py
index e07b54d2d..1c5e77a32 100644
--- a/tests/test_linear4bit.py
+++ b/tests/test_linear4bit.py
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
+def test_params4bit_torch_chunk_split(device, quant_type):
+ """Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
+ if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
+ pytest.skip("This configuration is not supported on HPU.")
+
+ if device == "cpu":
+ pytest.skip("CPU quantization causes segfault, skipping CPU test")
+
+ original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")
+
+ params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)
+
+ if device != "cpu":
+ params4bit = params4bit.to(device)
+
+ chunks = torch.chunk(params4bit, 2, dim=0)
+
+ assert isinstance(chunks, tuple), "torch.chunk should return tuple"
+ for chunk in chunks:
+ assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
+ assert hasattr(chunk, "quant_type"), "Should preserve metadata"
+ assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"
+
+ splits = torch.split(params4bit, 2, dim=0)
+
+ assert isinstance(splits, tuple), "torch.split should return tuple"
+ assert len(splits) > 0, "Should have at least one split"
+ for split in splits:
+ assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
+ assert hasattr(split, "quant_type"), "Should preserve metadata"
+ assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
+
+
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])