Skip to content

[WIP][modular] sdxl methods refactor #12067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,8 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
"""
if state is None:
state = PipelineState()
else:
state = deepcopy(state)

# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ class ComponentSpec:
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
# YiYi TODO: currently required is only used to mark optional components that the block can run without, in the future:
# 1. the spec for an optional component should has lower priority when combined in sequential/auto blocks
# 2. should not need to define default_creation_method for optional components
required: bool = True
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
Expand Down
872 changes: 345 additions & 527 deletions src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Large diffs are not rendered by default.

20 changes: 9 additions & 11 deletions src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast

if block_state.needs_upcasting:
if needs_upcasting:
self.upcast_vae(components)
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
Expand All @@ -117,29 +117,27 @@ def __call__(self, components, state: PipelineState) -> PipelineState:

# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
block_state.has_latents_mean = (
has_latents_mean = (
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
)
block_state.has_latents_std = (
has_latents_std = (
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
)
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents_std = (
latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = (
latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
)
latents = latents * latents_std / components.vae.config.scaling_factor + latents_mean
else:
latents = latents / components.vae.config.scaling_factor

block_state.images = components.vae.decode(latents, return_dict=False)[0]

# cast back to fp16 if needed
if block_state.needs_upcasting:
if needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents
Expand Down
14 changes: 7 additions & 7 deletions src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def intermediate_inputs(self) -> List[str]:

@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)

return components, block_state

Expand Down Expand Up @@ -134,10 +134,10 @@ def check_inputs(components, block_state):
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
self.check_inputs(components, block_state)

block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)
if components.num_channels_unet == 9:
block_state.scaled_latents = torch.cat(
[block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
block_state.latent_model_input = torch.cat(
[block_state.latent_model_input, block_state.mask, block_state.masked_image_latents], dim=1
)

return components, block_state
Expand Down Expand Up @@ -232,7 +232,7 @@ def __call__(
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
block_state.latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=block_state.timestep_cond,
Expand Down Expand Up @@ -410,7 +410,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
mid_block_res_sample = block_state.mid_block_res_sample_zeros
else:
down_block_res_samples, mid_block_res_sample = components.controlnet(
block_state.scaled_latents,
block_state.latent_model_input,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
controlnet_cond=block_state.controlnet_cond,
Expand All @@ -430,7 +430,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
# Predict the noise
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
block_state.latent_model_input,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
timestep_cond=block_state.timestep_cond,
Expand Down
Loading