Skip to content

Commit 61fdac5

Browse files
Enable bnb 4bit nn.Parameter from prequantized checkpoint
1 parent a282c57 commit 61fdac5

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

src/transformers/quantizers/quantizer_bnb_4bit.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -236,17 +236,35 @@ def create_quantized_param(
236236
if unexpected_keys is not None and k in unexpected_keys:
237237
unexpected_keys.remove(k)
238238

239-
param_kwargs = {}
240-
if self.is_bnb_supports_quant_storage_module:
241-
param_kwargs["module"] = module
242-
243-
module._parameters[tensor_name] = bnb.nn.Params4bit.from_prequantized(
244-
data=param_value,
245-
quantized_stats=quantized_stats,
246-
requires_grad=False,
247-
device=target_device,
248-
**param_kwargs,
249-
)
239+
if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
240+
param_kwargs = {}
241+
if self.is_bnb_supports_quant_storage_module:
242+
param_kwargs["module"] = module
243+
244+
module._parameters[tensor_name] = bnb.nn.Params4bit.from_prequantized(
245+
data=param_value,
246+
quantized_stats=quantized_stats,
247+
requires_grad=False,
248+
device=target_device,
249+
**param_kwargs,
250+
)
251+
elif self.quantization_config.bnb_4bit_target_paarameters:
252+
# Normal nn.Parameter, i.e. outside of a Linear4bit layer.
253+
import bitsandbytes.nn.parametrize
254+
255+
# Load the parameter on the target device
256+
module._parameters[tensor_name] = torch.nn.Parameter(
257+
param_value.to(target_device), requires_grad=False
258+
)
259+
260+
# Apply the bitsandbytes parametrization to support dequantization
261+
bitsandbytes.nn.parametrize.replace_parameter_4bit_prequantized(
262+
module,
263+
tensor_name,
264+
qs_dict=quantized_stats,
265+
device=target_device,
266+
)
267+
250268
else:
251269
new_value = param_value.to("cpu")
252270

@@ -359,20 +377,17 @@ def _process_model_before_weight_loading(
359377
]
360378

361379
if any(matched_params):
362-
import bitsandbytes.nn.parametrize
363-
364380
for param_name in matched_params:
365381
module, tensor_name = get_module_from_name(model, param_name)
366382

367-
# Fake quantize/replace parameter - we're in `init_empty_weights`
368-
# TODO: we could probably just infer the dtype/shape
369-
quantized_data, quant_state = bitsandbytes.functional.quantize_4bit(
370-
model.get_parameter(param_name).data,
371-
compress_statistics=self.quantization_config.bnb_4bit_use_double_quant,
372-
quant_type=self.quantization_config.bnb_4bit_quant_type,
383+
param = model.get_parameter(param_name)
384+
385+
quant_param = torch.nn.Parameter(
386+
torch.empty((param.numel() + 1) // 2, dtype=torch.uint8),
387+
requires_grad=False,
373388
)
374389

375-
setattr(module, tensor_name, torch.nn.Parameter(quantized_data, requires_grad=False))
390+
setattr(module, tensor_name, quant_param)
376391

377392
model.config.quantization_config = self.quantization_config
378393

src/transformers/utils/quantization_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,9 +581,9 @@ def post_init(self):
581581
"4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
582582
)
583583

584-
if self.bnb_4bit_target_parameters is not None and bnb_version < version.parse("0.47.0"):
584+
if self.bnb_4bit_target_parameters is not None and bnb_version < version.parse("0.48.0"):
585585
raise ValueError(
586-
"bnb_4bit_target_parameters requires bitsandbytes>=0.47.0 - please upgrade your bitsandbytes version"
586+
"bnb_4bit_target_parameters requires bitsandbytes>=0.48.0 - please upgrade your bitsandbytes version"
587587
)
588588

589589
def is_quantizable(self):

0 commit comments

Comments
 (0)