Skip to content

Commit 7a3f542

Browse files
committed
Fix QuantState.as_dict
Signed-off-by: cyy <[email protected]>
1 parent 6e371de commit 7a3f542

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

bitsandbytes/functional.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
502502
)
503503
return quant_state
504504

505-
def as_dict(self, packed=False):
505+
def as_dict(self, packed: bool = False) -> dict[str, Any]:
506506
"""
507507
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
508508
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
@@ -531,7 +531,10 @@ def as_dict(self, packed=False):
531531
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
532532
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
533533
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
534-
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
534+
key = "quant_state.bitsandbytes__"
535+
if self.quant_type is not None:
536+
key += self.quant_type
537+
qs_packed_dict[key] = pack_dict_to_tensor(non_tensor_dict)
535538
return qs_packed_dict
536539

537540
def to(self, device):

0 commit comments

Comments
 (0)