Skip to content

Commit aa4ec59

Browse files
ilanbriapsychedelicious
authored andcommitted
cr fixes 2
1 parent 6566ef7 commit aa4ec59

File tree

14 files changed

+103
-73
lines changed

14 files changed

+103
-73
lines changed

invokeai/app/invocations/bria_decoder.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,19 @@ class BriaDecoderInvocation(BaseInvocation):
2929
description=FieldDescriptions.latents,
3030
input=Input.Connection,
3131
)
32+
height: int = InputField(
33+
title="Height",
34+
description="The height of the output image",
35+
)
36+
width: int = InputField(
37+
title="Width",
38+
description="The width of the output image",
39+
)
3240

3341
@torch.no_grad()
3442
def invoke(self, context: InvocationContext) -> ImageOutput:
3543
latents = context.tensors.load(self.latents.latents_name)
36-
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
44+
latents = _unpack_latents(latents, self.height, self.width)
3745

3846
with context.models.load(self.vae.vae) as vae:
3947
assert isinstance(vae, AutoencoderKL)
@@ -48,3 +56,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
4856
img = Image.fromarray(image)
4957
image_dto = context.images.save(image=img)
5058
return ImageOutput.build(image_dto)
59+
60+
61+
def _unpack_latents(latents, height, width, vae_scale_factor=16):
62+
batch_size, num_patches, channels = latents.shape
63+
64+
height = height // vae_scale_factor
65+
width = width // vae_scale_factor
66+
67+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
68+
latents = latents.permute(0, 3, 1, 4, 2, 5)
69+
70+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
71+
72+
return latents

invokeai/app/invocations/bria_denoiser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
@invocation_output("bria_denoise_output")
2323
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
2424
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
25+
height: int = OutputField(description="The height of the output image")
26+
width: int = OutputField(description="The width of the output image")
2527

2628

2729
@invocation(
@@ -144,7 +146,6 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
144146
height=self.height,
145147
controlnet_conditioning_scale=control_scales,
146148
num_inference_steps=self.num_steps,
147-
max_sequence_length=128,
148149
guidance_scale=self.guidance_scale,
149150
latents=latents,
150151
latent_image_ids=latent_image_ids,
@@ -158,7 +159,7 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
158159

159160
assert isinstance(output_latents, torch.Tensor)
160161
saved_input_latents_tensor = context.tensors.save(output_latents)
161-
return BriaDenoiseInvocationOutput(latents=LatentsField(latents_name=saved_input_latents_tensor))
162+
return BriaDenoiseInvocationOutput(latents=LatentsField(latents_name=saved_input_latents_tensor), height=self.height, width=self.width)
162163

163164
def _prepare_multi_control(
164165
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device
@@ -191,7 +192,6 @@ def _prepare_multi_control(
191192

192193
def _build_step_callback(context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
193194
def step_callback(state: PipelineIntermediateState) -> None:
194-
return
195195
context.util.sd_step_callback(state, BaseModelType.Bria)
196196

197197
return step_callback

invokeai/app/invocations/bria_latent_noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class BriaLatentNoiseOutput(BaseModel):
2626
class BriaLatentNoiseInvocationOutput(BaseInvocationOutput):
2727
"""Base class for nodes that output Bria latent tensors."""
2828
latent_noise: BriaLatentNoiseOutput = OutputField(description="The latent noise, containing latents and latent image ids.")
29-
height: int = OutputField(description="The height of the output image", default=1024)
30-
width: int = OutputField(description="The width of the output image", default=1024)
29+
height: int = OutputField(description="The height of the output image")
30+
width: int = OutputField(description="The width of the output image")
3131

3232
@invocation(
3333
"bria_latent_noise",

invokeai/app/invocations/bria_text_encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
invocation_output,
2020
)
2121

22+
DEFAULT_NEGATIVE_PROMPT = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate"
2223

2324
@invocation_output("bria_text_encoder_output")
2425
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
@@ -48,7 +49,7 @@ class BriaTextEncoderInvocation(BaseInvocation):
4849
negative_prompt: Optional[str] = InputField(
4950
title="Negative Prompt",
5051
description="The negative prompt to encode",
51-
default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate",
52+
default="",
5253
)
5354
max_length: int = InputField(
5455
default=256,
@@ -74,11 +75,12 @@ def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
7475
assert isinstance(tokenizer, T5TokenizerFast)
7576
assert isinstance(text_encoder, T5EncoderModel)
7677

78+
negative_prompt = f"{DEFAULT_NEGATIVE_PROMPT}, {self.negative_prompt}"
7779
prompt_embeds, negative_prompt_embeds = encode_prompt(
7880
prompt=self.prompt,
7981
tokenizer=tokenizer,
8082
text_encoder=text_encoder,
81-
negative_prompt=self.negative_prompt,
83+
negative_prompt=negative_prompt,
8284
device=text_encoder.device,
8385
num_images_per_prompt=1,
8486
max_sequence_length=self.max_length,

invokeai/app/util/step_callback.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@
9494
]
9595

9696
BRIA_LATENT_RGB_FACTORS = [
97-
97+
[0.31115174, 0.38229316, 0.43620577],
98+
[-0.26867455, 0.05353606, 0.1088054],
99+
[0.09892498, 0.17854956, -0.12029117],
100+
[-0.37774912, -0.17128916, -0.25255626],
98101
]
99102

100103

invokeai/backend/bria/pipeline_bria.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ def check_inputs(
236236
prompt_embeds=None,
237237
negative_prompt_embeds=None,
238238
callback_on_step_end_tensor_inputs=None,
239-
max_sequence_length=None,
240239
):
241240
if height % 8 != 0 or width % 8 != 0:
242241
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -266,8 +265,6 @@ def check_inputs(
266265
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
267266
)
268267

269-
if max_sequence_length is not None and max_sequence_length > 512:
270-
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
271268

272269
def to(self, *args, **kwargs):
273270
DiffusionPipeline.to(self, *args, **kwargs)

invokeai/backend/bria/pipeline_bria_controlnet.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ def __call__(
251251
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
252252
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
253253
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
254-
max_sequence_length: int = 128,
255254
step_callback: Callable[[PipelineIntermediateState], None] = None,
256255
):
257256
r"""
@@ -342,7 +341,6 @@ def __call__(
342341
prompt_embeds=prompt_embeds,
343342
negative_prompt_embeds=negative_prompt_embeds,
344343
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
345-
max_sequence_length=max_sequence_length,
346344
)
347345

348346
self._guidance_scale = guidance_scale
@@ -416,15 +414,15 @@ def __call__(
416414
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
417415

418416
# Init Invoke step callback
419-
# step_callback(
420-
# PipelineIntermediateState(
421-
# step=0,
422-
# order=1,
423-
# total_steps=num_inference_steps,
424-
# timestep=int(timesteps[0]),
425-
# latents=latents,
426-
# ),
427-
# )
417+
step_callback(
418+
PipelineIntermediateState(
419+
step=0,
420+
order=1,
421+
total_steps=num_inference_steps,
422+
timestep=int(timesteps[0]),
423+
latents=latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128),
424+
),
425+
)
428426

429427
# EYAL - added the CFG loop
430428
# 7. Denoising loop
@@ -513,15 +511,15 @@ def __call__(
513511
# call the callback, if provided
514512
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
515513
progress_bar.update()
516-
# step_callback(
517-
# PipelineIntermediateState(
518-
# step=i + 1,
519-
# order=1,
520-
# total_steps=num_inference_steps,
521-
# timestep=int(t),
522-
# latents=latents,
523-
# ),
524-
# )
514+
step_callback(
515+
PipelineIntermediateState(
516+
step=i + 1,
517+
order=1,
518+
total_steps=num_inference_steps,
519+
timestep=int(t),
520+
latents=latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128),
521+
),
522+
)
525523

526524
if output_type == "latent":
527525
image = latents

invokeai/backend/model_manager/config.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -442,27 +442,6 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
442442
"base": cls.base_model(mod),
443443
}
444444

445-
class BriaDiffusersConfig(LoRAConfigBase, ModelConfigBase):
446-
"""Model config for Bria/Diffusers models."""
447-
448-
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
449-
450-
@classmethod
451-
def matches(cls, mod: ModelOnDisk) -> bool:
452-
if mod.path.is_file():
453-
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
454-
455-
suffixes = ["bin", "safetensors"]
456-
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
457-
return any(wf.exists() for wf in weight_files)
458-
459-
@classmethod
460-
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
461-
return {
462-
"base": cls.base_model(mod),
463-
}
464-
465-
466445
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
467446
"""Model config for standalone VAE models."""
468447

@@ -540,6 +519,35 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
540519

541520
pass
542521

522+
class BriaDiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
523+
"""Model config for Bria/Diffusers models."""
524+
525+
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
526+
base: Literal[BaseModelType.Bria] = BaseModelType.Bria
527+
528+
@classmethod
529+
def matches(cls, mod: ModelOnDisk) -> bool:
530+
if mod.path.is_file():
531+
return False
532+
533+
config_path = mod.path / "transformer" / "config.json"
534+
if config_path.exists():
535+
with open(config_path) as file:
536+
transformer_conf = json.load(file)
537+
if transformer_conf["_class_name"] == "BriaTransformer2DModel":
538+
return True
539+
540+
return False
541+
542+
@classmethod
543+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
544+
return {}
545+
546+
@classmethod
547+
def get_tag(cls) -> Tag:
548+
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}.{BaseModelType.Bria.value}")
549+
550+
543551

544552
class IPAdapterConfigBase(ABC, BaseModel):
545553
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class BaseModelType(str, Enum):
3030
Imagen4 = "imagen4"
3131
ChatGPT4o = "chatgpt-4o"
3232
FluxKontext = "flux-kontext"
33-
Bria = "bria"
33+
Bria = "bria-3"
3434

3535

3636
class ModelType(str, Enum):

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
2020
imagen4: 'pink',
2121
'chatgpt-4o': 'pink',
2222
'flux-kontext': 'pink',
23-
bria: 'purple',
23+
'bria-3': 'purple',
2424
};
2525

2626
const ModelBaseBadge = ({ base }: Props) => {

0 commit comments

Comments
 (0)