diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 332a6ce49b8c..2e07f55e0064 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -726,23 +726,29 @@ def _caching_allocator_warmup( very large margin. """ factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor() - # Remove disk and cpu devices, and cast to proper torch.device + + # Keep only accelerator devices accelerator_device_map = { param: torch.device(device) for param, device in expanded_device_map.items() if str(device) not in ["cpu", "disk"] } - total_byte_count = defaultdict(lambda: 0) + if not accelerator_device_map: + return + + elements_per_device = defaultdict(int) for param_name, device in accelerator_device_map.items(): try: - param = model.get_parameter(param_name) + p = model.get_parameter(param_name) except AttributeError: - param = model.get_buffer(param_name) - # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` - param_byte_count = param.numel() * param.element_size() + try: + p = model.get_buffer(param_name) + except AttributeError: + raise AttributeError(f"Parameter or buffer with name={param_name} not found in model") # TODO: account for TP when needed. - total_byte_count[device] += param_byte_count + elements_per_device[device] += p.numel() # This will kick off the caching allocator to avoid having to Malloc afterwards - for device, byte_count in total_byte_count.items(): - _ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False) + for device, elem_count in elements_per_device.items(): + warmup_elems = max(1, elem_count // factor) + _ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)