-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[WIP Modular] More Updates for Custom Code Loading #11969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I see, do you think maybe we should drop the concept difference between "inputs" and "intermediate inputs" algether? all inputs could be just "intermediates", and they can all be modified. https://huggingface.co/docs/diffusers/main/en/modular_diffusers/modular_diffusers_states Taking the example of If we do this for We need to be a little bit more mindful in modifying variables, and recommend to always use a new variable name, or not add to block_state, unless it's intended to replace e.g. if a downstream block needs to use a block like this would mean that it would not be able to access the raw image before the process def __call__(...):
...
block_state.image = process(block_state.image)
block_state.image_latent = prepare_latents(block_state.image)
... but, this might be fine (depends on if def __call__(...):
...
image = process(block_state.image)
block_state.image_latent = prepare_latents(block_state.image)
... On the other hand, the system would be more flexible, and also it simplifies a bit conceptually (I found it not easy to explain the difference between these two states) Let me know what you think! |
IMO this makes sense since we more or less put them into the same group when fetching the block state and they are accessed in the block in the same way. I can't think of any edge cases where this might lead to issues.
We could do this, but suppose you want to insert a custom block that manipulates an input value into a set of existing blocks. You would have to update all subsequent blocks to point to the intermediate input. def __call__(...):
...
image = process(block_state.image)
block_state.image_latent = prepare_latents(block_state.image)
... I think this works well as best practice if we want to leave the input unchanged. IMO the less restrictive we are the better since there isn't a very strong reason to keep the input types separated? |
@@ -322,7 +322,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): | |||
</Tip> | |||
""" | |||
|
|||
config_name = "config.json" | |||
config_name = "modular_config.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with modular_model_index.json
. Also could be cases where a repo contains model weights/config file and a modular pipeline block to load the model. We can avoid conflicts with the configs this way.
Sounds good! let's do this then! I think it will simplify the code a lot. Basically now we'll just have |
Another thing is: should we remove the concept of single |
@yiyixuxu
Then you can review and refactor as you see fit? |
@DN6 sounds good! |
…diffusers into custom-code-updates
values: Dict[str, Any] = field(default_factory=dict) | ||
kwargs_mapping: Dict[str, List[str]] = field(default_factory=dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we've removed the distinction between inputs and intermediates, we can perhaps simplify PipelineState
as well.
All values are stored under value
and all kwargs_type
under kwargs_mapping
. Setting and getting are through set
/get
methods which handle both single string/list inputs for setting/fetching.
@property | ||
def intermediate_outputs(self) -> List[OutputParam]: | ||
"""List of intermediate output parameters. Must be implemented by subclasses.""" | ||
return [] | ||
|
||
def _get_outputs(self): | ||
return self.intermediate_outputs | ||
|
||
@property | ||
def outputs(self) -> List[OutputParam]: | ||
return self._get_outputs() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think these can also be consolidated into just outputs
? Didn't do it here to keep PR scope limited.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes to both:
we should only have outputs
can do it in a next PR: currently outputs
are not used actually, so just have to remove outputs
from the code basee and then change the intermediate_outputs
to outputs
@yiyixuxu made the changes as discussed. LMK your thoughts. |
so "image" is one of the inputs that we override e.g. here we update it with processed image
but may need the raw input later. e.g.
should we rename it to something else? |
btw this is a slow test I ran for sdxl, currently 12, 14, 15, 16 fails, should be easy to fix (missed a few can you fix them and make sure these tests are able to run functionally? don't worry about generation, I can double check before mergge Click to show test script# test modular pipeline (slower test)
import os
import shutil
import torch
from diffusers import (
ControlNetModel,
UNet2DConditionModel,
AutoencoderKL,
ControlNetUnionModel,
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
PerturbedAttentionGuidance,
LayerSkipConfig,
ModularPipeline,
)
from diffusers import StableDiffusionXLAutoBlocks, ComponentsManager, ComponentSpec
from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)
# define device and dtype
device = "cuda:3"
dtype = torch.float16
num_images_per_prompt = 1
# test related parameters
test_lora = False
tests_to_run = [1,2,3,4,5,6,7,8,9,10,11,12,13,14, 15,16]
# define output folder
out_folder = "modular_test_outputs"
os.makedirs(out_folder, exist_ok=True)
# functions for memory info
def reset_memory():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def clear_memory():
torch.cuda.empty_cache()
def print_memory(message=None):
"""
Print detailed GPU memory statistics for a specific device.
Args:
device_id (int): GPU device ID
"""
def print_mem(mem_size, name):
mem_gb = mem_size / 1024**3
mem_mb = mem_size / 1024**2
print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")
allocated_mem = torch.cuda.memory_allocated(device)
reserved_mem = torch.cuda.memory_reserved(device)
mem_on_device = torch.cuda.mem_get_info(device)[0]
peak_mem = torch.cuda.max_memory_allocated(device)
print(f"\nGPU:{device} Memory Status {message}:")
print_mem(allocated_mem, "allocated memory")
print_mem(reserved_mem, "reserved memory")
print_mem(peak_mem, "peak memory")
print_mem(mem_on_device, "mem on device")
# (1)Define inputs
# prompts
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"
# image urls
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
inpaint_img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
inpaint_mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
ip_adapter_image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
# strength/scale etc
strength = 0.9 #img2img strength
inpaint_strength = 0.99 #inpainting strength
controlnet_conditioning_scale = 0.5 # recommended for good generalization)
# get all the image inputs(use a custom block to get images to prepare them)
get_image_step = ModularPipeline.from_pretrained("YiYiXu/image_inputs_0803", trust_remote_code=True)
init_image = get_image_step(image_url=url,output="image")
control_image = get_image_step(image_url=url, processor_id="canny",output="image")
controlnet_union_image = get_image_step(image_url=url, processor_id="lineart_anime",output="image")
inpaint_image = get_image_step(image_url=inpaint_img_url, size=(1024, 1024),output="image")
inpaint_mask = get_image_step(image_url=inpaint_mask_url, size=(1024, 1024),output="image")
ip_adapter_image = get_image_step(image_url=ip_adapter_image_url,output="image")
# (2) create pipelines
auto_blocks = StableDiffusionXLAutoBlocks()
refiner_blocks = StableDiffusionXLAutoBlocks()
# (3) define model components needed for the tests
# specs
refiner_spec = ComponentSpec(name="refiner", type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-refiner-1.0", subfolder="unet")
inpaint_spec = ComponentSpec(name="inpaint", type_hint=UNet2DConditionModel, repo="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="unet")
controlnet_union_spec = ComponentSpec(name="controlnet_union", type_hint=ControlNetUnionModel, repo="brad-twinkl/controlnet-union-sdxl-1.0-promax")
# repos
ip_adapter_repo = "h94/IP-Adapter"
modular_repo = "YiYiXu/modular_demo"
# create guiders: pag/cfg/apg
pag_guider_spec_config = {
"guidance_scale": 5.0,
"perturbed_guidance_scale": 3.0,
"perturbed_guidance_config": LayerSkipConfig(
indices=[2, 3, 7, 8],
fqn="mid_block.attentions.0.transformer_blocks",
skip_attention=False,
skip_ff=False,
skip_attention_scores=True,
),
"start": 0.0,
"stop": 1.0,
}
pag_guider_spec = ComponentSpec(name="guider", type_hint=PerturbedAttentionGuidance, config=pag_guider_spec_config, default_creation_method="from_config")
cfg_guider_spec = ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config")
apg_guider_spec = ComponentSpec(name="guider", type_hint=AdaptiveProjectedGuidance, config={"guidance_scale": 15.0, "adaptive_projected_guidance_momentum": -0.3, "adaptive_projected_guidance_rescale": 12.0, "start": 0.01}, default_creation_method="from_config")
# code to push their to hub
# pag_guider_spec.create().push_to_hub(modular_repo, subfolder="pag_guider")
# cfg_guider_spec.create().push_to_hub(modular_repo, subfolder="cfg_guider")
# apg_guider_spec.create().push_to_hub(modular_repo, subfolder="apg_guider")
# (4) create components manager and load the pipeline
components = ComponentsManager()
auto_pipeline = auto_blocks.init_pipeline(modular_repo, components_manager=components, collection="sdxl_auto")
#auto_pipeline.save_pretrained(modular_repo, push_to_hub=True)
auto_pipeline.load_default_components(torch_dtype=dtype)
print(f" ")
print(f"auto_pipeline:")
print(auto_pipeline)
print(f" loader components:")
for key, value in auto_pipeline.components.items():
if isinstance(value, torch.nn.Module):
print(f" {key}: {value.__class__.__name__}, dtype: {value.dtype}, device: {value.device}")
# enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()
# using auto_pipeline to generate images
# to get info about auto_pipeline and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
print(f" ")
print(f" auto_pipeline:")
print(auto_pipeline)
print(" ")
print(f" ")
print(f" auto_pipeline.blocks:")
print(auto_pipeline.blocks)
print(" ")
# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" auto_pipeline info (default use case: text2img)")
print(auto_pipeline.blocks.get_execution_blocks())
print(" ")
# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None
# assert False
if 1 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test1_out_text2img.png")
clear_memory()
# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
auto_pipeline.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
if 2 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test2_out_text2img_lora.png")
# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
auto_pipeline.unload_lora_weights()
auto_pipeline.update_components(guider=pag_guider_spec)
if 3 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test3_out_text2img_pag.png")
clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)
# test4: SDXL(text2img) with ip_adapter+ pag?
print(f" ")
print(f" running test4: SDXL(text2img) with ip_adapter")
auto_pipeline.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipeline.set_ip_adapter_scale(0.6)
if 4 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
ip_adapter_image=ip_adapter_image,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test4_out_text2img_ip_adapter_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test4_out_text2img_ip_adapter.png")
auto_pipeline.unload_ip_adapter()
clear_memory()
# test5: SDXL(text2img) with controlnet
# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet use case)")
print(auto_pipeline.blocks.get_execution_blocks("control_image"))
print(" ")
print(f" ")
print(f" running test5: SDXL(text2img) with controlnet")
if 5 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test5_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test5_out_text2img_control.png")
clear_memory()
# test6: SDXL(img2img)
print(f" ")
print(f" running test6: SDXL(img2img)")
# let's checkout the sdxl_node info for img2img use case
print(f" auto_pipeline info (img2img use case)")
print(auto_pipeline.blocks.get_execution_blocks("image"))
print(" ")
if 6 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
image=init_image,
strength=strength,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test6_out_img2img_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test6_out_img2img.png")
clear_memory()
# test7: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(auto_pipeline.blocks.get_execution_blocks("image", "control_image"))
print(" ")
print(f" ")
print(f" running test7: SDXL(img2img) with controlnet")
if 7 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
image=init_image,
strength=strength,
num_images_per_prompt=num_images_per_prompt,
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
print(f"image: {image.size}")
image.save(f"{out_folder}/test7_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test7_out_img2img_control.png")
clear_memory()
# test8: img2img with refiner
# test refiner pipeline but not using a repo
refiner_pipeline = refiner_blocks.init_pipeline(components_manager=components, collection="refiner")
print(f" ")
print(f" after setup refiner loader (initial setup, should be empty)")
print(refiner_pipeline)
print(f" ")
refiner_components = components.search_components("!unet|text_encoder|tokenizer|guider", collection="sdxl_auto")
print(f" reuse these components for refiner pipeline:")
for name, component in refiner_components.items():
print(f" {name}: {component.__class__.__name__}")
print(f" ")
refiner_pipeline.update_components(**refiner_components, unet=refiner_spec.load(torch_dtype=dtype), force_zeros_for_empty_prompt=False, requires_aesthetics_score=True)
print(f" ")
print(f" refiner loader after update")
print(refiner_pipeline)
print(f" ")
print(f" ")
print(f" ")
print(f" components info")
print(components)
print(f" ")
print(f" running test8: img2img with refiner (reuse components from components manager)")
if 8 in tests_to_run:
print(f" ")
print(f" step1 run auto pipeline to get latents")
generator = torch.Generator(device="cuda").manual_seed(0)
latents = auto_pipeline(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
denoising_end=0.8,
output="images",
output_type="latent",
)
print(f" ")
print(f" step2 run refiner pipeline to get images")
images_output = refiner_pipeline(
image_latents=latents,
prompt=prompt,
denoising_start=0.8,
generator=generator,
num_images_per_prompt=num_images_per_prompt,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test8_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test8_out_img2img_refiner.png")
clear_memory()
# test9: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" auto_pipeline info (inpainting use case)")
print(auto_pipeline.blocks.get_execution_blocks("mask_image", "image"))
print(" ")
print(f" ")
print(f" running test9: SDXL(inpainting)")
if 9 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
image=inpaint_image,
mask_image=inpaint_mask,
height=1024,
width=1024,
generator=generator,
num_images_per_prompt=num_images_per_prompt,
strength=inpaint_strength, # make sure to use `strength` below 1.0
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test9_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test9_out_inpainting.png")
clear_memory()
# test10: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" auto_pipeline info (inpainting + controlnet use case)")
print(auto_pipeline.blocks.get_execution_blocks("mask_image", "control_image"))
print(" ")
print(f" ")
print(f" running test10: SDXL(inpainting) with controlnet")
if 10 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
control_image=control_image,
image=inpaint_image,
height=1024,
width=1024,
mask_image=inpaint_mask,
num_images_per_prompt=num_images_per_prompt,
controlnet_conditioning_scale=controlnet_conditioning_scale,
strength=inpaint_strength, # make sure to use `strength` below 1.0
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test10_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test10_out_inpainting_control.png")
clear_memory()
# test11: SDXL(inpainting) with inpaint_unet
print(f" ")
print(f" running test11: SDXL(inpainting) with inpaint_unet")
inpaint_unet = inpaint_spec.load(torch_dtype=dtype)
# make a backup to swtich back later
sdxl_unet_spec = ComponentSpec.from_component("unet", auto_pipeline.unet)
auto_pipeline.update_components(unet=inpaint_unet)
if 11 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
image=inpaint_image,
mask_image=inpaint_mask,
height=1024,
width=1024,
generator=generator,
num_images_per_prompt=num_images_per_prompt,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test11_out_inpainting_inpaint_unet.png")
clear_memory()
print(f" after update with inpaint_unet")
print(components)
# test12: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ")
print(f" running test12: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")
if 12 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
image=inpaint_image,
mask_image=inpaint_mask,
height=1024,
width=1024,
generator=generator,
padding_mask_crop=33,
num_images_per_prompt=num_images_per_prompt,
strength=inpaint_strength, # make sure to use `strength` below 1.0
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop.png")
clear_memory()
# test13: apg
print(f" ")
print(f" running test13: apg")
auto_pipeline.update_components(guider=apg_guider_spec, unet=sdxl_unet_spec.load(torch_dtype=dtype))
print(f" autopipeline loader after update with apg guider and unet")
print(auto_pipeline)
print(f" ")
print(f" ")
print(f" components info")
print(components)
print(f" ")
if 13 in tests_to_run:
generator = torch.Generator().manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
generator=generator,
num_inference_steps=20,
num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
height=896,
width=768,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test13_out_apg_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test13_out_apg.png")
clear_memory()
# test13: SDXL(text2img) with controlnet_union
auto_pipeline.update_components(
controlnet=controlnet_union_spec.load(torch_dtype=dtype),
guider=pag_guider_spec
)
print(f" autopipeline loader after update with controlnet (controlnet_union), unet (sdxl_auto), and guider (pag_guider)")
print(auto_pipeline)
print(f" ")
print(f" ")
print(f" components info")
print(components)
print(f" ")
# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet union use case)")
print(auto_pipeline.blocks.get_execution_blocks("control_mode"))
print(" ")
print(f" ")
print(f" running test14: SDXL(text2img) with controlnet_union")
if 14 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
control_mode=[3],
control_image=[controlnet_union_image],
num_images_per_prompt=num_images_per_prompt,
height=1024,
width=1024,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test14_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test14_out_text2img_control_union.png")
clear_memory()
# test15: SDXL(img2img) with controlnet_union
print(f" ")
print(f" auto_pipeline info (img2img controlnet union use case)")
print(auto_pipeline.blocks.get_execution_blocks("image", "control_mode"))
print(" ")
print(f" ")
print(f" running test15: SDXL(img2img) with controlnet_union")
if 15 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
image=init_image,
generator=generator,
control_mode=[3],
control_image=[controlnet_union_image],
num_images_per_prompt=num_images_per_prompt,
height=1024,
width=1024,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test15_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test15_out_img2img_control_union.png")
clear_memory()
# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" auto_pipeline info (inpainting controlnet union use case)")
print(auto_pipeline.blocks.get_execution_blocks("mask", "control_mode"))
print(" ")
print(f" ")
print(f" running test16: SDXL(inpainting) with controlnet_union")
if 16 in tests_to_run:
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
prompt=prompt,
image=init_image,
mask_image=inpaint_mask,
control_image=controlnet_union_image,
control_mode=[3],
height=1024,
width=1024,
generator=generator,
output="images"
)
for i, image in enumerate(images_output):
image.save(f"{out_folder}/test16_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output)} images) to {out_folder}/test16_out_inpainting_control_union.png")
clear_memory()
print_memory("the end")
print(f" components info after the end")
print(components) |
@property | ||
def required_inputs(self) -> List[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we keep this in ModularPipelineBlocks
?
|
||
@property | ||
def doc(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should keep the doc
on ModularPipelineBlocks
too?
if name not in intermediate_inputs: | ||
state.set(name, passed_kwargs.pop(name), kwargs_type) | ||
else: | ||
state.set(name, passed_kwargs[name], kwargs_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if name not in intermediate_inputs: | |
state.set(name, passed_kwargs.pop(name), kwargs_type) | |
else: | |
state.set(name, passed_kwargs[name], kwargs_type) | |
state.set(name, passed_kwargs.pop(name), kwargs_type) |
this will always be True here -> name in intermediate_inputs
we pop other wise you got a warning for expected inputs
e.g. with this, you get
/home/yiyi/diffusers/src/diffusers/modular_pipelines/modular_pipeline.py:2417: UserWarning: Unexpected input dict_keys(['x']) provided. This input will be ignored.
from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
import torch
def make_block(inputs=[], intermediate_outputs=[], block_fn=None, description=None):
class TestBlock(ModularPipelineBlocks):
model_name = "test"
@property
def inputs(self):
return inputs
@property
def intermediate_outputs(self):
return intermediate_outputs
@property
def description(self):
return description if description is not None else ""
def __call__(self, components, state):
block_state = self.get_block_state(state)
if block_fn is not None:
block_state = block_fn(block_state, state)
self.set_block_state(state, block_state)
return components, state
return TestBlock
inputs = [InputParam(name="x", default=0)]
def block_fn(block_state, state):
block_state.x = block_state.x + 1
return block_state
block_cls = make_block(inputs=inputs, block_fn=block_fn, description="test")
pipe = block_cls().init_pipeline()
output = pipe(x=3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks so much @DN6
this is MUCH BETTER!!!
the change looks very nice to me, I left a few comments.
What does this PR do?
In order to support custom code for any block type, we have to invoke downloading/loading of custom code from the
ModularPipelineBlocks
object. This means some of the properties/methods defined in PipelineBlock, SequentialPipelineBlock etc have to be moved into theModularPipelineBlocks
class.Additionally, included a potential change to consolidate how the block state is created from inputs and intermediate inputs. The change is necessary to support the case where a pipeline input might be created in an intermediate step e.g. Using a segmentation model to create an inpainting mask. With the existing approach, simply setting the output of the mask creation step to the desired input value
mask_image
wouldn't allow downstream steps to access it in the block state, becauseget_block_state
runs over required inputs first, and errors out because it's missing. Propose changing this to checking for required values in both input and intermediates before erroring out. There could be edge cases here that I might be missing, but it seems safe to consolidate in this way.Code to test
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.