Skip to content

Commit ebd06f9

Browse files
authored
[Modular] loader related (#13025)
* tag loader_id from Automodel * style * load_components by default only load components that are not already loaded * by default, skip loading the componeneets does not have the repo id
1 parent b712042 commit ebd06f9

File tree

5 files changed

+24
-6
lines changed

5 files changed

+24
-6
lines changed

src/diffusers/models/auto_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from huggingface_hub.utils import validate_hf_hub_args
1919

2020
from ..configuration_utils import ConfigMixin
21-
from ..utils import logging
21+
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
2222
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
2323

2424

@@ -220,4 +220,11 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
220220
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
221221

222222
kwargs = {**load_config_kwargs, **kwargs}
223-
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
223+
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
224+
225+
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
226+
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
227+
load_id = "|".join("null" if p is None else p for p in parts)
228+
model._diffusers_load_id = load_id
229+
230+
return model

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,6 +2143,8 @@ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwarg
21432143
name
21442144
for name in self._component_specs.keys()
21452145
if self._component_specs[name].default_creation_method == "from_pretrained"
2146+
and self._component_specs[name].pretrained_model_name_or_path is not None
2147+
and getattr(self, name, None) is None
21462148
]
21472149
elif isinstance(names, str):
21482150
names = [names]

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
import inspect
1616
import re
1717
from collections import OrderedDict
18-
from dataclasses import dataclass, field, fields
18+
from dataclasses import dataclass, field
1919
from typing import Any, Dict, List, Literal, Optional, Type, Union
2020

2121
import PIL.Image
2222
import torch
2323

2424
from ..configuration_utils import ConfigMixin, FrozenDict
2525
from ..loaders.single_file_utils import _is_single_file_path_or_url
26-
from ..utils import is_torch_available, logging
26+
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
2727

2828

2929
if is_torch_available():
@@ -186,7 +186,7 @@ def loading_fields(cls) -> List[str]:
186186
"""
187187
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
188188
"""
189-
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
189+
return DIFFUSERS_LOAD_ID_FIELDS.copy()
190190

191191
@property
192192
def load_id(self) -> str:
@@ -198,7 +198,7 @@ def load_id(self) -> str:
198198
return "null"
199199
parts = [getattr(self, k) for k in self.loading_fields()]
200200
parts = ["null" if p is None else p for p in parts]
201-
return "|".join(p for p in parts if p)
201+
return "|".join(parts)
202202

203203
@classmethod
204204
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
2424
DEPRECATED_REVISION_ARGS,
2525
DIFFUSERS_DYNAMIC_MODULE_NAME,
26+
DIFFUSERS_LOAD_ID_FIELDS,
2627
FLAX_WEIGHTS_NAME,
2728
GGUF_FILE_EXTENSION,
2829
HF_ENABLE_PARALLEL_LOADING,

src/diffusers/utils/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,11 @@
7373
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
7474
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
7575
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
76+
77+
78+
DIFFUSERS_LOAD_ID_FIELDS = [
79+
"pretrained_model_name_or_path",
80+
"subfolder",
81+
"variant",
82+
"revision",
83+
]

0 commit comments

Comments
 (0)