Skip to content

Commit bb7bb96

Browse files
feat(mm): support bria-3 controlnets
1 parent 0f53b9b commit bb7bb96

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,34 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
547547
def get_tag(cls) -> Tag:
548548
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}.{BaseModelType.Bria.value}")
549549

550+
class BriaControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
551+
"""Model config for Bria/Diffusers ControlNet models."""
552+
553+
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
554+
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
555+
base: Literal[BaseModelType.Bria] = BaseModelType.Bria
556+
557+
@classmethod
558+
def matches(cls, mod: ModelOnDisk) -> bool:
559+
if mod.path.is_file():
560+
return False
561+
562+
config_path = mod.path / "config.json"
563+
if config_path.exists():
564+
with open(config_path) as file:
565+
transformer_conf = json.load(file)
566+
if transformer_conf["_class_name"] == "BriaTransformer2DModel":
567+
return True
568+
569+
return False
570+
571+
@classmethod
572+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
573+
return {}
574+
575+
@classmethod
576+
def get_tag(cls) -> Tag:
577+
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}.{BaseModelType.Bria.value}")
550578

551579

552580
class IPAdapterConfigBase(ABC, BaseModel):
@@ -732,6 +760,7 @@ def get_model_discriminator_value(v: Any) -> str:
732760
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
733761
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
734762
Annotated[BriaDiffusersConfig, BriaDiffusersConfig.get_tag()],
763+
Annotated[BriaControlNetDiffusersConfig, BriaControlNetDiffusersConfig.get_tag()],
735764
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
736765
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
737766
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ class ModelProbe(object):
126126
}
127127

128128
CLASS2TYPE = {
129-
"BriaPipeline": ModelType.Main,
130-
"BriaTransformer2DModel": ModelType.ControlNet,
131129
"FluxPipeline": ModelType.Main,
132130
"StableDiffusionPipeline": ModelType.Main,
133131
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -876,8 +874,6 @@ def get_base_type(self) -> BaseModelType:
876874
return BaseModelType.StableDiffusion3
877875
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
878876
return BaseModelType.CogView4
879-
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
880-
return BaseModelType.Bria
881877
else:
882878
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
883879

@@ -1027,9 +1023,6 @@ def get_base_type(self) -> BaseModelType:
10271023
if config.get("_class_name", None) == "FluxControlNetModel":
10281024
return BaseModelType.Flux
10291025

1030-
if config.get("_class_name", None) == "BriaTransformer2DModel":
1031-
return BaseModelType.Bria
1032-
10331026
# no obvious way to distinguish between sd2-base and sd2-768
10341027
dimension = config["cross_attention_dim"]
10351028
if dimension == 768:

0 commit comments

Comments
 (0)