-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Integrate Bria 3.1/3.2 Models and ControlNet Pipelines into InvokeAI #8248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ilanbria
wants to merge
20
commits into
invoke-ai:main
Choose a base branch
from
ilanbria:ilan/support_bria_3.2_pipeline
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
9d0b84f
Setup Probe and UI to accept bria main models
brandonrising 5d05120
added support for loading bria transformer
9c90223
front end support for bria
da96729
addded bria nodes for bria3.1 and bria3.2
b06f21d
Setup Probe and UI to accept bria controlnet models
72c2f75
Add Bria text to image model and controlnet support
9c1fbc9
Added scikit-image required for Bria's OpenposeDetector model
4ad5647
removed unused file
484fcde
Small cosmetic fixes
4ec7bca
moved bria's nodes to invocations folder
2a8e750
fixed node issue
2dcff77
ruff fix
e945696
fixed schema
7f58ff5
cr fixes 1
ilanbria c64dce2
cr fixes 2
ilanbria 35382ac
readded support for bria3.2 and controlnet
ilanbria ed6b954
feat(mm): support bria-3 controlnets
psychedelicious 60953b1
cr fixes 3
ilanbria 25d1a5c
chore: ruff
psychedelicious 7a53114
feat(nodes): use TorchDevice to get device/dtype in bria latent noise…
psychedelicious File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from PIL import Image | ||
from pydantic import BaseModel, Field | ||
|
||
from invokeai.app.invocations.baseinvocation import ( | ||
BaseInvocation, | ||
BaseInvocationOutput, | ||
invocation, | ||
invocation_output, | ||
) | ||
from invokeai.app.invocations.fields import ( | ||
FieldDescriptions, | ||
ImageField, | ||
InputField, | ||
OutputField, | ||
UIType, | ||
WithBoard, | ||
WithMetadata, | ||
) | ||
from invokeai.app.invocations.model import ModelIdentifierField | ||
from invokeai.app.services.shared.invocation_context import InvocationContext | ||
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES | ||
from invokeai.invocation_api import Classification | ||
|
||
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf" | ||
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" | ||
|
||
|
||
class BriaControlNetField(BaseModel): | ||
image: ImageField = Field(description="The control image") | ||
model: ModelIdentifierField = Field(description="The ControlNet model to use") | ||
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet") | ||
conditioning_scale: float = Field(description="The weight given to the ControlNet") | ||
|
||
|
||
@invocation_output("bria_controlnet_output") | ||
class BriaControlNetOutput(BaseInvocationOutput): | ||
"""Bria ControlNet info""" | ||
|
||
control: BriaControlNetField = OutputField(description=FieldDescriptions.control) | ||
|
||
|
||
@invocation( | ||
"bria_controlnet", | ||
title="ControlNet - Bria", | ||
tags=["controlnet", "bria"], | ||
category="controlnet", | ||
version="1.0.0", | ||
classification=Classification.Prototype, | ||
) | ||
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard): | ||
"""Collect Bria ControlNet info to pass to denoiser node.""" | ||
|
||
control_image: ImageField = InputField(description="The control image") | ||
control_model: ModelIdentifierField = InputField( | ||
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel | ||
) | ||
control_mode: BRIA_CONTROL_MODES = InputField(default="depth", description="The mode of the ControlNet") | ||
control_weight: float = InputField(default=1.0, ge=-1, le=2, description="The weight given to the ControlNet") | ||
|
||
def invoke(self, context: InvocationContext) -> BriaControlNetOutput: | ||
image_in = resize_img(context.images.get_pil(self.control_image.image_name)) | ||
if self.control_mode == "colorgrid": | ||
control_image = tile(64, image_in) | ||
elif self.control_mode == "recolor": | ||
control_image = convert_to_grayscale(image_in) | ||
elif self.control_mode == "tile": | ||
control_image = tile(16, image_in) | ||
else: | ||
control_image = image_in | ||
|
||
control_image = resize_img(control_image) | ||
image_dto = context.images.save(image=control_image) | ||
return BriaControlNetOutput( | ||
control=BriaControlNetField( | ||
image=ImageField(image_name=image_dto.image_name), | ||
model=self.control_model, | ||
mode=self.control_mode, | ||
conditioning_scale=self.control_weight, | ||
), | ||
) | ||
|
||
|
||
RATIO_CONFIGS_1024 = { | ||
0.6666666666666666: {"width": 832, "height": 1248}, | ||
0.7432432432432432: {"width": 880, "height": 1184}, | ||
0.8028169014084507: {"width": 912, "height": 1136}, | ||
1.0: {"width": 1024, "height": 1024}, | ||
1.2456140350877194: {"width": 1136, "height": 912}, | ||
1.3454545454545455: {"width": 1184, "height": 880}, | ||
1.4339622641509433: {"width": 1216, "height": 848}, | ||
1.5: {"width": 1248, "height": 832}, | ||
1.5490196078431373: {"width": 1264, "height": 816}, | ||
1.62: {"width": 1296, "height": 800}, | ||
1.7708333333333333: {"width": 1360, "height": 768}, | ||
} | ||
|
||
|
||
def convert_to_grayscale(image: Image.Image) -> Image.Image: | ||
gray_image = image.convert("L").convert("RGB") | ||
return gray_image | ||
|
||
|
||
def tile(downscale_factor: int, input_image: Image.Image) -> Image.Image: | ||
control_image = input_image.resize( | ||
(input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor) | ||
).resize(input_image.size, Image.Resampling.NEAREST) | ||
return control_image | ||
|
||
|
||
def resize_img(control_image: Image.Image) -> Image.Image: | ||
image_ratio = control_image.width / control_image.height | ||
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio)) | ||
to_height = RATIO_CONFIGS_1024[ratio]["height"] | ||
to_width = RATIO_CONFIGS_1024[ratio]["width"] | ||
resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS) | ||
return resized_image |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
from typing import Callable, List, Tuple | ||
|
||
import torch | ||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL | ||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler | ||
|
||
from invokeai.app.invocations.bria_controlnet import BriaControlNetField | ||
from invokeai.app.invocations.fields import FluxConditioningField, Input, InputField, LatentsField, OutputField | ||
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField | ||
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions | ||
from invokeai.app.services.shared.invocation_context import InvocationContext | ||
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel | ||
from invokeai.backend.bria.controlnet_utils import prepare_control_images | ||
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline | ||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel | ||
from invokeai.backend.model_manager.taxonomy import BaseModelType | ||
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState | ||
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output | ||
|
||
|
||
@invocation_output("bria_denoise_output") | ||
class BriaDenoiseInvocationOutput(BaseInvocationOutput): | ||
latents: LatentsField = OutputField(description=FieldDescriptions.latents) | ||
height: int = OutputField(description="The height of the output image") | ||
width: int = OutputField(description="The width of the output image") | ||
|
||
|
||
@invocation( | ||
"bria_denoise", | ||
title="Denoise - Bria", | ||
tags=["image", "bria"], | ||
category="image", | ||
version="1.0.0", | ||
classification=Classification.Prototype, | ||
) | ||
class BriaDenoiseInvocation(BaseInvocation): | ||
""" | ||
Denoise Bria latents using a Bria Pipeline. | ||
""" | ||
|
||
num_steps: int = InputField( | ||
default=30, title="Number of Steps", description="The number of steps to use for the denoiser" | ||
) | ||
guidance_scale: float = InputField( | ||
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser" | ||
) | ||
|
||
transformer: TransformerField = InputField( | ||
description="Bria model (Transformer) to load", | ||
input=Input.Connection, | ||
title="Transformer", | ||
) | ||
t5_encoder: T5EncoderField = InputField( | ||
title="T5Encoder", | ||
description=FieldDescriptions.t5_encoder, | ||
input=Input.Connection, | ||
) | ||
vae: VAEField = InputField( | ||
description=FieldDescriptions.vae, | ||
input=Input.Connection, | ||
title="VAE", | ||
) | ||
height: int = InputField( | ||
default=1024, | ||
title="Height", | ||
description="The height of the output image", | ||
) | ||
width: int = InputField( | ||
default=1024, | ||
title="Width", | ||
description="The width of the output image", | ||
) | ||
pos_embeds: FluxConditioningField = InputField( | ||
description="Positive Prompt Embeds", | ||
input=Input.Connection, | ||
title="Positive Prompt Embeds", | ||
) | ||
neg_embeds: FluxConditioningField = InputField( | ||
description="Negative Prompt Embeds", | ||
input=Input.Connection, | ||
title="Negative Prompt Embeds", | ||
) | ||
latents: LatentsField = InputField( | ||
description="Latent noise with latent image ids to denoise", | ||
input=Input.Connection, | ||
title="Latent Noise", | ||
) | ||
latent_image_ids: LatentsField = InputField( | ||
description="Latent image ids to denoise", | ||
input=Input.Connection, | ||
title="Latent Image IDs", | ||
) | ||
control: BriaControlNetField | list[BriaControlNetField] | None = InputField( | ||
description="ControlNet", | ||
input=Input.Connection, | ||
title="ControlNet", | ||
default=None, | ||
) | ||
|
||
@torch.no_grad() | ||
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput: | ||
latents = context.tensors.load(self.latents.latents_name) | ||
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name) | ||
pos_embeds = context.tensors.load(self.pos_embeds.conditioning_name) | ||
neg_embeds = context.tensors.load(self.neg_embeds.conditioning_name) | ||
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler}) | ||
|
||
device = None | ||
dtype = None | ||
with ( | ||
context.models.load(self.transformer.transformer) as transformer, | ||
context.models.load(scheduler_identifier) as scheduler, | ||
context.models.load(self.vae.vae) as vae, | ||
context.models.load(self.t5_encoder.text_encoder) as t5_encoder, | ||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer, | ||
): | ||
assert isinstance(transformer, BriaTransformer2DModel) | ||
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler) | ||
assert isinstance(vae, AutoencoderKL) | ||
dtype = transformer.dtype | ||
device = transformer.device | ||
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds)) | ||
|
||
control_model, control_images, control_modes, control_scales = None, None, None, None | ||
if self.control is not None: | ||
control_model, control_images, control_modes, control_scales = self._prepare_multi_control( | ||
context=context, | ||
vae=vae, | ||
width=self.width, | ||
height=self.height, | ||
device=vae.device, | ||
) | ||
|
||
pipeline = BriaControlNetPipeline( | ||
psychedelicious marked this conversation as resolved.
Show resolved
Hide resolved
|
||
transformer=transformer, | ||
scheduler=scheduler, | ||
vae=vae, | ||
text_encoder=t5_encoder, | ||
tokenizer=t5_tokenizer, | ||
controlnet=control_model, | ||
ilanbria marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
pipeline.to(device=transformer.device, dtype=transformer.dtype) | ||
|
||
output_latents = pipeline( | ||
control_image=control_images, | ||
control_mode=control_modes, | ||
width=self.width, | ||
height=self.height, | ||
controlnet_conditioning_scale=control_scales, | ||
num_inference_steps=self.num_steps, | ||
guidance_scale=self.guidance_scale, | ||
latents=latents, | ||
latent_image_ids=latent_image_ids, | ||
prompt_embeds=pos_embeds, | ||
negative_prompt_embeds=neg_embeds, | ||
output_type="latent", | ||
step_callback=_build_step_callback(context), | ||
)[0] | ||
|
||
assert isinstance(output_latents, torch.Tensor) | ||
saved_input_latents_tensor = context.tensors.save(output_latents) | ||
return BriaDenoiseInvocationOutput( | ||
latents=LatentsField(latents_name=saved_input_latents_tensor), height=self.height, width=self.width | ||
) | ||
|
||
def _prepare_multi_control( | ||
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device | ||
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[int], List[float]]: | ||
control = self.control if isinstance(self.control, list) else [self.control] | ||
control_images, control_models, control_modes, control_scales = [], [], [], [] | ||
for controlnet in control: | ||
if controlnet is not None: | ||
control_models.append(context.models.load(controlnet.model).model) | ||
control_modes.append(BriaControlModes[controlnet.mode].value) | ||
control_scales.append(controlnet.conditioning_scale) | ||
try: | ||
control_images.append(context.images.get_pil(controlnet.image.image_name)) | ||
except Exception: | ||
raise FileNotFoundError( | ||
f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline." | ||
) | ||
|
||
control_model = BriaMultiControlNetModel(control_models).to(device) | ||
tensored_control_images, tensored_control_modes = prepare_control_images( | ||
vae=vae, | ||
control_images=control_images, | ||
control_modes=control_modes, | ||
width=width, | ||
height=height, | ||
device=device, | ||
) | ||
return control_model, tensored_control_images, tensored_control_modes, control_scales | ||
|
||
|
||
def _build_step_callback(context: InvocationContext) -> Callable[[PipelineIntermediateState], None]: | ||
def step_callback(state: PipelineIntermediateState) -> None: | ||
context.util.sd_step_callback(state, BaseModelType.Bria) | ||
|
||
return step_callback |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.