|
17 | 17 | import functools |
18 | 18 | import importlib |
19 | 19 | import inspect |
20 | | -import math |
21 | 20 | import os |
22 | 21 | from array import array |
23 | 22 | from collections import OrderedDict, defaultdict |
@@ -717,27 +716,33 @@ def _expand_device_map(device_map, param_names): |
717 | 716 |
|
718 | 717 |
|
719 | 718 | # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 |
720 | | -def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: |
| 719 | +def _caching_allocator_warmup( |
| 720 | + model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer] |
| 721 | +) -> None: |
721 | 722 | """ |
722 | 723 | This function warm-ups the caching allocator based on the size of the model tensors that will reside on each |
723 | 724 | device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, |
724 | 725 | which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a |
725 | 726 | very large margin. |
726 | 727 | """ |
| 728 | + factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor() |
727 | 729 | # Remove disk and cpu devices, and cast to proper torch.device |
728 | 730 | accelerator_device_map = { |
729 | 731 | param: torch.device(device) |
730 | 732 | for param, device in expanded_device_map.items() |
731 | 733 | if str(device) not in ["cpu", "disk"] |
732 | 734 | } |
733 | | - parameter_count = defaultdict(lambda: 0) |
| 735 | + total_byte_count = defaultdict(lambda: 0) |
734 | 736 | for param_name, device in accelerator_device_map.items(): |
735 | 737 | try: |
736 | 738 | param = model.get_parameter(param_name) |
737 | 739 | except AttributeError: |
738 | 740 | param = model.get_buffer(param_name) |
739 | | - parameter_count[device] += math.prod(param.shape) |
| 741 | + # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` |
| 742 | + param_byte_count = param.numel() * param.element_size() |
| 743 | + # TODO: account for TP when needed. |
| 744 | + total_byte_count[device] += param_byte_count |
740 | 745 |
|
741 | 746 | # This will kick off the caching allocator to avoid having to Malloc afterwards |
742 | | - for device, param_count in parameter_count.items(): |
743 | | - _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) |
| 747 | + for device, byte_count in total_byte_count.items(): |
| 748 | + _ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False) |
0 commit comments