Skip to content

Add low_cpu_mem_usage option to from_single_file to align with from_pretrained #12114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
Expand Down Expand Up @@ -64,6 +64,10 @@

from ..models.modeling_utils import load_model_dict_into_meta

if is_torch_version(">=", "1.9.0") and is_accelerate_available():
_LOW_CPU_MEM_USAGE_DEFAULT = True
else:
_LOW_CPU_MEM_USAGE_DEFAULT = False

SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
Expand Down Expand Up @@ -236,6 +240,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
not initializing the weights. This also tries to not use more than 1x model size in CPU memory
(including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
an older version of PyTorch, setting this argument to `True` will raise an error.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
Expand Down Expand Up @@ -285,6 +294,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

Expand Down Expand Up @@ -389,7 +399,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)

Expand Down Expand Up @@ -427,7 +437,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
)

device_map = None
if is_accelerate_available():
if low_cpu_mem_usage:
param_device = torch.device(device) if device else torch.device("cpu")
empty_state_dict = model.state_dict()
unexpected_keys = [
Expand Down