From 1bc34f556c2925c886d1eafc363572ad89f81c9c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Aug 2025 10:15:15 +0530 Subject: [PATCH 1/3] fix: caching allocator behaviour for quantization. --- src/diffusers/models/model_loading_utils.py | 30 ++++++++++++--------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 332a6ce49b8c..9af3cd189069 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -726,23 +726,29 @@ def _caching_allocator_warmup( very large margin. """ factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor() - # Remove disk and cpu devices, and cast to proper torch.device + + # Keep only accelerator devices accelerator_device_map = { param: torch.device(device) for param, device in expanded_device_map.items() - if str(device) not in ["cpu", "disk"] + if str(device) not in ["cpu", "disk", "meta"] } - total_byte_count = defaultdict(lambda: 0) - for param_name, device in accelerator_device_map.items(): + if not accelerator_device_map: + return + + elements_per_device = defaultdict(int) + for name, device in accelerator_device_map.items(): try: - param = model.get_parameter(param_name) + p = model.get_parameter(name) except AttributeError: - param = model.get_buffer(param_name) - # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` - param_byte_count = param.numel() * param.element_size() - # TODO: account for TP when needed. - total_byte_count[device] += param_byte_count + try: + p = model.get_buffer(name) + except AttributeError: + raise AttributeError(f"Parameter {name} not found in model") + + elements_per_device[device] += p.numel() # This will kick off the caching allocator to avoid having to Malloc afterwards - for device, byte_count in total_byte_count.items(): - _ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False) + for device, elem_count in elements_per_device.items(): + warmup_elems = max(1, elem_count // factor) + _ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False) From 51588508b823fdb3156b4c0d7846f4be33b7ce4b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Aug 2025 10:34:44 +0530 Subject: [PATCH 2/3] up --- src/diffusers/models/model_loading_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 9af3cd189069..a69f9a752aae 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -731,21 +731,21 @@ def _caching_allocator_warmup( accelerator_device_map = { param: torch.device(device) for param, device in expanded_device_map.items() - if str(device) not in ["cpu", "disk", "meta"] + if str(device) not in ["cpu", "disk"] } if not accelerator_device_map: return elements_per_device = defaultdict(int) - for name, device in accelerator_device_map.items(): + for param_name, device in accelerator_device_map.items(): try: - p = model.get_parameter(name) + p = model.get_parameter(param_name) except AttributeError: try: - p = model.get_buffer(name) + p = model.get_buffer(param_name) except AttributeError: - raise AttributeError(f"Parameter {name} not found in model") - + raise AttributeError(f"Parameter {param_name} not found in model") + # TODO: account for TP when needed. elements_per_device[device] += p.numel() # This will kick off the caching allocator to avoid having to Malloc afterwards From 6deece1cb11fdbc1082cbce89492ede90c0731a1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 18 Aug 2025 10:59:19 +0530 Subject: [PATCH 3/3] Update src/diffusers/models/model_loading_utils.py Co-authored-by: Aryan --- src/diffusers/models/model_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index a69f9a752aae..2e07f55e0064 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -744,7 +744,7 @@ def _caching_allocator_warmup( try: p = model.get_buffer(param_name) except AttributeError: - raise AttributeError(f"Parameter {param_name} not found in model") + raise AttributeError(f"Parameter or buffer with name={param_name} not found in model") # TODO: account for TP when needed. elements_per_device[device] += p.numel()