@@ -63,7 +63,7 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
63
63
try :
64
64
config = self ._load_diffusers_config (model_path , config_name = "config.json" )
65
65
if class_name := config .get ("_class_name" ):
66
- result = self ._hf_definition_to_type (module = "diffusers" , class_name = class_name )
66
+ result = self ._hf_definition_to_type (module = "diffusers" , class_name = class_name , model_name = model_path . name )
67
67
elif class_name := config .get ("architectures" ):
68
68
result = self ._hf_definition_to_type (module = "transformers" , class_name = class_name [0 ])
69
69
else :
@@ -74,19 +74,19 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
74
74
return result
75
75
76
76
# TO DO: Add exception handling
77
- def _hf_definition_to_type (self , module : str , class_name : str ) -> ModelMixin : # fix with correct type
77
+ def _hf_definition_to_type (self , module : str , class_name : str , model_name : Optional [ str ] = None ) -> ModelMixin : # fix with correct type
78
78
if module in [
79
79
"diffusers" ,
80
80
"transformers" ,
81
81
"invokeai.backend.quantization.fast_quantized_transformers_model" ,
82
82
"invokeai.backend.quantization.fast_quantized_diffusion_model" ,
83
83
"transformer_bria" ,
84
84
]:
85
- if module == "transformer_bria" :
86
- module = "invokeai.backend.bria.transformer_bria"
87
- elif class_name == "BriaTransformer2DModel" :
85
+ if model_name == "BRIA-3.2-ControlNet-Union" :
88
86
class_name = "BriaControlNetModel"
89
87
module = "invokeai.backend.bria.controlnet_bria"
88
+ elif module == "transformer_bria" or class_name == "BriaTransformer2DModel" :
89
+ module = "invokeai.backend.bria.transformer_bria"
90
90
res_type = sys .modules [module ]
91
91
else :
92
92
res_type = sys .modules ["diffusers" ].pipelines
0 commit comments