Skip to content

Commit 282df32

Browse files
Ubuntuhipsterusername
authored andcommitted
fixed node issue
1 parent 8523ea8 commit 282df32

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

invokeai/app/invocations/bria_controlnet.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
22
from pydantic import BaseModel, Field
3-
from invokeai.invocation_api import ImageOutput
3+
from invokeai.invocation_api import ImageOutput, Classification
44
from invokeai.app.invocations.baseinvocation import (
55
BaseInvocation,
66
BaseInvocationOutput,
77
invocation,
88
invocation_output,
99
)
10-
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
10+
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType, WithBoard, WithMetadata
1111
from invokeai.app.invocations.model import ModelIdentifierField
1212
from invokeai.app.services.shared.invocation_context import InvocationContext
1313
import numpy as np
@@ -26,9 +26,9 @@ class BriaControlNetField(BaseModel):
2626
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
2727
conditioning_scale: float = Field(description="The weight given to the ControlNet")
2828

29-
@invocation_output("flux_controlnet_output")
29+
@invocation_output("bria_controlnet_output")
3030
class BriaControlNetOutput(BaseInvocationOutput):
31-
"""FLUX ControlNet info"""
31+
"""Bria ControlNet info"""
3232

3333
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
3434
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
@@ -40,8 +40,9 @@ class BriaControlNetOutput(BaseInvocationOutput):
4040
tags=["controlnet", "bria"],
4141
category="controlnet",
4242
version="1.0.0",
43+
classification=Classification.Prototype,
4344
)
44-
class BriaControlNetInvocation(BaseInvocation):
45+
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
4546
"""Collect Bria ControlNet info to pass to denoiser node."""
4647

4748
control_image: ImageField = InputField(description="The control image")

invokeai/app/invocations/bria_denoiser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
44
from invokeai.backend.bria.controlnet_utils import prepare_control_images
55
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
6-
from invokeai.nodes.bria_nodes.bria_controlnet import BriaControlNetField
6+
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
77

88
import torch
99
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
@@ -12,7 +12,7 @@
1212
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
1313
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
1414
from invokeai.app.services.shared.invocation_context import InvocationContext
15-
from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output
15+
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
1616
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
1717

1818
@invocation_output("bria_denoise_output")

invokeai/app/invocations/bria_latent_sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,22 @@ class BriaLatentSamplerInvocation(BaseInvocation):
4848
title="Transformer",
4949
)
5050

51+
@torch.no_grad()
5152
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
52-
device = torch.device("cuda")
53+
with context.models.load(self.transformer.transformer) as transformer:
54+
device = transformer.device
55+
dtype = transformer.dtype
56+
5357
height, width = 1024, 1024
5458
generator = torch.Generator(device=device).manual_seed(self.seed)
5559

56-
num_channels_latents = 4 # due to patch=2, we devide by 4
60+
num_channels_latents = 4
5761
latents, latent_image_ids = prepare_latents(
5862
batch_size=1,
5963
num_channels_latents=num_channels_latents,
6064
height=height,
6165
width=width,
62-
dtype=torch.float32,
66+
dtype=dtype,
6367
device=device,
6468
generator=generator,
6569
)

invokeai/backend/bria/pipeline_bria_controlnet.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -612,14 +612,14 @@ def encode_prompt(
612612

613613

614614
def prepare_latents(
615-
batch_size,
616-
num_channels_latents,
617-
height,
618-
width,
619-
dtype,
620-
device,
621-
generator,
622-
latents=None,
615+
batch_size: int,
616+
num_channels_latents: int,
617+
height: int,
618+
width: int,
619+
dtype: torch.dtype,
620+
device: torch.device,
621+
generator: torch.Generator,
622+
latents: Optional[torch.FloatTensor] = None,
623623
):
624624
# VAE applies 8x compression on images but we must also account for packing which requires
625625
# latent height and width to be divisible by 2.

0 commit comments

Comments
 (0)