diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c9f5ece60..b2d3b5266 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -462,12 +462,13 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState # unpacking tensor with non-tensor components qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and "quant_type" not in qs_dict: - raise ValueError("Expected packed or unpacked quant_state items, found neither") - elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError( - f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", - ) + if "quant_type" not in qs_dict: + if not qs_key: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: @@ -500,7 +501,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState ) return quant_state - def as_dict(self, packed=False): + def as_dict(self, packed: bool = False) -> dict[str, Any]: """ returns dict of tensors and strings to use in serialization via _save_to_state_dict() param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving @@ -511,7 +512,7 @@ def as_dict(self, packed=False): "blocksize": self.blocksize, "quant_map": self.code, "dtype": str(self.dtype).strip("torch."), - "shape": tuple(self.shape), + "shape": tuple(self.shape) if self.shape is not None else None, } if self.nested: qs_dict.update( @@ -529,7 +530,10 @@ def as_dict(self, packed=False): # packed format allows serialization of non-tensor components, critical for saving in safetensors format qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} - qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) + key = "quant_state.bitsandbytes__" + if self.quant_type is not None: + key += self.quant_type + qs_packed_dict[key] = pack_dict_to_tensor(non_tensor_dict) return qs_packed_dict def to(self, device): diff --git a/tests/test_functional.py b/tests/test_functional.py index 6a008d847..a184d4298 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -118,6 +118,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, for i in range(iters): A1 = torch.randn(1024, 1024, device=device, dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + if i == 0: + d = S.as_dict() + S = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_blockwise(C, S) diff = torch.abs(A1 - A2).float() reldiff = diff / torch.abs(A1.float() + 1e-8) @@ -134,6 +137,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, for i in range(iters): A1 = torch.rand(1024, 1024, device=device, dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code) + if i == 0: + d = S.as_dict() + S = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_blockwise(C, S) diff = torch.abs(A1 - A2).float() reldiff = diff / torch.abs(A1.float() + 1e-8) @@ -243,6 +249,9 @@ def test_fp8_quant(self, device): for i in range(10): A1 = torch.randn(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1, code=code) + if i == 0: + d = SC.as_dict() + SC = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) reldiff = diff / torch.abs(A1 + 1e-8) @@ -1116,6 +1125,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + d = SA.as_dict() + SA = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) err = (A1 - A2).abs().float()