Skip to content

Commit 4890e9b

Browse files
authored
Allow Automodel to use from_config with custom code. (#13123)
* update * update
1 parent f1e5914 commit 4890e9b

File tree

2 files changed

+186
-2
lines changed

2 files changed

+186
-2
lines changed

src/diffusers/models/auto_model.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,128 @@ class AutoModel(ConfigMixin):
3030
def __init__(self, *args, **kwargs):
3131
raise EnvironmentError(
3232
f"{self.__class__.__name__} is designed to be instantiated "
33-
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
33+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`, "
34+
f"`{self.__class__.__name__}.from_config(config)`, or "
3435
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
3536
)
3637

38+
@classmethod
39+
def from_config(
40+
cls, pretrained_model_name_or_path_or_dict: Optional[Union[str, os.PathLike, dict]] = None, **kwargs
41+
):
42+
r"""
43+
Instantiate a model from a config dictionary or a pretrained model configuration file with random weights (no
44+
pretrained weights are loaded).
45+
46+
Parameters:
47+
pretrained_model_name_or_path_or_dict (`str`, `os.PathLike`, or `dict`):
48+
Can be either:
49+
50+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model
51+
configuration hosted on the Hub.
52+
- A path to a *directory* (for example `./my_model_directory`) containing a model configuration
53+
file.
54+
- A config dictionary.
55+
56+
cache_dir (`Union[str, os.PathLike]`, *optional*):
57+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
58+
is not used.
59+
force_download (`bool`, *optional*, defaults to `False`):
60+
Whether or not to force the (re-)download of the model configuration, overriding the cached version if
61+
it exists.
62+
proxies (`Dict[str, str]`, *optional*):
63+
A dictionary of proxy servers to use by protocol or endpoint.
64+
local_files_only(`bool`, *optional*, defaults to `False`):
65+
Whether to only load local model configuration files or not.
66+
token (`str` or *bool*, *optional*):
67+
The token to use as HTTP bearer authorization for remote files.
68+
revision (`str`, *optional*, defaults to `"main"`):
69+
The specific model version to use.
70+
trust_remote_code (`bool`, *optional*, defaults to `False`):
71+
Whether to trust remote code.
72+
subfolder (`str`, *optional*, defaults to `""`):
73+
The subfolder location of a model file within a larger model repository on the Hub or locally.
74+
75+
Returns:
76+
A model object instantiated from the config with random weights.
77+
78+
Example:
79+
80+
```py
81+
from diffusers import AutoModel
82+
83+
model = AutoModel.from_config("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
84+
```
85+
"""
86+
subfolder = kwargs.pop("subfolder", None)
87+
trust_remote_code = kwargs.pop("trust_remote_code", False)
88+
89+
hub_kwargs_names = [
90+
"cache_dir",
91+
"force_download",
92+
"local_files_only",
93+
"proxies",
94+
"revision",
95+
"token",
96+
]
97+
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
98+
99+
if pretrained_model_name_or_path_or_dict is None:
100+
raise ValueError(
101+
"Please provide a `pretrained_model_name_or_path_or_dict` as the first positional argument."
102+
)
103+
104+
if isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
105+
pretrained_model_name_or_path = pretrained_model_name_or_path_or_dict
106+
config = cls.load_config(pretrained_model_name_or_path, subfolder=subfolder, **hub_kwargs)
107+
else:
108+
config = pretrained_model_name_or_path_or_dict
109+
pretrained_model_name_or_path = config.get("_name_or_path", None)
110+
111+
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
112+
trust_remote_code = resolve_trust_remote_code(
113+
trust_remote_code, pretrained_model_name_or_path, has_remote_code
114+
)
115+
116+
if has_remote_code and trust_remote_code:
117+
class_ref = config["auto_map"][cls.__name__]
118+
module_file, class_name = class_ref.split(".")
119+
module_file = module_file + ".py"
120+
model_cls = get_class_from_dynamic_module(
121+
pretrained_model_name_or_path,
122+
subfolder=subfolder,
123+
module_file=module_file,
124+
class_name=class_name,
125+
**hub_kwargs,
126+
)
127+
else:
128+
if "_class_name" in config:
129+
class_name = config["_class_name"]
130+
library = "diffusers"
131+
elif "model_type" in config:
132+
class_name = "AutoModel"
133+
library = "transformers"
134+
else:
135+
raise ValueError(
136+
f"Couldn't find a model class associated with the config: {config}. Make sure the config "
137+
"contains a `_class_name` or `model_type` key."
138+
)
139+
140+
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
141+
142+
model_cls, _ = get_class_obj_and_candidates(
143+
library_name=library,
144+
class_name=class_name,
145+
importable_classes=ALL_IMPORTABLE_CLASSES,
146+
pipelines=None,
147+
is_pipeline_module=False,
148+
)
149+
150+
if model_cls is None:
151+
raise ValueError(f"AutoModel can't find a model linked to {class_name}.")
152+
153+
return model_cls.from_config(config, **kwargs)
154+
37155
@classmethod
38156
@validate_hf_hub_args
39157
def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = None, **kwargs):

tests/models/test_models_auto.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from unittest.mock import patch
2+
from unittest.mock import MagicMock, patch
33

44
from transformers import CLIPTextModel, LongformerModel
55

@@ -30,3 +30,69 @@ def test_load_from_config_without_subfolder(self):
3030
def test_load_from_model_index(self):
3131
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
3232
assert isinstance(model, CLIPTextModel)
33+
34+
35+
class TestAutoModelFromConfig(unittest.TestCase):
36+
@patch(
37+
"diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates",
38+
return_value=(MagicMock(), None),
39+
)
40+
def test_from_config_with_dict_diffusers_class(self, mock_get_class):
41+
config = {"_class_name": "UNet2DConditionModel", "sample_size": 64}
42+
mock_model = MagicMock()
43+
mock_get_class.return_value[0].from_config.return_value = mock_model
44+
45+
result = AutoModel.from_config(config)
46+
47+
mock_get_class.assert_called_once_with(
48+
library_name="diffusers",
49+
class_name="UNet2DConditionModel",
50+
importable_classes=unittest.mock.ANY,
51+
pipelines=None,
52+
is_pipeline_module=False,
53+
)
54+
mock_get_class.return_value[0].from_config.assert_called_once_with(config)
55+
assert result is mock_model
56+
57+
@patch(
58+
"diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates",
59+
return_value=(MagicMock(), None),
60+
)
61+
@patch("diffusers.models.AutoModel.load_config", return_value={"_class_name": "UNet2DConditionModel"})
62+
def test_from_config_with_string_path(self, mock_load_config, mock_get_class):
63+
mock_model = MagicMock()
64+
mock_get_class.return_value[0].from_config.return_value = mock_model
65+
66+
result = AutoModel.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
67+
68+
mock_load_config.assert_called_once()
69+
assert result is mock_model
70+
71+
def test_from_config_raises_on_missing_class_info(self):
72+
config = {"some_key": "some_value"}
73+
with self.assertRaises(ValueError, msg="Couldn't find a model class"):
74+
AutoModel.from_config(config)
75+
76+
@patch(
77+
"diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates",
78+
return_value=(MagicMock(), None),
79+
)
80+
def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class):
81+
config = {"model_type": "clip_text_model"}
82+
mock_model = MagicMock()
83+
mock_get_class.return_value[0].from_config.return_value = mock_model
84+
85+
result = AutoModel.from_config(config)
86+
87+
mock_get_class.assert_called_once_with(
88+
library_name="transformers",
89+
class_name="AutoModel",
90+
importable_classes=unittest.mock.ANY,
91+
pipelines=None,
92+
is_pipeline_module=False,
93+
)
94+
assert result is mock_model
95+
96+
def test_from_config_raises_on_none(self):
97+
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
98+
AutoModel.from_config(None)

0 commit comments

Comments
 (0)