Skip to content

Commit b256087

Browse files
committed
cr fixes 3
1 parent 1eee328 commit b256087

File tree

5 files changed

+311
-79
lines changed

5 files changed

+311
-79
lines changed

invokeai/app/invocations/bria_denoiser.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
66

77
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
8-
from invokeai.app.invocations.bria_latent_noise import BriaLatentNoiseOutput
98
from invokeai.app.invocations.fields import FluxConditioningField, Input, InputField, LatentsField, OutputField
109
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
1110
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
@@ -72,11 +71,6 @@ class BriaDenoiseInvocation(BaseInvocation):
7271
title="Width",
7372
description="The width of the output image",
7473
)
75-
latent_noise: BriaLatentNoiseOutput = InputField(
76-
description="Latent noise to denoise",
77-
input=Input.Connection,
78-
title="Latent Noise",
79-
)
8074
pos_embeds: FluxConditioningField = InputField(
8175
description="Positive Prompt Embeds",
8276
input=Input.Connection,
@@ -87,6 +81,16 @@ class BriaDenoiseInvocation(BaseInvocation):
8781
input=Input.Connection,
8882
title="Negative Prompt Embeds",
8983
)
84+
latents: LatentsField = InputField(
85+
description="Latent noise with latent image ids to denoise",
86+
input=Input.Connection,
87+
title="Latent Noise",
88+
)
89+
latent_image_ids: LatentsField = InputField(
90+
description="Latent image ids to denoise",
91+
input=Input.Connection,
92+
title="Latent Image IDs",
93+
)
9094
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
9195
description="ControlNet",
9296
input=Input.Connection,
@@ -96,10 +100,10 @@ class BriaDenoiseInvocation(BaseInvocation):
96100

97101
@torch.no_grad()
98102
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
99-
latents = context.tensors.load(self.latent_noise.latents.latents_name)
103+
latents = context.tensors.load(self.latents.latents_name)
104+
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
100105
pos_embeds = context.tensors.load(self.pos_embeds.conditioning_name)
101106
neg_embeds = context.tensors.load(self.neg_embeds.conditioning_name)
102-
latent_image_ids = context.tensors.load(self.latent_noise.latent_image_ids.latents_name)
103107
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
104108

105109
device = None

invokeai/app/invocations/bria_latent_noise.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
)
1919

2020

21-
class BriaLatentNoiseOutput(BaseModel):
22-
latents: LatentsField
23-
latent_image_ids: LatentsField
24-
2521
@invocation_output("bria_latent_noise_output")
2622
class BriaLatentNoiseInvocationOutput(BaseInvocationOutput):
2723
"""Base class for nodes that output Bria latent tensors."""
28-
latent_noise: BriaLatentNoiseOutput = OutputField(description="The latent noise, containing latents and latent image ids.")
24+
latents: LatentsField = OutputField(description="The latent noise")
25+
latent_image_ids: LatentsField = OutputField(description="The latent image ids.")
2926
height: int = OutputField(description="The height of the output image")
3027
width: int = OutputField(description="The width of the output image")
3128

@@ -86,10 +83,8 @@ def invoke(self, context: InvocationContext) -> BriaLatentNoiseInvocationOutput:
8683
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
8784

8885
return BriaLatentNoiseInvocationOutput(
89-
latent_noise=BriaLatentNoiseOutput(
90-
latents=latents_output,
91-
latent_image_ids=latent_image_ids_output,
92-
),
86+
latents=latents_output,
87+
latent_image_ids=latent_image_ids_output,
9388
height=self.height,
9489
width=self.width,
9590
)

invokeai/app/invocations/bria_decoder.py renamed to invokeai/app/invocations/bria_latents_to_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010

1111
@invocation(
12-
"bria_decoder",
13-
title="Decoder - Bria",
12+
"bria_latents_to_image",
13+
title="Latents to Image - Bria",
1414
tags=["image", "bria"],
1515
category="image",
1616
version="1.0.0",
1717
classification=Classification.Prototype,
1818
)
19-
class BriaDecoderInvocation(BaseInvocation):
19+
class BriaLatentsToImageInvocation(BaseInvocation):
2020
"""
2121
Decode Bria latents to an image.
2222
"""

invokeai/backend/bria/pipeline_bria_controlnet.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,6 @@ def __call__(
259259
prompt (`str` or `List[str]`, *optional*):
260260
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
261261
instead.
262-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
263-
The height in pixels of the generated image. This is set to 1024 by default for the best results.
264-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
265-
The width in pixels of the generated image. This is set to 1024 by default for the best results.
266262
num_inference_steps (`int`, *optional*, defaults to 50):
267263
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
268264
expense of slower inference.
@@ -323,8 +319,6 @@ def __call__(
323319
`tuple`. When returning a tuple, the first element is a list with the generated images.
324320
"""
325321

326-
height = height or self.default_sample_size * self.vae_scale_factor
327-
width = width or self.default_sample_size * self.vae_scale_factor
328322
control_guidance_start, control_guidance_end = self.get_control_start_end(
329323
control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
330324
)
@@ -335,8 +329,8 @@ def __call__(
335329
)
336330
self.check_inputs(
337331
prompt,
338-
height,
339-
width,
332+
height=height,
333+
width=width,
340334
negative_prompt=negative_prompt,
341335
prompt_embeds=prompt_embeds,
342336
negative_prompt_embeds=negative_prompt_embeds,
@@ -517,7 +511,7 @@ def __call__(
517511
order=1,
518512
total_steps=num_inference_steps,
519513
timestep=int(t),
520-
latents=latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128),
514+
latents=self._unpack_latents(latents, height, width, self.vae_scale_factor),
521515
),
522516
)
523517

0 commit comments

Comments
 (0)