@@ -236,17 +236,35 @@ def create_quantized_param(
236
236
if unexpected_keys is not None and k in unexpected_keys :
237
237
unexpected_keys .remove (k )
238
238
239
- param_kwargs = {}
240
- if self .is_bnb_supports_quant_storage_module :
241
- param_kwargs ["module" ] = module
242
-
243
- module ._parameters [tensor_name ] = bnb .nn .Params4bit .from_prequantized (
244
- data = param_value ,
245
- quantized_stats = quantized_stats ,
246
- requires_grad = False ,
247
- device = target_device ,
248
- ** param_kwargs ,
249
- )
239
+ if isinstance (module ._parameters [tensor_name ], bnb .nn .Params4bit ):
240
+ param_kwargs = {}
241
+ if self .is_bnb_supports_quant_storage_module :
242
+ param_kwargs ["module" ] = module
243
+
244
+ module ._parameters [tensor_name ] = bnb .nn .Params4bit .from_prequantized (
245
+ data = param_value ,
246
+ quantized_stats = quantized_stats ,
247
+ requires_grad = False ,
248
+ device = target_device ,
249
+ ** param_kwargs ,
250
+ )
251
+ elif self .quantization_config .bnb_4bit_target_paarameters :
252
+ # Normal nn.Parameter, i.e. outside of a Linear4bit layer.
253
+ import bitsandbytes .nn .parametrize
254
+
255
+ # Load the parameter on the target device
256
+ module ._parameters [tensor_name ] = torch .nn .Parameter (
257
+ param_value .to (target_device ), requires_grad = False
258
+ )
259
+
260
+ # Apply the bitsandbytes parametrization to support dequantization
261
+ bitsandbytes .nn .parametrize .replace_parameter_4bit_prequantized (
262
+ module ,
263
+ tensor_name ,
264
+ qs_dict = quantized_stats ,
265
+ device = target_device ,
266
+ )
267
+
250
268
else :
251
269
new_value = param_value .to ("cpu" )
252
270
@@ -359,20 +377,17 @@ def _process_model_before_weight_loading(
359
377
]
360
378
361
379
if any (matched_params ):
362
- import bitsandbytes .nn .parametrize
363
-
364
380
for param_name in matched_params :
365
381
module , tensor_name = get_module_from_name (model , param_name )
366
382
367
- # Fake quantize/replace parameter - we're in `init_empty_weights`
368
- # TODO: we could probably just infer the dtype/shape
369
- quantized_data , quant_state = bitsandbytes .functional .quantize_4bit (
370
- model .get_parameter (param_name ).data ,
371
- compress_statistics = self .quantization_config .bnb_4bit_use_double_quant ,
372
- quant_type = self .quantization_config .bnb_4bit_quant_type ,
383
+ param = model .get_parameter (param_name )
384
+
385
+ quant_param = torch .nn .Parameter (
386
+ torch .empty ((param .numel () + 1 ) // 2 , dtype = torch .uint8 ),
387
+ requires_grad = False ,
373
388
)
374
389
375
- setattr (module , tensor_name , torch . nn . Parameter ( quantized_data , requires_grad = False ) )
390
+ setattr (module , tensor_name , quant_param )
376
391
377
392
model .config .quantization_config = self .quantization_config
378
393
0 commit comments