Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,13 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState

# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and "quant_type" not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
)
if "quant_type" not in qs_dict:
if not qs_key:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
)

# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
Expand Down Expand Up @@ -500,7 +501,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
)
return quant_state

def as_dict(self, packed=False):
def as_dict(self, packed: bool = False) -> dict[str, Any]:
"""
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
Expand All @@ -511,7 +512,7 @@ def as_dict(self, packed=False):
"blocksize": self.blocksize,
"quant_map": self.code,
"dtype": str(self.dtype).strip("torch."),
"shape": tuple(self.shape),
"shape": tuple(self.shape) if self.shape is not None else None,
}
if self.nested:
qs_dict.update(
Expand All @@ -529,7 +530,10 @@ def as_dict(self, packed=False):
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
key = "quant_state.bitsandbytes__"
if self.quant_type is not None:
key += self.quant_type
qs_packed_dict[key] = pack_dict_to_tensor(non_tensor_dict)
return qs_packed_dict

def to(self, device):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
for i in range(iters):
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
if i == 0:
d = S.as_dict()
S = F.QuantState.from_dict(d, device=torch.device(device))
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
Expand All @@ -134,6 +137,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
for i in range(iters):
A1 = torch.rand(1024, 1024, device=device, dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
if i == 0:
d = S.as_dict()
S = F.QuantState.from_dict(d, device=torch.device(device))
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
Expand Down Expand Up @@ -243,6 +249,9 @@ def test_fp8_quant(self, device):
for i in range(10):
A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code)
if i == 0:
d = SC.as_dict()
SC = F.QuantState.from_dict(d, device=torch.device(device))
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
Expand Down Expand Up @@ -1116,6 +1125,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):

A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
d = SA.as_dict()
SA = F.QuantState.from_dict(d, device=torch.device(device))
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)

err = (A1 - A2).abs().float()
Expand Down