@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
726
726
very large margin.
727
727
"""
728
728
factor = 2 if hf_quantizer is None else hf_quantizer .get_cuda_warm_up_factor ()
729
- # Remove disk and cpu devices, and cast to proper torch.device
729
+
730
+ # Keep only accelerator devices
730
731
accelerator_device_map = {
731
732
param : torch .device (device )
732
733
for param , device in expanded_device_map .items ()
733
734
if str (device ) not in ["cpu" , "disk" ]
734
735
}
735
- total_byte_count = defaultdict (lambda : 0 )
736
+ if not accelerator_device_map :
737
+ return
738
+
739
+ elements_per_device = defaultdict (int )
736
740
for param_name , device in accelerator_device_map .items ():
737
741
try :
738
- param = model .get_parameter (param_name )
742
+ p = model .get_parameter (param_name )
739
743
except AttributeError :
740
- param = model .get_buffer (param_name )
741
- # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
742
- param_byte_count = param .numel () * param .element_size ()
744
+ try :
745
+ p = model .get_buffer (param_name )
746
+ except AttributeError :
747
+ raise AttributeError (f"Parameter or buffer with name={ param_name } not found in model" )
743
748
# TODO: account for TP when needed.
744
- total_byte_count [device ] += param_byte_count
749
+ elements_per_device [device ] += p . numel ()
745
750
746
751
# This will kick off the caching allocator to avoid having to Malloc afterwards
747
- for device , byte_count in total_byte_count .items ():
748
- _ = torch .empty (byte_count // factor , dtype = dtype , device = device , requires_grad = False )
752
+ for device , elem_count in elements_per_device .items ():
753
+ warmup_elems = max (1 , elem_count // factor )
754
+ _ = torch .empty (warmup_elems , dtype = dtype , device = device , requires_grad = False )
0 commit comments