Skip to content

Commit 1ac5a24

Browse files
Ilan TchenakIlan Tchenak
authored andcommitted
ruff fix
1 parent 282df32 commit 1ac5a24

22 files changed

+998
-753
lines changed

invokeai/app/invocations/bria_controlnet.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,41 @@
1-
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
1+
import cv2
2+
import numpy as np
3+
from PIL import Image
24
from pydantic import BaseModel, Field
3-
from invokeai.invocation_api import ImageOutput, Classification
5+
46
from invokeai.app.invocations.baseinvocation import (
57
BaseInvocation,
68
BaseInvocationOutput,
79
invocation,
810
invocation_output,
911
)
10-
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType, WithBoard, WithMetadata
12+
from invokeai.app.invocations.fields import (
13+
FieldDescriptions,
14+
ImageField,
15+
InputField,
16+
OutputField,
17+
UIType,
18+
WithBoard,
19+
WithMetadata,
20+
)
1121
from invokeai.app.invocations.model import ModelIdentifierField
1222
from invokeai.app.services.shared.invocation_context import InvocationContext
13-
import numpy as np
14-
import cv2
15-
from PIL import Image
16-
23+
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
24+
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
1725
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
18-
from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector, Body, Hand, Face
26+
from invokeai.invocation_api import Classification, ImageOutput
1927

2028
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
2129
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/"
2230

31+
2332
class BriaControlNetField(BaseModel):
2433
image: ImageField = Field(description="The control image")
2534
model: ModelIdentifierField = Field(description="The ControlNet model to use")
2635
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
2736
conditioning_scale: float = Field(description="The weight given to the ControlNet")
2837

38+
2939
@invocation_output("bria_controlnet_output")
3040
class BriaControlNetOutput(BaseInvocationOutput):
3141
"""Bria ControlNet info"""
@@ -49,12 +59,8 @@ class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
4959
control_model: ModelIdentifierField = InputField(
5060
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
5161
)
52-
control_mode: BRIA_CONTROL_MODES = InputField(
53-
default="depth", description="The mode of the ControlNet"
54-
)
55-
control_weight: float = InputField(
56-
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
57-
)
62+
control_mode: BRIA_CONTROL_MODES = InputField(default="depth", description="The mode of the ControlNet")
63+
control_weight: float = InputField(default=1.0, ge=-1, le=2, description="The weight given to the ControlNet")
5864

5965
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
6066
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
@@ -70,7 +76,7 @@ def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
7076
control_image = convert_to_grayscale(image_in)
7177
elif self.control_mode == "tile":
7278
control_image = tile(16, image_in)
73-
79+
7480
control_image = resize_img(control_image)
7581
image_dto = context.images.save(image=control_image)
7682
image_output = ImageOutput.build(image_dto)
@@ -99,6 +105,7 @@ def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
99105
1.7708333333333333: {"width": 1360, "height": 768},
100106
}
101107

108+
102109
def extract_depth(image: Image.Image, context: InvocationContext):
103110
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
104111

@@ -107,6 +114,7 @@ def extract_depth(image: Image.Image, context: InvocationContext):
107114
depth_map = depth_anything_detector.generate_depth(image)
108115
return depth_map
109116

117+
110118
def extract_openpose(image: Image.Image, context: InvocationContext):
111119
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
112120
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
@@ -115,10 +123,10 @@ def extract_openpose(image: Image.Image, context: InvocationContext):
115123
with body_model as body_model, hand_model as hand_model, face_model as face_model:
116124
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
117125
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
118-
126+
119127
processed_image_open_pose = processed_image_open_pose.resize(image.size)
120128
return processed_image_open_pose
121-
129+
122130

123131
def extract_canny(input_image):
124132
image = np.array(input_image)
@@ -130,13 +138,17 @@ def extract_canny(input_image):
130138

131139

132140
def convert_to_grayscale(image):
133-
gray_image = image.convert('L').convert('RGB')
141+
gray_image = image.convert("L").convert("RGB")
134142
return gray_image
135143

144+
136145
def tile(downscale_factor, input_image):
137-
control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.Resampling.NEAREST)
146+
control_image = input_image.resize(
147+
(input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)
148+
).resize(input_image.size, Image.Resampling.NEAREST)
138149
return control_image
139-
150+
151+
140152
def resize_img(control_image):
141153
image_ratio = control_image.width / control_image.height
142154
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))

invokeai/app/invocations/bria_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ class BriaDecoderInvocation(BaseInvocation):
3030
def invoke(self, context: InvocationContext) -> ImageOutput:
3131
latents = context.tensors.load(self.latents.latents_name)
3232
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
33-
33+
3434
with context.models.load(self.vae.vae) as vae:
3535
assert isinstance(vae, AutoencoderKL)
36-
latents = (latents / vae.config.scaling_factor)
36+
latents = latents / vae.config.scaling_factor
3737
latents = latents.to(device=vae.device, dtype=vae.dtype)
38-
38+
3939
decoded_output = vae.decode(latents)
4040
image = decoded_output.sample
41-
41+
4242
# Convert to numpy with proper gradient handling
4343
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
4444
img = Image.fromarray(image)

invokeai/app/invocations/bria_denoiser.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from typing import List, Tuple
2-
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
3-
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
4-
from invokeai.backend.bria.controlnet_utils import prepare_control_images
5-
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
6-
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
72

83
import torch
4+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
95
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
106

7+
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
118
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
129
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
1310
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
1411
from invokeai.app.services.shared.invocation_context import InvocationContext
15-
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
12+
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
13+
from invokeai.backend.bria.controlnet_utils import prepare_control_images
14+
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
1615
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
16+
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
17+
1718

1819
@invocation_output("bria_denoise_output")
1920
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
@@ -80,7 +81,7 @@ class BriaDenoiseInvocation(BaseInvocation):
8081
description="ControlNet",
8182
input=Input.Connection,
8283
title="ControlNet",
83-
default = None,
84+
default=None,
8485
)
8586

8687
@torch.no_grad()
@@ -106,7 +107,7 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
106107
assert isinstance(vae, AutoencoderKL)
107108
dtype = transformer.dtype
108109
device = transformer.device
109-
latents, pos_embeds, neg_embeds = map(lambda x: x.to(device, dtype), (latents, pos_embeds, neg_embeds))
110+
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds))
110111

111112
control_model, control_images, control_modes, control_scales = None, None, None, None
112113
if self.control is not None:
@@ -134,7 +135,7 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
134135
width=1024,
135136
height=1024,
136137
controlnet_conditioning_scale=control_scales,
137-
num_inference_steps=self.num_steps,
138+
num_inference_steps=self.num_steps,
138139
max_sequence_length=128,
139140
guidance_scale=self.guidance_scale,
140141
latents=latents,
@@ -150,36 +151,30 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
150151
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
151152
return BriaDenoiseInvocationOutput(latents=latents_output)
152153

153-
154154
def _prepare_multi_control(
155-
self,
156-
context: InvocationContext,
157-
vae: AutoencoderKL,
158-
width: int,
159-
height: int,
160-
device: torch.device
155+
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device
161156
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
162-
163157
control = self.control if isinstance(self.control, list) else [self.control]
164158
control_images, control_models, control_modes, control_scales = [], [], [], []
165159
for controlnet in control:
166160
if controlnet is not None:
167161
control_models.append(context.models.load(controlnet.model).model)
168-
control_modes.append(BriaControlModes[controlnet.mode].value)
162+
control_modes.append(BriaControlModes[controlnet.mode].value)
169163
control_scales.append(controlnet.conditioning_scale)
170164
try:
171165
control_images.append(context.images.get_pil(controlnet.image.image_name))
172-
except:
173-
raise FileNotFoundError(f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline.")
166+
except Exception:
167+
raise FileNotFoundError(
168+
f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline."
169+
)
174170

175171
control_model = BriaMultiControlNetModel(control_models).to(device)
176172
tensored_control_images, tensored_control_modes = prepare_control_images(
177173
vae=vae,
178-
control_images=control_images,
179-
control_modes=control_modes,
174+
control_images=control_images,
175+
control_modes=control_modes,
180176
width=width,
181177
height=height,
182-
device=device,
183-
)
178+
device=device,
179+
)
184180
return control_model, tensored_control_images, tensored_control_modes, control_scales
185-

invokeai/app/invocations/bria_latent_sampler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import torch
22

3-
from invokeai.app.invocations.fields import Input, InputField
3+
from invokeai.app.invocations.fields import Input, InputField, OutputField
44
from invokeai.app.invocations.model import TransformerField
55
from invokeai.app.invocations.primitives import (
66
BaseInvocationOutput,
77
FieldDescriptions,
8-
Input,
98
LatentsField,
10-
OutputField,
119
)
1210
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
1311
from invokeai.invocation_api import (
1412
BaseInvocation,
1513
Classification,
16-
InputField,
1714
InvocationContext,
1815
invocation,
1916
invocation_output,
@@ -56,7 +53,7 @@ def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutpu
5653

5754
height, width = 1024, 1024
5855
generator = torch.Generator(device=device).manual_seed(self.seed)
59-
56+
6057
num_channels_latents = 4
6158
latents, latent_image_ids = prepare_latents(
6259
batch_size=1,
@@ -66,7 +63,7 @@ def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutpu
6663
dtype=dtype,
6764
device=device,
6865
generator=generator,
69-
)
66+
)
7067

7168
saved_latents_tensor = context.tensors.save(latents)
7269
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)

invokeai/app/invocations/bria_model_loader.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
BaseInvocation,
1111
BaseInvocationOutput,
1212
Classification,
13-
InputField,
1413
InvocationContext,
15-
OutputField,
1614
invocation,
1715
invocation_output,
1816
)

invokeai/app/invocations/bria_text_encoder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
invocation_output,
2020
)
2121

22-
from invokeai.backend.bria.bria_utils import get_t5_prompt_embeds, is_ng_none
23-
2422

2523
@invocation_output("bria_text_encoder_output")
2624
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
@@ -70,7 +68,7 @@ def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
7068
):
7169
assert isinstance(tokenizer, T5TokenizerFast)
7270
assert isinstance(text_encoder, T5EncoderModel)
73-
71+
7472
(prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt(
7573
prompt=self.prompt,
7674
tokenizer=tokenizer,
@@ -81,7 +79,7 @@ def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
8179
max_sequence_length=self.max_length,
8280
lora_scale=1.0,
8381
)
84-
82+
8583
saved_pos_tensor = context.tensors.save(prompt_embeds)
8684
saved_neg_tensor = context.tensors.save(negative_prompt_embeds)
8785
saved_text_ids_tensor = context.tensors.save(text_ids)

invokeai/backend/bria/bria_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def is_ng_none(negative_prompt):
8787
negative_prompt is None
8888
or negative_prompt == ""
8989
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
90-
or (type(negative_prompt) == list and negative_prompt[0] == "")
90+
or (isinstance(negative_prompt, list) and negative_prompt[0] == "")
9191
)
9292

9393

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
__version__ = "0.0.9"
22

3-
from .canny import CannyDetector
4-
from .open_pose import OpenposeDetector
3+
from invokeai.backend.bria.controlnet_aux.canny import CannyDetector as CannyDetector
4+
from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector as OpenposeDetector
55

6+
__all__ = ["CannyDetector", "OpenposeDetector"]
Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
import warnings
2+
23
import cv2
34
import numpy as np
45
from PIL import Image
5-
from ..util import HWC3, resize_image
6+
7+
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
8+
69

710
class CannyDetector:
8-
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs):
11+
def __call__(
12+
self,
13+
input_image=None,
14+
low_threshold=100,
15+
high_threshold=200,
16+
detect_resolution=512,
17+
image_resolution=512,
18+
output_type=None,
19+
**kwargs,
20+
):
921
if "img" in kwargs:
10-
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
22+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning, stacklevel=2)
1123
input_image = kwargs.pop("img")
12-
24+
1325
if input_image is None:
1426
raise ValueError("input_image must be defined.")
1527

@@ -18,19 +30,19 @@ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, dete
1830
output_type = output_type or "pil"
1931
else:
2032
output_type = output_type or "np"
21-
33+
2234
input_image = HWC3(input_image)
2335
input_image = resize_image(input_image, detect_resolution)
2436

2537
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
26-
detected_map = HWC3(detected_map)
27-
38+
detected_map = HWC3(detected_map)
39+
2840
img = resize_image(input_image, image_resolution)
2941
H, W, C = img.shape
3042

3143
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
32-
44+
3345
if output_type == "pil":
3446
detected_map = Image.fromarray(detected_map)
35-
47+
3648
return detected_map

0 commit comments

Comments
 (0)