Skip to content

Commit e824660

Browse files
sayakpaula-r-r-o-w
andauthored
fix: caching allocator behaviour for quantization. (#12172)
* fix: caching allocator behaviour for quantization. * up * Update src/diffusers/models/model_loading_utils.py Co-authored-by: Aryan <[email protected]> --------- Co-authored-by: Aryan <[email protected]>
1 parent 03be15e commit e824660

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
726726
very large margin.
727727
"""
728728
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
729-
# Remove disk and cpu devices, and cast to proper torch.device
729+
730+
# Keep only accelerator devices
730731
accelerator_device_map = {
731732
param: torch.device(device)
732733
for param, device in expanded_device_map.items()
733734
if str(device) not in ["cpu", "disk"]
734735
}
735-
total_byte_count = defaultdict(lambda: 0)
736+
if not accelerator_device_map:
737+
return
738+
739+
elements_per_device = defaultdict(int)
736740
for param_name, device in accelerator_device_map.items():
737741
try:
738-
param = model.get_parameter(param_name)
742+
p = model.get_parameter(param_name)
739743
except AttributeError:
740-
param = model.get_buffer(param_name)
741-
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
742-
param_byte_count = param.numel() * param.element_size()
744+
try:
745+
p = model.get_buffer(param_name)
746+
except AttributeError:
747+
raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
743748
# TODO: account for TP when needed.
744-
total_byte_count[device] += param_byte_count
749+
elements_per_device[device] += p.numel()
745750

746751
# This will kick off the caching allocator to avoid having to Malloc afterwards
747-
for device, byte_count in total_byte_count.items():
748-
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
752+
for device, elem_count in elements_per_device.items():
753+
warmup_elems = max(1, elem_count // factor)
754+
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)

0 commit comments

Comments
 (0)