diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index ecccf3c11311..16bd0441072a 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -62,7 +62,7 @@ if is_accelerate_available(): from accelerate import dispatch_model, init_empty_weights - from ..models.modeling_utils import load_model_dict_into_meta + from ..models.model_loading_utils import load_model_dict_into_meta if is_torch_version(">=", "1.9.0") and is_accelerate_available(): _LOW_CPU_MEM_USAGE_DEFAULT = True diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 723f0c136f48..ef6c41e3ce97 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -55,7 +55,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - from ..models.modeling_utils import load_model_dict_into_meta + from ..models.model_loading_utils import load_model_dict_into_meta logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index ced81960fae5..ef7b921b7ddf 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -17,7 +17,8 @@ ImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.model_loading_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..utils import is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index 1bc3a9c7a851..e3728082efdd 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -16,7 +16,8 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.embeddings import IPAdapterTimeImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.model_loading_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..utils import is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 1d698e5a8b53..c5e56af156fc 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,8 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict +from ..models.model_loading_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 4e2d24b75011..1fcaedcb87d5 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -14,12 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import importlib import inspect import math import os from array import array from collections import OrderedDict, defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Dict, List, Optional, Union from zipfile import is_zipfile @@ -31,6 +33,7 @@ from ..quantizers import DiffusersQuantizer from ..utils import ( + DEFAULT_HF_PARALLEL_LOADING_WORKERS, GGUF_FILE_EXTENSION, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -310,6 +313,161 @@ def load_model_dict_into_meta( return offload_index, state_dict_index +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first + checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's + parameters. + + """ + if model_to_load.device.type == "meta": + return False + + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", True): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = next(iter(model_to_load.state_dict().keys())) + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + return False + + +def _load_shard_file( + shard_file, + model, + model_state_dict, + device_map=None, + dtype=None, + hf_quantizer=None, + keep_in_fp32_modules=None, + dduf_entries=None, + loaded_keys=None, + unexpected_keys=None, + offload_index=None, + offload_folder=None, + state_dict_index=None, + state_dict_folder=None, + ignore_mismatched_sizes=False, + low_cpu_mem_usage=False, +): + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = [] + if low_cpu_mem_usage: + offload_index, state_dict_index = load_model_dict_into_meta( + model, + state_dict, + device_map=device_map, + dtype=dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_index=state_dict_index, + state_dict_folder=state_dict_folder, + ) + else: + assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) + + error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + return offload_index, state_dict_index, mismatched_keys, error_msgs + + +def _load_shard_files_with_threadpool( + shard_files, + model, + model_state_dict, + device_map=None, + dtype=None, + hf_quantizer=None, + keep_in_fp32_modules=None, + dduf_entries=None, + loaded_keys=None, + unexpected_keys=None, + offload_index=None, + offload_folder=None, + state_dict_index=None, + state_dict_folder=None, + ignore_mismatched_sizes=False, + low_cpu_mem_usage=False, +): + # Do not spawn anymore workers than you need + num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS) + + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + + error_msgs = [] + mismatched_keys = [] + + load_one = functools.partial( + _load_shard_file, + model=model, + model_state_dict=model_state_dict, + device_map=device_map, + dtype=dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + dduf_entries=dduf_entries, + loaded_keys=loaded_keys, + unexpected_keys=unexpected_keys, + offload_index=offload_index, + offload_folder=offload_folder, + state_dict_index=state_dict_index, + state_dict_folder=state_dict_folder, + ignore_mismatched_sizes=ignore_mismatched_sizes, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar: + futures = [executor.submit(load_one, shard_file) for shard_file in shard_files] + for future in as_completed(futures): + result = future.result() + offload_index, state_dict_index, _mismatched_keys, _error_msgs = result + error_msgs += _error_msgs + mismatched_keys += _mismatched_keys + pbar.update(1) + + return offload_index, state_dict_index, mismatched_keys, error_msgs + + +def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, +): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + + if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + def _load_state_dict_into_model( model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False ) -> List[str]: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 815f12a70774..8ab301426263 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import copy +import functools import inspect import itertools import json @@ -41,7 +42,9 @@ from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, + ENV_VARS_TRUE_VALUES, FLAX_WEIGHTS_NAME, + HF_PARALLEL_LOADING_FLAG, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -69,9 +72,8 @@ _expand_device_map, _fetch_index_file, _fetch_index_file_legacy, - _find_mismatched_keys, - _load_state_dict_into_model, - load_model_dict_into_meta, + _load_shard_file, + _load_shard_files_with_threadpool, load_state_dict, ) @@ -208,34 +210,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return last_tuple[1].dtype -def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): - """ - Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first - checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's - parameters. - - """ - if model_to_load.device.type == "meta": - return False - - if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: - return False - - # Some models explicitly do not support param buffer assignment - if not getattr(model_to_load, "_supports_param_buffer_assignment", True): - logger.debug( - f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" - ) - return False - - # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype - first_key = next(iter(model_to_load.state_dict().keys())) - if start_prefix + first_key in state_dict: - return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype - - return False - - @contextmanager def no_init_weights(): """ @@ -988,6 +962,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) + is_parallel_loading_enabled = os.environ.get(HF_PARALLEL_LOADING_FLAG, "").upper() in ENV_VARS_TRUE_VALUES + if is_parallel_loading_enabled and not low_cpu_mem_usage: + raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.") + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 logger.warning( @@ -1323,6 +1301,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, dduf_entries=dduf_entries, + is_parallel_loading_enabled=is_parallel_loading_enabled, ) loading_info = { "missing_keys": missing_keys, @@ -1518,6 +1497,7 @@ def _load_pretrained_model( offload_state_dict: Optional[bool] = None, offload_folder: Optional[Union[str, os.PathLike]] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None, + is_parallel_loading_enabled: Optional[bool] = False, ): model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -1531,6 +1511,9 @@ def _load_pretrained_model( for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + mismatched_keys = [] + error_msgs = [] + # Deal with offload if device_map is not None and "disk" in device_map.values(): if offload_folder is None: @@ -1566,37 +1549,39 @@ def _load_pretrained_model( # if state dict is not None, it means that we don't need to read the files from resolved_model_file also resolved_model_file = [state_dict] - if len(resolved_model_file) > 1: - resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") - - mismatched_keys = [] - assign_to_params_buffers = None - error_msgs = [] - - for shard_file in resolved_model_file: - state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) - mismatched_keys += _find_mismatched_keys( - state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes - ) + # Prepare the loading function sharing the attributes shared between them. + load_fn = functools.partial( + _load_shard_files_with_threadpool if is_parallel_loading_enabled else _load_shard_file, + model=model, + model_state_dict=model_state_dict, + device_map=device_map, + dtype=dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + dduf_entries=dduf_entries, + loaded_keys=loaded_keys, + unexpected_keys=unexpected_keys, + offload_index=offload_index, + offload_folder=offload_folder, + state_dict_index=state_dict_index, + state_dict_folder=state_dict_folder, + ignore_mismatched_sizes=ignore_mismatched_sizes, + low_cpu_mem_usage=low_cpu_mem_usage, + ) - if low_cpu_mem_usage: - offload_index, state_dict_index = load_model_dict_into_meta( - model, - state_dict, - device_map=device_map, - dtype=dtype, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - unexpected_keys=unexpected_keys, - offload_folder=offload_folder, - offload_index=offload_index, - state_dict_index=state_dict_index, - state_dict_folder=state_dict_folder, - ) - else: - if assign_to_params_buffers is None: - assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) - error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + if is_parallel_loading_enabled: + offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(resolved_model_file) + error_msgs += _error_msgs + mismatched_keys += _mismatched_keys + else: + shard_files = resolved_model_file + if len(resolved_model_file) > 1: + shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + + for shard_file in shard_files: + offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file) + error_msgs += _error_msgs + mismatched_keys += _mismatched_keys empty_device_cache() diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 5f49f5e75734..32bae015e37c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -20,11 +20,13 @@ from .. import __version__ from .constants import ( CONFIG_NAME, + DEFAULT_HF_PARALLEL_LOADING_WORKERS, DEPRECATED_REVISION_ARGS, DIFFUSERS_DYNAMIC_MODULE_NAME, FLAX_WEIGHTS_NAME, GGUF_FILE_EXTENSION, HF_MODULES_CACHE, + HF_PARALLEL_LOADING_FLAG, HUGGINGFACE_CO_RESOLVE_ENDPOINT, MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index f8f04cc03abd..6313d33dddb9 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -43,6 +43,8 @@ DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES +DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 +HF_PARALLEL_LOADING_FLAG = "HF_ENABLE_PARALLEL_LOADING" # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 0254e7e8c8e7..0e16f95a4276 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1428,6 +1428,41 @@ def test_sharded_checkpoints_with_variant(self): self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @require_torch_accelerator + def test_sharded_checkpoints_with_parallel_loading(self): + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict) + + model_size = compute_module_persistent_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + # Now check if the right number of shards exists. First, let's get the number of shards. + # Since this number can be dependent on the model being tested, it's important that we calculate it + # instead of hardcoding it. + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) + self.assertTrue(actual_num_shards == expected_num_shards) + + # Load with parallel loading + os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes" + new_model = self.model_class.from_pretrained(tmp_dir).eval() + new_model = new_model.to(torch_device) + + torch.manual_seed(0) + if "generator" in inputs_dict: + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + new_output = new_model(**inputs_dict) + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + # set to no. + os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no" + @require_torch_accelerator def test_sharded_checkpoints_device_map(self): if self.model_class._no_split_modules is None: