@@ -547,6 +547,34 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
547
547
def get_tag (cls ) -> Tag :
548
548
return Tag (f"{ ModelType .Main .value } .{ ModelFormat .Diffusers .value } .{ BaseModelType .Bria .value } " )
549
549
550
+ class BriaControlNetDiffusersConfig (DiffusersConfigBase , ControlAdapterConfigBase , ModelConfigBase ):
551
+ """Model config for Bria/Diffusers ControlNet models."""
552
+
553
+ type : Literal [ModelType .ControlNet ] = ModelType .ControlNet
554
+ format : Literal [ModelFormat .Diffusers ] = ModelFormat .Diffusers
555
+ base : Literal [BaseModelType .Bria ] = BaseModelType .Bria
556
+
557
+ @classmethod
558
+ def matches (cls , mod : ModelOnDisk ) -> bool :
559
+ if mod .path .is_file ():
560
+ return False
561
+
562
+ config_path = mod .path / "config.json"
563
+ if config_path .exists ():
564
+ with open (config_path ) as file :
565
+ transformer_conf = json .load (file )
566
+ if transformer_conf ["_class_name" ] == "BriaTransformer2DModel" :
567
+ return True
568
+
569
+ return False
570
+
571
+ @classmethod
572
+ def parse (cls , mod : ModelOnDisk ) -> dict [str , Any ]:
573
+ return {}
574
+
575
+ @classmethod
576
+ def get_tag (cls ) -> Tag :
577
+ return Tag (f"{ ModelType .ControlNet .value } .{ ModelFormat .Diffusers .value } .{ BaseModelType .Bria .value } " )
550
578
551
579
552
580
class IPAdapterConfigBase (ABC , BaseModel ):
@@ -732,6 +760,7 @@ def get_model_discriminator_value(v: Any) -> str:
732
760
Annotated [ControlLoRADiffusersConfig , ControlLoRADiffusersConfig .get_tag ()],
733
761
Annotated [LoRADiffusersConfig , LoRADiffusersConfig .get_tag ()],
734
762
Annotated [BriaDiffusersConfig , BriaDiffusersConfig .get_tag ()],
763
+ Annotated [BriaControlNetDiffusersConfig , BriaControlNetDiffusersConfig .get_tag ()],
735
764
Annotated [T5EncoderConfig , T5EncoderConfig .get_tag ()],
736
765
Annotated [T5EncoderBnbQuantizedLlmInt8bConfig , T5EncoderBnbQuantizedLlmInt8bConfig .get_tag ()],
737
766
Annotated [TextualInversionFileConfig , TextualInversionFileConfig .get_tag ()],
0 commit comments