Skip to content

[core] parallel loading of shards #12028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta

if is_torch_version(">=", "1.9.0") and is_accelerate_available():
_LOW_CPU_MEM_USAGE_DEFAULT = True
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
if is_accelerate_available():
from accelerate import init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down
158 changes: 158 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
Expand All @@ -31,6 +33,7 @@

from ..quantizers import DiffusersQuantizer
from ..utils import (
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
Expand Down Expand Up @@ -310,6 +313,161 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it here from modeling_utils.py.

"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.

"""
if model_to_load.device.type == "meta":
return False

if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False

# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False

# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

return False


def _load_shard_file(
shard_file,
model,
model_state_dict,
device_map=None,
dtype=None,
hf_quantizer=None,
keep_in_fp32_modules=None,
dduf_entries=None,
loaded_keys=None,
unexpected_keys=None,
offload_index=None,
offload_folder=None,
state_dict_index=None,
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
):
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = []
if low_cpu_mem_usage:
offload_index, state_dict_index = load_model_dict_into_meta(
model,
state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
)
else:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)

error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
return offload_index, state_dict_index, mismatched_keys, error_msgs


def _load_shard_files_with_threadpool(
shard_files,
model,
model_state_dict,
device_map=None,
dtype=None,
hf_quantizer=None,
keep_in_fp32_modules=None,
dduf_entries=None,
loaded_keys=None,
unexpected_keys=None,
offload_index=None,
offload_folder=None,
state_dict_index=None,
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
):
# Do not spawn anymore workers than you need
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)

logger.info(f"Loading model weights in parallel with {num_workers} workers...")

error_msgs = []
mismatched_keys = []

load_one = functools.partial(
_load_shard_file,
model=model,
model_state_dict=model_state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
loaded_keys=loaded_keys,
unexpected_keys=unexpected_keys,
offload_index=offload_index,
offload_folder=offload_folder,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
)

with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys
pbar.update(1)

return offload_index, state_dict_index, mismatched_keys, error_msgs


def _find_mismatched_keys(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Moved it out of modeling_utils.py.

state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue

if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys


def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
Expand Down
Loading
Loading