diff --git a/README.md b/README.md index 0d9e561ce..47510db9b 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ bitsandbytes has the following minimum requirements for all platforms: #### Accelerator support: Note: this table reflects the status of the current development branch. For the latest stable release, see the -[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support). +[document in the 0.47.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.47.0/README.md#accelerator-support). ##### Legend: diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 516afa51f..d58b7b441 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -38,7 +38,6 @@ if hasattr(torch, "xpu") and torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops - if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"): # In case not automatically imported import habana_frameworks.torch @@ -76,4 +75,4 @@ def _import_backends(): "optim.optimizer.MockArgs": False, } -__version__ = "0.47.0.dev0" +__version__ = "0.48.0.dev0" diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ba134f52a..e599643cc 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -356,6 +356,46 @@ def to(self, *args, **kwargs): return new_param + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func in [torch.chunk, torch.split]: + tensor = args[0] + + result = super().__torch_function__(func, types, args, kwargs) + + if isinstance(result, tuple): + return tuple( + cls( + data=chunk, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + for chunk in result + ) + else: + return cls( + data=result, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + + return super().__torch_function__(func, types, args, kwargs) + def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): if getattr(module.weight, "quant_state", None) is not None: diff --git a/bitsandbytes/nn/parametrize.py b/bitsandbytes/nn/parametrize.py new file mode 100644 index 000000000..d8f94c283 --- /dev/null +++ b/bitsandbytes/nn/parametrize.py @@ -0,0 +1,175 @@ +from functools import partial +from typing import Any, Literal, Optional + +import torch +import torch.nn as nn +import torch.nn.utils.parametrize as P + +from .. import functional as F + + +class Bnb4bitParametrization(nn.Module): + """ + A parametrization module that handles dequantization of a 4-bit quantized parameter. + + The parameter data is expected to be already quantized when this parametrization is applied. + This module will dequantize the parameter data to its original floating-point representation + when the forward method is called (i.e. when the parameter is accessed). + + Args: + quant_state (`F.QuantState`): + The quantization state containing the necessary information for dequantization. + """ + + def __init__(self, quant_state: F.QuantState, p_name="unknown"): + super().__init__() + self.quant_state = quant_state + self.p_name = p_name + + def forward(self, quantized_param: torch.Tensor) -> torch.Tensor: + """ + Forward pass to dequantize the parameter. + + Args: + quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original) + + Returns: + `torch.Tensor`: The dequantized parameter tensor in the original shape and dtype. + """ + return F.dequantize_4bit(quantized_param, self.quant_state) + + +def replace_parameter_4bit_prequantized( + module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device +): + if not hasattr(module, param_name): + raise AttributeError(f"Module does not have parameter '{param_name}'") + + original_param = getattr(module, param_name) + + if not isinstance(original_param, nn.Parameter): + raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter") + + quant_state = F.QuantState.from_dict(qs_dict, device=device) + + # Apply a parametrization to the module to handle dequantization. + P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True) + + # Next, register state dict hook for saving. + module.register_state_dict_post_hook( + partial( + _parametrized_state_dict_post_hook, + param_name=param_name, + ) + ) + + +def replace_parameter_4bit( + module: nn.Module, + param_name: str, + compress_statistics: bool = False, + quant_type: Literal["nf4", "fp4"] = "nf4", + blocksize: Optional[int] = None, +): + """ + Replace a module parameter with a 4-bit quantized version using parametrization. + + This function quantizes an existing parameter in a PyTorch module to 4-bit precision + and sets up parametrization to handle automatic dequantization during forward passes. + The original parameter is replaced with quantized data, and a parametrization layer + is registered to manage the quantization state and dequantization process. + + Additional, it registers a state dict post-hook to ensure that the quantization state + is saved correctly when the model's state dict is saved. + + It is useful for MoE models or other scenarios where you want to quantize parameters + outside of nn.Linear layers without changing the model's architecture. + + This feature is experimental and may change in future releases. + + Args: + module (`nn.Module`): + The PyTorch module containing the parameter to be quantized. + param_name (`str`): + The name of the parameter within the module to quantize. + compress_statistics (`bool`, *optional*, defaults to `False`): + Whether to compress quantization statistics to reduce memory usage. + quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`): + The quantization format to use. + blocksize (`int`, *optional*, defaults to `None`): + The block size for quantization. If None, uses the default block size. + + Raises: + AttributeError: If the module does not have the specified parameter. + TypeError: If the specified attribute is not an instance of nn.Parameter. + """ + + if not hasattr(module, param_name): + raise AttributeError(f"Module does not have parameter '{param_name}'") + + original_param = getattr(module, param_name) + + if not isinstance(original_param, nn.Parameter): + raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter") + + # Quantize the original parameter. + quantized_data, quant_state = F.quantize_4bit( + original_param.data, + blocksize=blocksize, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + + # Replace the parameter with the quantized data. + setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False)) + del original_param + + # Apply a parametrization to the module to handle dequantization. + P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True) + + # Next, register state dict hook for saving. + module.register_state_dict_post_hook( + partial( + _parametrized_state_dict_post_hook, + param_name=param_name, + ) + ) + + +def _parametrized_state_dict_post_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + local_metadata: Any, + *, + param_name: str = "weight", + **kwargs: dict[str, Any], +) -> None: + """ + Hook to modify the state dict to include the quantization state. + """ + + original_key = f"{prefix}parametrizations.{param_name}.original" + + if original_key in state_dict: + # Create a clean entry. + # The `parametrizations.{param_name}.original` key will have the quantized data, + # but we would like it to keep it in the state_dict as `{param_name}`. + clean_key = f"{prefix}{param_name}" + state_dict[clean_key] = state_dict.pop(original_key) + + assert P.is_parametrized(module, param_name) + + # Find the parametrization, which should have the quantization state. + parametrization: Bnb4bitParametrization = next( + filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None + ) + + assert parametrization is not None, "Parametrization not found for the parameter." + + quant_state = parametrization.quant_state + + # Next, we need to store the quantization state. + if quant_state is not None: + for k, v in quant_state.as_dict(packed=True).items(): + state_dict[f"{prefix}{param_name}.{k}"] = v diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 649f2ee1f..97b80f050 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise( LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - unsigned char packed_4bit = 0; switch (DATA_TYPE) { case General8bit: #pragma unroll NUM_PER_TH @@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise( case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { - packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; - packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); } break; case NF4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { - packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; - packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; } diff --git a/setup.py b/setup.py index 7aa50c1b8..a04630b8a 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def run(self): setup( - version="0.47.0.dev0", + version="0.48.0.dev0", packages=find_packages(), distclass=BinaryDistribution, cmake_source_dir=".", diff --git a/tests/test_functional.py b/tests/test_functional.py index b84db6502..fc37cb4c3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # With larger block sizes, we can expect this to blow up. # At blocksize>=1024, don't even bother looking at relerr. - if blocksize <= 64: - assert err.item() < 0.1 - assert relerr.item() < 0.28 - elif blocksize <= 256: - assert err.item() < 0.11 - assert relerr.item() < 0.30 - elif blocksize <= 512: - assert err.item() < 0.12 - assert relerr.item() < 0.31 - elif quant_type == "fp4": - # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 - assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 - else: - # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 - assert err.item() < math.log2(blocksize) * 8e-2 + # + # Actually, the above is not true anymore after fixing the integer packing bug. + # The following values were taken from averaging 1k samples per test configuration after fixing the bug. + error_dict = dict() + error_dict["fp4"] = dict() + error_dict["nf4"] = dict() + error_dict["fp4"]["err"] = { + 64: 0.096545, + 128: 0.102947, + 256: 0.108685, + 512: 0.114087, + 1024: 0.119312, + 2048: 0.124460, + 4096: 0.129573, + } + error_dict["fp4"]["rel_err"] = { + 64: 0.260130, + 128: 0.275734, + 256: 0.289842, + 512: 0.302852, + 1024: 0.314982, + 2048: 0.326402, + 4096: 0.337228, + } + + error_dict["nf4"]["err"] = { + 64: 0.072792, + 128: 0.076835, + 256: 0.080326, + 512: 0.083535, + 1024: 0.086603, + 2048: 0.089592, + 4096: 0.092537, + } + error_dict["nf4"]["rel_err"] = { + 64: 0.203299, + 128: 0.215252, + 256: 0.226044, + 512: 0.236021, + 1024: 0.245365, + 2048: 0.254146, + 4096: 0.262457, + } + + assert err < error_dict[quant_type]["err"][blocksize] + 1e-3 + assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3 @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index e07b54d2d..1c5e77a32 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +def test_params4bit_torch_chunk_split(device, quant_type): + """Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility.""" + if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8): + pytest.skip("This configuration is not supported on HPU.") + + if device == "cpu": + pytest.skip("CPU quantization causes segfault, skipping CPU test") + + original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu") + + params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False) + + if device != "cpu": + params4bit = params4bit.to(device) + + chunks = torch.chunk(params4bit, 2, dim=0) + + assert isinstance(chunks, tuple), "torch.chunk should return tuple" + for chunk in chunks: + assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass" + assert hasattr(chunk, "quant_type"), "Should preserve metadata" + assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value" + + splits = torch.split(params4bit, 2, dim=0) + + assert isinstance(splits, tuple), "torch.split should return tuple" + assert len(splits) > 0, "Should have at least one split" + for split in splits: + assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass" + assert hasattr(split, "quant_type"), "Should preserve metadata" + assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value" + + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])