diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 0ef1d59f4d65..b981117c10da 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1972,6 +1972,8 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = """ if state is None: state = PipelineState() + else: + state = deepcopy(state) # Make a copy of the input kwargs passed_kwargs = kwargs.copy() diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index f2fc015e948f..2547360aa290 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -91,7 +91,10 @@ class ComponentSpec: type_hint: Optional[Type] = None description: Optional[str] = None config: Optional[FrozenDict] = None - # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + # YiYi TODO: currently required is only used to mark optional components that the block can run without, in the future: + # 1. the spec for an optional component should has lower priority when combined in sequential/auto blocks + # 2. should not need to define default_creation_method for optional components + required: bool = True repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) subfolder: Optional[str] = field(default="", metadata={"loading": True}) variant: Optional[str] = field(default=None, metadata={"loading": True}) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index c56f4af1b8a5..df9e913de3f9 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -15,13 +15,11 @@ import inspect from typing import Any, List, Optional, Tuple, Union -import PIL import torch from ...configuration_utils import FrozenDict -from ...guiders import ClassifierFreeGuidance from ...image_processor import VaeImageProcessor -from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel +from ...models import ControlNetModel, ControlNetUnionModel, UNet2DConditionModel from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging @@ -117,84 +115,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def prepare_latents_img2img( - vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True -): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}") - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: - latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: - latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if vae.config.force_upcast: - image = image.float() - vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(vae.encode(image), generator=generator) - - if vae.config.force_upcast: - vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std - else: - init_latents = vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -222,31 +142,37 @@ def intermediate_inputs(self) -> List[str]: "prompt_embeds", required=True, type_hint=torch.Tensor, + kwargs_type="guider_input_fields", description="Pre-generated text embeddings. Can be generated from text_encoder step.", ), InputParam( "negative_prompt_embeds", type_hint=torch.Tensor, + kwargs_type="guider_input_fields", description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", ), InputParam( "pooled_prompt_embeds", required=True, type_hint=torch.Tensor, + kwargs_type="guider_input_fields", description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", ), InputParam( "negative_pooled_prompt_embeds", + kwargs_type="guider_input_fields", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step.", ), InputParam( "ip_adapter_embeds", type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.", ), InputParam( "negative_ip_adapter_embeds", type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.", ), ] @@ -264,42 +190,6 @@ def intermediate_outputs(self) -> List[str]: type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)", ), - OutputParam( - "prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields - description="text embeddings used to guide the image generation", - ), - OutputParam( - "negative_prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields - description="negative text embeddings used to guide the image generation", - ), - OutputParam( - "pooled_prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields - description="pooled text embeddings used to guide the image generation", - ), - OutputParam( - "negative_pooled_prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields - description="negative pooled text embeddings used to guide the image generation", - ), - OutputParam( - "ip_adapter_embeds", - type_hint=List[torch.Tensor], - kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields - description="image embeddings for IP-Adapter", - ), - OutputParam( - "negative_ip_adapter_embeds", - type_hint=List[torch.Tensor], - kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields - description="negative image embeddings for IP-Adapter", - ), ] def check_inputs(self, components, block_state): @@ -419,8 +309,6 @@ def inputs(self) -> List[InputParam]: InputParam("denoising_end"), InputParam("strength", default=0.3), InputParam("denoising_start"), - # YiYi TODO: do we need num_images_per_prompt here? - InputParam("num_images_per_prompt", default=1), ] @property @@ -495,31 +383,29 @@ def get_timesteps(components, num_inference_steps, strength, device, denoising_s def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, - block_state.num_inference_steps, - block_state.device, - block_state.timesteps, - block_state.sigmas, + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + timesteps=block_state.timesteps, + sigmas=block_state.sigmas, ) def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( - components, - block_state.num_inference_steps, - block_state.strength, - block_state.device, + components=components, + num_inference_steps=block_state.num_inference_steps, + strength=block_state.strength, + device=device, denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, ) - block_state.latent_timestep = block_state.timesteps[:1].repeat( - block_state.batch_size * block_state.num_images_per_prompt - ) + block_state.latent_timestep = block_state.timesteps[:1] if ( block_state.denoising_end is not None @@ -527,14 +413,14 @@ def denoising_value_valid(dnv): and block_state.denoising_end > 0 and block_state.denoising_end < 1 ): - block_state.discrete_timestep_cutoff = int( + discrete_timestep_cutoff = int( round( components.scheduler.config.num_train_timesteps - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) block_state.num_inference_steps = len( - list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps)) ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] @@ -580,14 +466,14 @@ def intermediate_outputs(self) -> List[OutputParam]: def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, - block_state.num_inference_steps, - block_state.device, - block_state.timesteps, - block_state.sigmas, + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + timesteps=block_state.timesteps, + sigmas=block_state.sigmas, ) if ( @@ -596,14 +482,14 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline and block_state.denoising_end > 0 and block_state.denoising_end < 1 ): - block_state.discrete_timestep_cutoff = int( + discrete_timestep_cutoff = int( round( components.scheduler.config.num_train_timesteps - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) block_state.num_inference_steps = len( - list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps)) ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] @@ -627,7 +513,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), InputParam( @@ -654,7 +539,6 @@ def intermediate_inputs(self) -> List[str]: ), InputParam( "latent_timestep", - required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", ), @@ -675,7 +559,11 @@ def intermediate_inputs(self) -> List[str]: type_hint=torch.Tensor, description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.", ), - InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs, can be generated in input step.", + ), ] @property @@ -691,208 +579,98 @@ def intermediate_outputs(self) -> List[str]: ), ] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from @staticmethod - def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument - def prepare_latents_inpaint( - self, - components, - batch_size, - num_channels_latents, - height, - width, + def prepare_latents( + image_latents, + scheduler, dtype, device, generator, - latents=None, - image=None, timestep=None, is_strength_max=True, add_noise=True, - return_noise=False, - return_image_latents=False, ): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) + batch_size = image_latents.shape[0] + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(components, image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if add_noise: + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * components.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) + latents = latents * scheduler.init_noise_sigma if is_strength_max else latents - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image else: - masked_image_latents = None + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + return latents, noise - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) + def check_inputs(self, batch_size, image_latents, mask, masked_image_latents): + if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): + raise ValueError( + f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}" + ) - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + if not (mask.shape[0] == 1 or mask.shape[0] == batch_size): + raise ValueError(f"mask should have have batch size 1 or {batch_size}, but got {mask.shape[0]}") - return mask, masked_image_latents + if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size): + raise ValueError( + f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}" + ) @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device + self.check_inputs( + batch_size=block_state.batch_size, + image_latents=block_state.image_latents, + mask=block_state.mask, + masked_image_latents=block_state.masked_image_latents, + ) - block_state.is_strength_max = block_state.strength == 1.0 + dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype + device = components._execution_device - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(components, "unet") and components.unet is not None: - if components.unet.config.in_channels == 4: - block_state.masked_image_latents = None + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt - block_state.add_noise = True if block_state.denoising_start is None else False + block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) + block_state.image_latents = block_state.image_latents.repeat( + final_batch_size // block_state.image_latents.shape[0], 1, 1, 1 + ) - block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor - block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + # 7. Prepare mask latent variables + block_state.mask = block_state.mask.to(device=device, dtype=dtype) + block_state.mask = block_state.mask.repeat(final_batch_size // block_state.mask.shape[0], 1, 1, 1) - block_state.latents, block_state.noise = self.prepare_latents_inpaint( - components, - block_state.batch_size * block_state.num_images_per_prompt, - components.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - image=block_state.image_latents, - timestep=block_state.latent_timestep, - is_strength_max=block_state.is_strength_max, - add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, + block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype) + block_state.masked_image_latents = block_state.masked_image_latents.repeat( + final_batch_size // block_state.masked_image_latents.shape[0], 1, 1, 1 ) - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image_latents, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, + if block_state.latent_timestep is not None: + block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) + block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype) + + is_strength_max = block_state.strength == 1.0 + add_noise = True if block_state.denoising_start is None else False + + block_state.latents, block_state.noise = self.prepare_latents( + image_latents=block_state.image_latents, + scheduler=components.scheduler, + dtype=dtype, + device=device, + generator=block_state.generator, + timestep=block_state.latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, ) self.set_block_state(state, block_state) @@ -906,7 +684,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", EulerDiscreteScheduler), ] @@ -917,7 +694,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), ] @@ -928,7 +704,6 @@ def intermediate_inputs(self) -> List[InputParam]: InputParam("generator"), InputParam( "latent_timestep", - required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", ), @@ -944,7 +719,7 @@ def intermediate_inputs(self) -> List[InputParam]: type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), ] @property @@ -955,26 +730,61 @@ def intermediate_outputs(self) -> List[OutputParam]: ) ] + def check_inputs(self, batch_size, image_latents): + if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): + raise ValueError( + f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}" + ) + + @staticmethod + def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None): + if isinstance(generator, list) and len(generator) != image_latents.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {image_latents.shape[0]}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) + latents = scheduler.add_noise(image_latents, noise, timestep) + + return latents + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - block_state.add_noise = True if block_state.denoising_start is None else False - if block_state.latents is None: - block_state.latents = prepare_latents_img2img( - components.vae, - components.scheduler, - block_state.image_latents, - block_state.latent_timestep, - block_state.batch_size, - block_state.num_images_per_prompt, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.add_noise, + self.check_inputs( + batch_size=block_state.batch_size, + image_latents=block_state.image_latents, + ) + + dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype + device = components._execution_device + + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + + block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) + block_state.image_latents = block_state.image_latents.repeat( + final_batch_size // block_state.image_latents.shape[0], 1, 1, 1 + ) + + if block_state.latent_timestep is not None: + block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) + block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype) + + add_noise = True if block_state.denoising_start is None else False + + if add_noise: + block_state.latents = self.prepare_latents( + image_latents=block_state.image_latents, + scheduler=components.scheduler, + timestep=block_state.latent_timestep, + dtype=dtype, + device=device, + generator=block_state.generator, ) + else: + block_state.latents = block_state.image_latents self.set_block_state(state, block_state) @@ -988,7 +798,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("vae", AutoencoderKL), ] @property @@ -1026,15 +835,15 @@ def intermediate_outputs(self) -> List[OutputParam]: ] @staticmethod - def check_inputs(components, block_state): + def check_inputs(components, height, width): if ( - block_state.height is not None - and block_state.height % components.vae_scale_factor != 0 - or block_state.width is not None - and block_state.width % components.vae_scale_factor != 0 + height is not None + and height % components.vae_scale_factor != 0 + or width is not None + and width % components.vae_scale_factor != 0 ): raise ValueError( - f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {height} and {width}." ) @staticmethod @@ -1065,26 +874,27 @@ def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - if block_state.dtype is None: - block_state.dtype = components.vae.dtype + dtype = block_state.dtype + if dtype is None: + dtype = components.unet.dtype if hasattr(components, "unet") else torch.float32 - block_state.device = components._execution_device + device = components._execution_device - self.check_inputs(components, block_state) + self.check_inputs(components, block_state.height, block_state.width) + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + height = block_state.height or components.default_sample_size * components.vae_scale_factor + width = block_state.width or components.default_sample_size * components.vae_scale_factor - block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor - block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor - block_state.num_channels_latents = components.num_channels_latents block_state.latents = self.prepare_latents( - components, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, + comp=components, + batch_size=final_batch_size, + num_channels_latents=components.num_channels_latents, + height=height, + width=width, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, ) self.set_block_state(state, block_state) @@ -1105,12 +915,6 @@ def expected_configs(self) -> List[ConfigSpec]: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config", - ), ] @property @@ -1121,14 +925,14 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("original_size"), - InputParam("target_size"), InputParam("negative_original_size"), + InputParam("target_size"), InputParam("negative_target_size"), InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), + InputParam("num_images_per_prompt", default=1), ] @property @@ -1152,6 +956,11 @@ def intermediate_inputs(self) -> List[InputParam]: type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs, can be generated in input step.", + ), ] @property @@ -1169,7 +978,6 @@ def intermediate_outputs(self) -> List[OutputParam]: kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process", ), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] @staticmethod @@ -1225,52 +1033,24 @@ def _get_add_time_ids( return add_time_ids, add_neg_time_ids - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device - block_state.vae_scale_factor = components.vae_scale_factor + device = components._execution_device + dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * block_state.vae_scale_factor - block_state.width = block_state.width * block_state.vae_scale_factor + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + # define original_size/negative_original_size/target_size/negative_target_size + # - they are all defaulted to None + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + block_state.original_size = block_state.original_size or (height, width) + block_state.target_size = block_state.target_size or (height, width) if block_state.negative_original_size is None: block_state.negative_original_size = block_state.original_size @@ -1287,30 +1067,13 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline block_state.negative_original_size, block_state.negative_crops_coords_top_left, block_state.negative_target_size, - dtype=block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to( + device=device ) - block_state.add_time_ids = block_state.add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) - - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( - block_state.batch_size * block_state.num_images_per_prompt - ) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) self.set_block_state(state, block_state) return components, state @@ -1327,12 +1090,6 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config", - ), ] @property @@ -1368,6 +1125,11 @@ def intermediate_inputs(self) -> List[InputParam]: type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), ] @property @@ -1385,7 +1147,6 @@ def intermediate_outputs(self) -> List[OutputParam]: kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process", ), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] @staticmethod @@ -1408,6 +1169,92 @@ def _get_add_time_ids( add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype + text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor + original_size = block_state.original_size or (height, width) + target_size = block_state.target_size or (height, width) + + block_state.add_time_ids = self._get_add_time_ids( + components, + original_size, + block_state.crops_coords_top_left, + target_size, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to( + device=device + ) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLLCMStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "Step that prepares the timestep cond input for latent consistency models" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam("embedded_guidance_scale"), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 @@ -1439,61 +1286,31 @@ def get_guidance_scale_embedding( assert emb.shape == (w.shape[0], embedding_dim) return emb + def check_input(self, unet, embedded_guidance_scale): + if embedded_guidance_scale is not None and unet.config.time_cond_proj_dim is None: + raise ValueError( + f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None" + ) + + if embedded_guidance_scale is None and unet.config.time_cond_proj_dim is not None: + raise ValueError("unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None") + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - - block_state.add_time_ids = self._get_add_time_ids( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - if block_state.negative_original_size is not None and block_state.negative_target_size is not None: - block_state.negative_add_time_ids = self._get_add_time_ids( - components, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - else: - block_state.negative_add_time_ids = block_state.add_time_ids + device = components._execution_device + dtype = block_state.dtype if block_state.dtype is not None else components.unet.dtype - block_state.add_time_ids = block_state.add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( - block_state.batch_size * block_state.num_images_per_prompt, 1 - ).to(device=block_state.device) + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt # Optionally get Guidance Scale Embedding for LCM block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( - block_state.batch_size * block_state.num_images_per_prompt - ) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) + + guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size) + block_state.timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=device, dtype=dtype) self.set_block_state(state, block_state) return components, state @@ -1613,14 +1430,18 @@ def prepare_control_image( def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # (1) prepare controlnet inputs - block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(components.controlnet) + device = components._execution_device + dtype = components.controlnet.dtype + + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # (1) prepare controlnet inputs + # (1.1) # control_guidance_start/control_guidance_end (align format) if not isinstance(block_state.control_guidance_start, list) and isinstance( @@ -1645,37 +1466,35 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline ) # (1.2) - # controlnet_conditioning_scale (align format) + # conditioning_scale (align format) if isinstance(controlnet, MultiControlNetModel) and isinstance( block_state.controlnet_conditioning_scale, float ): - block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len( - controlnet.nets - ) + block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + else: + block_state.conditioning_scale = block_state.controlnet_conditioning_scale # (1.3) - # global_pool_conditions - block_state.global_pool_conditions = ( + # guess_mode + global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) - # (1.4) - # guess_mode - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or global_pool_conditions - # (1.5) - # control_image + # (1.4) + # controlnet_cond if isinstance(controlnet, ControlNetModel): - block_state.control_image = self.prepare_control_image( + block_state.controlnet_cond = self.prepare_control_image( components, image=block_state.control_image, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, + width=width, + height=height, + batch_size=final_batch_size, num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, + device=device, + dtype=dtype, crops_coords=block_state.crops_coords, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1685,18 +1504,18 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline control_image = self.prepare_control_image( components, image=control_image_, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, + width=width, + height=height, + batch_size=final_batch_size, num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, + device=device, + dtype=dtype, crops_coords=block_state.crops_coords, ) control_images.append(control_image) - block_state.control_image = control_images + block_state.controlnet_cond = control_images else: assert False @@ -1710,9 +1529,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - self.set_block_state(state, block_state) return components, state @@ -1852,9 +1668,10 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline device = components._execution_device dtype = block_state.dtype or components.controlnet.dtype - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt # control_guidance_start/control_guidance_end (align format) if not isinstance(block_state.control_guidance_start, list) and isinstance( @@ -1871,8 +1688,8 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline ] # guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or global_pool_conditions # control_image if not isinstance(block_state.control_image, list): @@ -1885,30 +1702,32 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline raise ValueError("Expected len(control_image) == len(control_type)") # control_type - block_state.num_control_type = controlnet.config.num_control_type - block_state.control_type = [0 for _ in range(block_state.num_control_type)] + num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(num_control_type)] for control_idx in block_state.control_mode: block_state.control_type[control_idx] = 1 block_state.control_type = torch.Tensor(block_state.control_type) - block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=dtype) repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - # prepare control_image + # prepare controlnet_cond + block_state.controlnet_cond = [] for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( + control_image = self.prepare_control_image( components, image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, + width=width, + height=height, + batch_size=final_batch_size, num_images_per_prompt=block_state.num_images_per_prompt, device=device, dtype=dtype, crops_coords=block_state.crops_coords, ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + _, _, height, width = control_image.shape + block_state.controlnet_cond.append(control_image) # controlnet_keep block_state.controlnet_keep = [] @@ -1921,7 +1740,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline ) ) block_state.control_type_idx = block_state.control_mode - block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index e9f627636e8c..f68b0be4b60e 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -105,9 +105,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState: if not block_state.output_type == "latent": latents = block_state.latents # make sure the VAE is in float32 mode, as it overflows in float16 - block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast - if block_state.needs_upcasting: + if needs_upcasting: self.upcast_vae(components) latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != components.vae.dtype: @@ -117,29 +117,27 @@ def __call__(self, components, state: PipelineState) -> PipelineState: # unscale/denormalize the latents # denormalize with the mean and std if available and not None - block_state.has_latents_mean = ( + has_latents_mean = ( hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None ) - block_state.has_latents_std = ( + has_latents_std = ( hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None ) - if block_state.has_latents_mean and block_state.has_latents_std: - block_state.latents_mean = ( + if has_latents_mean and has_latents_std: + latents_mean = ( torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - block_state.latents_std = ( + latents_std = ( torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - latents = ( - latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean - ) + latents = latents * latents_std / components.vae.config.scaling_factor + latents_mean else: latents = latents / components.vae.config.scaling_factor block_state.images = components.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed - if block_state.needs_upcasting: + if needs_upcasting: components.vae.to(dtype=torch.float16) else: block_state.images = block_state.latents diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 7fe4a472eec3..871fafd0248b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -67,7 +67,7 @@ def intermediate_inputs(self) -> List[str]: @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t) return components, block_state @@ -134,10 +134,10 @@ def check_inputs(components, block_state): def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t) if components.num_channels_unet == 9: - block_state.scaled_latents = torch.cat( - [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1 + block_state.latent_model_input = torch.cat( + [block_state.latent_model_input, block_state.mask, block_state.masked_image_latents], dim=1 ) return components, block_state @@ -232,7 +232,7 @@ def __call__( # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, + block_state.latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=block_state.timestep_cond, @@ -410,7 +410,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl mid_block_res_sample = block_state.mid_block_res_sample_zeros else: down_block_res_samples, mid_block_res_sample = components.controlnet( - block_state.scaled_latents, + block_state.latent_model_input, t, encoder_hidden_states=guider_state_batch.prompt_embeds, controlnet_cond=block_state.controlnet_cond, @@ -430,7 +430,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl # Predict the noise # store the noise_pred in guider_state_batch so we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, + block_state.latent_model_input, t, encoder_hidden_states=guider_state_batch.prompt_embeds, timestep_cond=block_state.timestep_cond, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index bd0e962140e8..2c27bb8ad868 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple import torch +from PIL import Image from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -57,6 +58,91 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +def get_clip_prompt_embeds( + prompt, + text_encoder, + tokenizer, + device, + clip_skip=None, + max_length=None, +): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length if max_length is not None else tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only using the pooled output of the text_encoder_2, which has 2 dimensions + # (pooled output for text_encoder has 3 dimensions) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + return prompt_embeds, pooled_prompt_embeds + + +# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components +def encode_vae_image( + image: torch.Tensor, vae: AutoencoderKL, generator: torch.Generator, dtype: torch.dtype, device: torch.device +): + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + + image = image.to(device=device, dtype=dtype) + + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != image.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {image.shape[0]}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + image_latents = (image_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + image_latents = vae.config.scaling_factor * image_latents + + return image_latents + + class StableDiffusionXLIPAdapterStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -86,6 +172,7 @@ def expected_components(self) -> List[ComponentSpec]: ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config", + required=False, ), ] @@ -103,10 +190,16 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam( + "ip_adapter_embeds", + type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", + description="IP adapter image embeddings", + ), OutputParam( "negative_ip_adapter_embeds", - type_hint=torch.Tensor, + type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", description="Negative IP adapter image embeddings", ), ] @@ -137,79 +230,35 @@ def encode_image(components, image, device, num_images_per_prompt, output_hidden return image_embeds, uncond_image_embeds - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, - components, - ip_adapter_image, - ip_adapter_image_embeds, - device, - num_images_per_prompt, - prepare_unconditional_embeds, - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device + device = components._execution_device - block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - components, - ip_adapter_image=block_state.ip_adapter_image, - ip_adapter_image_embeds=None, - device=block_state.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, - ) - if block_state.prepare_unconditional_embeds: + block_state.ip_adapter_embeds = [] + if components.requires_unconditional_embeds: block_state.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(block_state.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - block_state.negative_ip_adapter_embeds.append(negative_image_embeds) - block_state.ip_adapter_embeds[i] = image_embeds + + if not isinstance(block_state.ip_adapter_image, list): + block_state.ip_adapter_image = [block_state.ip_adapter_image] + + if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(block_state.ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + block_state.ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + block_state.ip_adapter_embeds.append(single_image_embeds[None, :]) + if components.requires_unconditional_embeds: + block_state.negative_ip_adapter_embeds.append(single_negative_image_embeds[None, :]) self.set_block_state(state, block_state) return components, state @@ -225,15 +274,16 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder", CLIPTextModel, required=False), ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer", CLIPTokenizer, required=False), ComponentSpec("tokenizer_2", CLIPTokenizer), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config", + required=False, ), ] @@ -244,7 +294,7 @@ def expected_configs(self) -> List[ConfigSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("prompt"), + InputParam("prompt", required=True), InputParam("prompt_2"), InputParam("negative_prompt"), InputParam("negative_prompt_2"), @@ -282,15 +332,22 @@ def intermediate_outputs(self) -> List[OutputParam]: ] @staticmethod - def check_inputs(block_state): - if block_state.prompt is not None and ( - not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - elif block_state.prompt_2 is not None and ( - not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list) + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if negative_prompt_2 is not None and ( + not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list) ): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + raise ValueError(f"`negative_prompt_2` has to be of type `str` or `list` but is {type(negative_prompt_2)}") @staticmethod def encode_prompt( @@ -298,14 +355,9 @@ def encode_prompt( prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prepare_unconditional_embeds: bool = True, + requires_unconditional_embeds: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): @@ -331,52 +383,17 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - device = device or components._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): - components._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if components.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) - else: - scale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) - else: - scale_lora_layers(components.text_encoder_2, lora_scale) + dtype = components.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] + batch_size = len(prompt) # Define tokenizers and text encoders tokenizers = ( @@ -389,58 +406,56 @@ def encode_prompt( if components.text_encoder is not None else [components.text_encoder_2] ) + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - prompt = components.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] + # dynamically adjust the LoRA scale + for text_encoder in text_encoders: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(text_encoder, lora_scale) else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + scale_lora_layers(text_encoder, lora_scale) + + # Define prompts + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + prompts = [prompt, prompt_2] + + # generate prompt_embeds & pooled_prompt_embeds + prompt_embeds_list = [] + pooled_prompt_embeds_list = [] + + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds( + prompt=prompt, + text_encoder=text_encoder, + tokenizer=tokenizer, + device=device, + clip_skip=clip_skip, + max_length=tokenizer.model_max_length, + ) + + prompt_embeds_list.append(prompt_embeds) + if pooled_prompt_embeds.ndim == 2: + pooled_prompt_embeds_list.append(pooled_prompt_embeds) - prompt_embeds_list.append(prompt_embeds) + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = torch.concat(pooled_prompt_embeds_list, dim=0) - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + negative_prompt_embeds = None + negative_pooled_prompt_embeds = None - # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + # generate negative_prompt_embeds & negative_pooled_prompt_embeds + if requires_unconditional_embeds and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif prepare_unconditional_embeds and negative_prompt_embeds is None: + elif requires_unconditional_embeds: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt @@ -451,87 +466,52 @@ def encode_prompt( ) uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): + if batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] + if batch_size != len(negative_prompt_2): + raise ValueError( + f"`negative_prompt_2`: {negative_prompt_2} has batch size {len(negative_prompt_2)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt_2` matches" + " the batch size of `prompt`." + ) + uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] + negative_pooled_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(components, TextualInversionLoaderMixin): negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", + negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds( + prompt=negative_prompt, + text_encoder=text_encoder, + tokenizer=tokenizer, + device=device, + clip_skip=None, max_length=max_length, - truncation=True, - return_tensors="pt", ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - negative_prompt_embeds_list.append(negative_prompt_embeds) + if negative_pooled_prompt_embeds.ndim == 2: + negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0) - if components.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device) + if requires_unconditional_embeds: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype=dtype, device=device) - if prepare_unconditional_embeds: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if components.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=components.text_encoder_2.dtype, device=device - ) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if prepare_unconditional_embeds: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if components.text_encoder is not None: + for text_encoder in text_encoders: if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder_2, lora_scale) + unscale_lora_layers(text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -539,13 +519,15 @@ def encode_prompt( def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) - self.check_inputs(block_state) - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device + self.check_inputs( + block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2 + ) + + device = components._execution_device # Encode input prompt - block_state.text_encoder_lora_scale = ( + lora_scale = ( block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None @@ -557,18 +539,13 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline block_state.negative_pooled_prompt_embeds, ) = self.encode_prompt( components, - block_state.prompt, - block_state.prompt_2, - block_state.device, - 1, - block_state.prepare_unconditional_embeds, - block_state.negative_prompt, - block_state.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=block_state.text_encoder_lora_scale, + prompt=block_state.prompt, + prompt_2=block_state.prompt_2, + device=device, + requires_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + negative_prompt_2=block_state.negative_prompt_2, + lora_scale=lora_scale, clip_skip=block_state.clip_skip, ) # Add outputs @@ -599,8 +576,6 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("image", required=True), - InputParam("height"), - InputParam("width"), ] @property @@ -608,11 +583,6 @@ def intermediate_inputs(self) -> List[InputParam]: return [ InputParam("generator"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam( - "preprocess_kwargs", - type_hint=Optional[dict], - description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", - ), ] @property @@ -625,65 +595,18 @@ def intermediate_outputs(self) -> List[OutputParam]: ) ] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} - block_state.device = components._execution_device - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.image = components.image_processor.preprocess( - block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs - ) - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + device = components._execution_device + dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.batch_size = block_state.image.shape[0] + image = components.image_processor.preprocess(block_state.image) - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" - f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." - ) - - block_state.image_latents = self._encode_vae_image( - components, image=block_state.image, generator=block_state.generator + # Encode image into latents + block_state.image_latents = encode_vae_image( + image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device ) self.set_block_state(state, block_state) @@ -741,7 +664,6 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image" ), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), OutputParam( "masked_image_latents", type_hint=torch.Tensor, @@ -752,150 +674,82 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask", ), + OutputParam( + "mask", + type_hint=torch.Tensor, + description="The mask to apply on the latents for the inpainting generation.", + ), ] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + def check_inputs(self, image, mask_image, padding_mask_crop): + if padding_mask_crop is not None and not isinstance(image, Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) - return mask, masked_image_latents + if padding_mask_crop is not None and not isinstance(mask_image, Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type {type(mask_image)}." + ) @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device + self.check_inputs(block_state.image, block_state.mask_image, block_state.padding_mask_crop) - if block_state.height is None: - block_state.height = components.default_height - if block_state.width is None: - block_state.width = components.default_width + dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + device = components._execution_device + + height = block_state.height if block_state.height is not None else components.default_height + width = block_state.width if block_state.width is not None else components.default_width if block_state.padding_mask_crop is not None: block_state.crops_coords = components.mask_processor.get_crop_region( - block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop + mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop ) - block_state.resize_mode = "fill" + resize_mode = "fill" else: block_state.crops_coords = None - block_state.resize_mode = "default" + resize_mode = "default" - block_state.image = components.image_processor.preprocess( + image = components.image_processor.preprocess( block_state.image, - height=block_state.height, - width=block_state.width, + height=height, + width=width, crops_coords=block_state.crops_coords, - resize_mode=block_state.resize_mode, + resize_mode=resize_mode, ) - block_state.image = block_state.image.to(dtype=torch.float32) - block_state.mask = components.mask_processor.preprocess( + image = image.to(dtype=torch.float32) + + mask_image = components.mask_processor.preprocess( block_state.mask_image, - height=block_state.height, - width=block_state.width, - resize_mode=block_state.resize_mode, + height=height, + width=width, + resize_mode=resize_mode, crops_coords=block_state.crops_coords, ) - block_state.masked_image = block_state.image * (block_state.mask < 0.5) - block_state.batch_size = block_state.image.shape[0] - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - block_state.image_latents = self._encode_vae_image( - components, image=block_state.image, generator=block_state.generator + masked_image = image * (mask_image < 0.5) + + # Prepare image latent variables + block_state.image_latents = encode_vae_image( + image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device ) - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image, - block_state.batch_size, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, + # Prepare masked image latent variables + block_state.masked_image_latents = encode_vae_image( + image=masked_image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device + ) + + # resize mask to match the image latents + _, _, height_latents, width_latents = block_state.image_latents.shape + block_state.mask = torch.nn.functional.interpolate( + mask_image, + size=(height_latents, width_latents), ) + block_state.mask = block_state.mask.to(dtype=dtype, device=device) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py index c9033856bcc0..93998ab6cda0 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py @@ -23,6 +23,7 @@ StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLInputStep, + StableDiffusionXLLCMStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, @@ -79,6 +80,16 @@ def description(self): return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" +class StableDiffusionXLAutoLCMStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLLCMStep] + block_names = ["lcm"] + block_trigger_inputs = ["embedded_guidance_scale"] + + @property + def description(self): + return "Run LCM step if `latents` is provided. This step should be placed before the 'input' step.\n" + + # before_denoise: text2img class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [ @@ -262,6 +273,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLAutoLCMStep, StableDiffusionXLAutoControlNetInputStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep, @@ -271,6 +283,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): "ip_adapter", "image_encoder", "before_denoise", + "lcm", "controlnet_input", "denoise", "decoder", @@ -286,6 +299,7 @@ def description(self): + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + "- for text-to-image generation, all you need to provide is `prompt`" + + "- to run the latent consistency models workflow, you need to provide `embedded_guidance_scale`" ) @@ -357,6 +371,12 @@ def description(self): ] ) +LCM_BLOCKS = InsertableDict( + [ + ("lcm", StableDiffusionXLAutoLCMStep), + ] +) + AUTO_BLOCKS = InsertableDict( [ ("text_encoder", StableDiffusionXLTextEncoderStep), @@ -376,5 +396,6 @@ def description(self): "inpaint": INPAINT_BLOCKS, "controlnet": CONTROLNET_BLOCKS, "ip_adapter": IP_ADAPTER_BLOCKS, + "lcm": LCM_BLOCKS, "auto": AUTO_BLOCKS, } diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index fc030fae56fb..c169786f1661 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -90,6 +90,23 @@ def num_channels_latents(self): num_channels_latents = self.vae.config.latent_channels return num_channels_latents + @property + def requires_unconditional_embeds(self): + # by default, always prepare unconditional embeddings + requires_unconditional_embeds = True + + if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is not None: + # LCM + requires_unconditional_embeds = False + + elif hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider.num_conditions > 1 + + elif not hasattr(self, "guider") or self.guider is None: + requires_unconditional_embeds = False + + return requires_unconditional_embeds + # YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks # auto_docstring