Skip to content

Commit 0f53b9b

Browse files
ilanbriapsychedelicious
authored andcommitted
readded support for bria3.2 and controlnet
1 parent aa4ec59 commit 0f53b9b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
6363
try:
6464
config = self._load_diffusers_config(model_path, config_name="config.json")
6565
if class_name := config.get("_class_name"):
66-
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
66+
result = self._hf_definition_to_type(module="diffusers", class_name=class_name, model_name=model_path.name)
6767
elif class_name := config.get("architectures"):
6868
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
6969
else:
@@ -74,19 +74,19 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
7474
return result
7575

7676
# TO DO: Add exception handling
77-
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
77+
def _hf_definition_to_type(self, module: str, class_name: str, model_name: Optional[str] = None) -> ModelMixin: # fix with correct type
7878
if module in [
7979
"diffusers",
8080
"transformers",
8181
"invokeai.backend.quantization.fast_quantized_transformers_model",
8282
"invokeai.backend.quantization.fast_quantized_diffusion_model",
8383
"transformer_bria",
8484
]:
85-
if module == "transformer_bria":
86-
module = "invokeai.backend.bria.transformer_bria"
87-
elif class_name == "BriaTransformer2DModel":
85+
if model_name == "BRIA-3.2-ControlNet-Union":
8886
class_name = "BriaControlNetModel"
8987
module = "invokeai.backend.bria.controlnet_bria"
88+
elif module == "transformer_bria" or class_name == "BriaTransformer2DModel":
89+
module = "invokeai.backend.bria.transformer_bria"
9090
res_type = sys.modules[module]
9191
else:
9292
res_type = sys.modules["diffusers"].pipelines

0 commit comments

Comments
 (0)