@@ -30,10 +30,128 @@ class AutoModel(ConfigMixin):
3030 def __init__ (self , * args , ** kwargs ):
3131 raise EnvironmentError (
3232 f"{ self .__class__ .__name__ } is designed to be instantiated "
33- f"using the `{ self .__class__ .__name__ } .from_pretrained(pretrained_model_name_or_path)` or "
33+ f"using the `{ self .__class__ .__name__ } .from_pretrained(pretrained_model_name_or_path)`, "
34+ f"`{ self .__class__ .__name__ } .from_config(config)`, or "
3435 f"`{ self .__class__ .__name__ } .from_pipe(pipeline)` methods."
3536 )
3637
38+ @classmethod
39+ def from_config (
40+ cls , pretrained_model_name_or_path_or_dict : Optional [Union [str , os .PathLike , dict ]] = None , ** kwargs
41+ ):
42+ r"""
43+ Instantiate a model from a config dictionary or a pretrained model configuration file with random weights (no
44+ pretrained weights are loaded).
45+
46+ Parameters:
47+ pretrained_model_name_or_path_or_dict (`str`, `os.PathLike`, or `dict`):
48+ Can be either:
49+
50+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model
51+ configuration hosted on the Hub.
52+ - A path to a *directory* (for example `./my_model_directory`) containing a model configuration
53+ file.
54+ - A config dictionary.
55+
56+ cache_dir (`Union[str, os.PathLike]`, *optional*):
57+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
58+ is not used.
59+ force_download (`bool`, *optional*, defaults to `False`):
60+ Whether or not to force the (re-)download of the model configuration, overriding the cached version if
61+ it exists.
62+ proxies (`Dict[str, str]`, *optional*):
63+ A dictionary of proxy servers to use by protocol or endpoint.
64+ local_files_only(`bool`, *optional*, defaults to `False`):
65+ Whether to only load local model configuration files or not.
66+ token (`str` or *bool*, *optional*):
67+ The token to use as HTTP bearer authorization for remote files.
68+ revision (`str`, *optional*, defaults to `"main"`):
69+ The specific model version to use.
70+ trust_remote_code (`bool`, *optional*, defaults to `False`):
71+ Whether to trust remote code.
72+ subfolder (`str`, *optional*, defaults to `""`):
73+ The subfolder location of a model file within a larger model repository on the Hub or locally.
74+
75+ Returns:
76+ A model object instantiated from the config with random weights.
77+
78+ Example:
79+
80+ ```py
81+ from diffusers import AutoModel
82+
83+ model = AutoModel.from_config("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
84+ ```
85+ """
86+ subfolder = kwargs .pop ("subfolder" , None )
87+ trust_remote_code = kwargs .pop ("trust_remote_code" , False )
88+
89+ hub_kwargs_names = [
90+ "cache_dir" ,
91+ "force_download" ,
92+ "local_files_only" ,
93+ "proxies" ,
94+ "revision" ,
95+ "token" ,
96+ ]
97+ hub_kwargs = {name : kwargs .pop (name , None ) for name in hub_kwargs_names }
98+
99+ if pretrained_model_name_or_path_or_dict is None :
100+ raise ValueError (
101+ "Please provide a `pretrained_model_name_or_path_or_dict` as the first positional argument."
102+ )
103+
104+ if isinstance (pretrained_model_name_or_path_or_dict , (str , os .PathLike )):
105+ pretrained_model_name_or_path = pretrained_model_name_or_path_or_dict
106+ config = cls .load_config (pretrained_model_name_or_path , subfolder = subfolder , ** hub_kwargs )
107+ else :
108+ config = pretrained_model_name_or_path_or_dict
109+ pretrained_model_name_or_path = config .get ("_name_or_path" , None )
110+
111+ has_remote_code = "auto_map" in config and cls .__name__ in config ["auto_map" ]
112+ trust_remote_code = resolve_trust_remote_code (
113+ trust_remote_code , pretrained_model_name_or_path , has_remote_code
114+ )
115+
116+ if has_remote_code and trust_remote_code :
117+ class_ref = config ["auto_map" ][cls .__name__ ]
118+ module_file , class_name = class_ref .split ("." )
119+ module_file = module_file + ".py"
120+ model_cls = get_class_from_dynamic_module (
121+ pretrained_model_name_or_path ,
122+ subfolder = subfolder ,
123+ module_file = module_file ,
124+ class_name = class_name ,
125+ ** hub_kwargs ,
126+ )
127+ else :
128+ if "_class_name" in config :
129+ class_name = config ["_class_name" ]
130+ library = "diffusers"
131+ elif "model_type" in config :
132+ class_name = "AutoModel"
133+ library = "transformers"
134+ else :
135+ raise ValueError (
136+ f"Couldn't find a model class associated with the config: { config } . Make sure the config "
137+ "contains a `_class_name` or `model_type` key."
138+ )
139+
140+ from ..pipelines .pipeline_loading_utils import ALL_IMPORTABLE_CLASSES , get_class_obj_and_candidates
141+
142+ model_cls , _ = get_class_obj_and_candidates (
143+ library_name = library ,
144+ class_name = class_name ,
145+ importable_classes = ALL_IMPORTABLE_CLASSES ,
146+ pipelines = None ,
147+ is_pipeline_module = False ,
148+ )
149+
150+ if model_cls is None :
151+ raise ValueError (f"AutoModel can't find a model linked to { class_name } ." )
152+
153+ return model_cls .from_config (config , ** kwargs )
154+
37155 @classmethod
38156 @validate_hf_hub_args
39157 def from_pretrained (cls , pretrained_model_or_path : str | os .PathLike | None = None , ** kwargs ):
0 commit comments