From 7d5901d2ad083b0be99272d5316e44e716d7e7ed Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 22 Jul 2025 19:28:42 -0700 Subject: [PATCH 01/14] Initial commit implementing frequency-decoupled guidance (FDG) as a guider --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 2 + .../guiders/frequency_decoupled_guidance.py | 215 ++++++++++++++++++ 3 files changed, 219 insertions(+) create mode 100644 src/diffusers/guiders/frequency_decoupled_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 30d497892ff5..7aa2ff0ed217 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -139,6 +139,7 @@ "AutoGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", + "FrequencyDecoupledGuidance", "PerturbedAttentionGuidance", "SkipLayerGuidance", "SmoothedEnergyGuidance", @@ -797,6 +798,7 @@ AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, + FrequencyDecoupledGuidance, PerturbedAttentionGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 1c288f00f084..23cb7a0a7157 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -22,6 +22,7 @@ from .auto_guidance import AutoGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance @@ -32,6 +33,7 @@ AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, + FrequencyDecoupledGuidance, PerturbedAttentionGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py new file mode 100644 index 000000000000..daa7164c5cf8 --- /dev/null +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -0,0 +1,215 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import kornia +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +def project(v0: torch.Tensor, v1: torch.Tensor): + """ + Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from + paper (Algorithm 2). + """ + # v0 shape: [B, C, H, W] + # v1 shape: [B, C, H, W] + dtype = v0.dtype + v0, v1 = v0.double(), v1.double() + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel.to(dtype), v0_orthogonal.to(dtype) + + +def build_image_from_pyramid(pyramid): + """ + Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper + (Algorihtm 2). + """ + img = pyramid[-1] + for i in range(len(pyramid) - 2, -1, -1): + img = kornia.geometry.pyrup(img) + pyramid[i] + return img + + +class FrequencyDecoupledGuidance(BaseGuidance): + """ + Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper + proposes scaling and shifting the conditional distribution based on the difference between conditional and + unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale_low (`float`, defaults to `5.0`): + The scale parameter for frequency-decoupled guidance for low-frequency components. Higher values result in + stronger conditioning on the text prompt, while lower values allow for more freedom in generation. Higher + values may lead to saturation and deterioration of image quality. The FDG authors recommend + `guidance_scale_low < guidance_scale_high`. + guidance_scale_high (`float`, defaults to `10.0`): + The scale parameter for frequency-decoupled guidance for high-frequency components. Higher values result in + stronger conditioning on the text prompt, while lower values allow for more freedom in generation. Higher + values may lead to saturation and deterioration of image quality. The FDG authors recommend + `guidance_scale_low < guidance_scale_high`. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + parallel_weights_low (`float`, defaults to `1.0`): + Optional weight for the parallel component of the low-frequency component of the projected CFG shift. + The default value of `1.0` corresponds to using the normal CFG shift (that is, equal weights for the + parallel and orthogonal components). + parallel_weights_high (`float`, defaults to `1.0`): + Optional weight for the parallel component of the high-frequency component of the projected CFG shift. + The default value of `1.0` corresponds to using the normal CFG shift (that is, equal weights for the + parallel and orthogonal components). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale_low: float = 5.0, + guidance_scale_high: float = 10.0, + guidance_rescale: float = 0.0, + parallel_weights_low: float = 1.0, + parallel_weights_high: float = 1.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale_low = guidance_scale_low + self.guidance_scale_high = guidance_scale_high + self.guidance_rescale = guidance_rescale + # Split the frequency components into 2 levels: low-frequency and high-frequency + # For now, hardcoded + self.levels = 2 + + self.parallel_weights_low = parallel_weights_low + self.parallel_weights_high = parallel_weights_high + + self.use_original_formulation = use_original_formulation + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_fdg_enabled(): + pred = pred_cond + else: + # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional components. + pred_cond_pyramid = kornia.geometry.transform.build_laplacian_pyramid(pred_cond, self.levels) + pred_uncond_pyramid = kornia.geometry.transform.build_laplacian_pyramid(pred_uncond, self.levels) + + # From high freq to low, following the paper implementation + pred_guided_pyramid = [] + guidance_scales = [self.guidance_scale_high, self.guidance_scale_low] + parallel_weights = [self.parallel_weights_high, self.parallel_weights_low] + parameters = zip(guidance_scales, parallel_weights) + for level, (guidance_scale, parallel_weight) in enumerate(parameters): + shift = pred_cond_pyramid[level] - pred_uncond_pyramid[level] + + # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) + shift_parallel, shift_orthogonal = project(shift, pred_cond) + shift = parallel_weight * shift_parallel + shift_orthogonal + + # Apply CFG for the current frequency level + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond_pyramid[level], self.guidance_rescale) + + # Add the current FDG guided level to the guided pyramid + pred_guided_pyramid.append(pred) + + # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform + pred = build_image_from_pyramid(pred_guided_pyramid) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_fdg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_fdg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale_low, 0.0) and math.isclose(self.guidance_scale_high, 0.0) + else: + is_close = math.isclose(self.guidance_scale_low, 1.0) and math.isclose(self.guidance_scale_high, 1.0) + + return is_within_range and not is_close From fe824a88920e980a87b9e8830144bedec37645eb Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 22 Jul 2025 20:47:56 -0700 Subject: [PATCH 02/14] Update FrequencyDecoupledGuidance docstring to describe FDG --- .../guiders/frequency_decoupled_guidance.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index daa7164c5cf8..fbeea091020d 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -26,7 +26,7 @@ from ..modular_pipelines.modular_pipeline import BlockState -def project(v0: torch.Tensor, v1: torch.Tensor): +def project(v0: torch.Tensor, v1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper (Algorithm 2). @@ -41,7 +41,7 @@ def project(v0: torch.Tensor, v1: torch.Tensor): return v0_parallel.to(dtype), v0_orthogonal.to(dtype) -def build_image_from_pyramid(pyramid): +def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor: """ Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper (Algorihtm 2). @@ -56,21 +56,27 @@ class FrequencyDecoupledGuidance(BaseGuidance): """ Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713 - CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by - jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during - inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper - proposes scaling and shifting the conditional distribution based on the difference between conditional and - unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] - - Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen - paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation + quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both + conditional and unconditional data, and use a combination of the two during inference. (If you want more details + on how CFG works, you can check out the CFG guider.) + + FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency + components using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in + frequency space separately for the low- and high-frequency components with different guidance scales. Finally, the + inverse frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for + images) to form the final FDG prediction. + + For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample + diversity and realistic color composition, while using high guidance scales for high-frequency components enhances + sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) + for the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an + example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper). + + As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] - The intution behind the original formulation can be thought of as moving the conditional distribution estimates - further away from the unconditional distribution estimates, while the diffusers-native implementation can be - thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of - the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) - The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. @@ -154,11 +160,11 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if not self._is_fdg_enabled(): pred = pred_cond else: - # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional components. + # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions. pred_cond_pyramid = kornia.geometry.transform.build_laplacian_pyramid(pred_cond, self.levels) pred_uncond_pyramid = kornia.geometry.transform.build_laplacian_pyramid(pred_uncond, self.levels) - # From high freq to low, following the paper implementation + # From high frequencies to low frequencies, following the paper implementation pred_guided_pyramid = [] guidance_scales = [self.guidance_scale_high, self.guidance_scale_low] parallel_weights = [self.parallel_weights_high, self.parallel_weights_low] From 6949eceb5cd327f75e3b25b675ce489f3c73e5a3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 22 Jul 2025 21:03:31 -0700 Subject: [PATCH 03/14] Update project so that it accepts any number of non-batch dims --- src/diffusers/guiders/frequency_decoupled_guidance.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index fbeea091020d..11c8f466796e 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -31,12 +31,14 @@ def project(v0: torch.Tensor, v1: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper (Algorithm 2). """ - # v0 shape: [B, C, H, W] - # v1 shape: [B, C, H, W] + # v0 shape: [B, ...] + # v1 shape: [B, ...] dtype = v0.dtype + # Assume first dim is a batch dim and all other dims are channel or "spatial" dims + all_dims_but_first = list(range(1, len(v0.shape))) v0, v1 = v0.double(), v1.double() - v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) - v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first) + v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel return v0_parallel.to(dtype), v0_orthogonal.to(dtype) From 8c05d64fa41a21fd6dd0de0a9ffeadb099016494 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 23 Jul 2025 16:47:45 -0700 Subject: [PATCH 04/14] Change guidance_scale and other params to accept a list of params for each freq level --- .../guiders/frequency_decoupled_guidance.py | 172 ++++++++++++------ 1 file changed, 113 insertions(+), 59 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index 11c8f466796e..e8da8efba57e 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -83,36 +83,33 @@ class FrequencyDecoupledGuidance(BaseGuidance): paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. Args: - guidance_scale_low (`float`, defaults to `5.0`): - The scale parameter for frequency-decoupled guidance for low-frequency components. Higher values result in - stronger conditioning on the text prompt, while lower values allow for more freedom in generation. Higher - values may lead to saturation and deterioration of image quality. The FDG authors recommend - `guidance_scale_low < guidance_scale_high`. - guidance_scale_high (`float`, defaults to `10.0`): - The scale parameter for frequency-decoupled guidance for high-frequency components. Higher values result in - stronger conditioning on the text prompt, while lower values allow for more freedom in generation. Higher - values may lead to saturation and deterioration of image quality. The FDG authors recommend - `guidance_scale_low < guidance_scale_high`. - guidance_rescale (`float`, defaults to `0.0`): + guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`): + The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest + frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower + values allow for more freedom in generation. Higher values may lead to saturation and deterioration of + image quality. The FDG authors recommend using higher guidance scales for higher frequency components and + lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in + descending order). + guidance_rescale (`float` or `List[float]`, defaults to `0.0`): The rescale factor applied to the noise predictions. This is used to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://huggingface.co/papers/2305.08891). - parallel_weights_low (`float`, defaults to `1.0`): - Optional weight for the parallel component of the low-frequency component of the projected CFG shift. - The default value of `1.0` corresponds to using the normal CFG shift (that is, equal weights for the - parallel and orthogonal components). - parallel_weights_high (`float`, defaults to `1.0`): - Optional weight for the parallel component of the high-frequency component of the projected CFG shift. - The default value of `1.0` corresponds to using the normal CFG shift (that is, equal weights for the - parallel and orthogonal components). + Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as + `guidance_scales`. + parallel_weights (`float` or `List[float]`, *optional*): + Optional weights for the parallel component of each frequency component of the projected CFG shift. If not + set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift + (that is, equal weights for the parallel and orthogonal components). If a list is supplied, it should be + the same length as `guidance_scales`. use_original_formulation (`bool`, defaults to `False`): Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. See [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. - start (`float`, defaults to `0.0`): - The fraction of the total number of denoising steps after which guidance starts. - stop (`float`, defaults to `1.0`): - The fraction of the total number of denoising steps after which guidance stops. + start (`float` or `List[float]`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it + should be the same length as `guidance_scales`. + stop (`float` or `List[float]`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it + should be the same length as `guidance_scales`. """ _input_predictions = ["pred_cond", "pred_uncond"] @@ -120,29 +117,65 @@ class FrequencyDecoupledGuidance(BaseGuidance): @register_to_config def __init__( self, - guidance_scale_low: float = 5.0, - guidance_scale_high: float = 10.0, - guidance_rescale: float = 0.0, - parallel_weights_low: float = 1.0, - parallel_weights_high: float = 1.0, + guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0], + guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0, + parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None, use_original_formulation: bool = False, - start: float = 0.0, - stop: float = 1.0, + start: Union[float, List[float], Tuple[float]] = 0.0, + stop: Union[float, List[float], Tuple[float]] = 1.0, ): - super().__init__(start, stop) - - self.guidance_scale_low = guidance_scale_low - self.guidance_scale_high = guidance_scale_high - self.guidance_rescale = guidance_rescale - # Split the frequency components into 2 levels: low-frequency and high-frequency - # For now, hardcoded - self.levels = 2 - - self.parallel_weights_low = parallel_weights_low - self.parallel_weights_high = parallel_weights_high + # Set start to earliest start for any freq component and stop to latest stop for any freq component + min_start = start if isinstance(start, float) else min(start) + max_stop = stop if isinstance(stop, float) else max(stop) + super().__init__(min_start, max_stop) + + self.guidance_scales = guidance_scales + self.levels = len(guidance_scales) + + if isinstance(guidance_rescale, float): + self.guidance_rescale = [guidance_rescale] * self.levels + elif len(guidance_rescale) == self.levels: + self.guidance_rescale = guidance_rescale + else: + raise ValueError( + f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as " + f"`guidance_scales` ({len(self.guidance_scales)})" + ) + + if parallel_weights is None: + # Use normal CFG shift (equal weights for parallel and orthogonal components) + self.parallel_weights = [1.0] * self.levels + elif isinstance(parallel_weights, float): + self.parallel_weights = [parallel_weights] * self.levels + elif len(parallel_weights) == self.levels: + self.parallel_weights = parallel_weights + else: + raise ValueError( + f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as " + f"`guidance_scales` ({len(self.guidance_scales)})" + ) self.use_original_formulation = use_original_formulation + if isinstance(start, float): + self.guidance_start = [start] * self.levels + elif len(start) == self.levels: + self.guidance_start = start + else: + raise ValueError( + f"`start` has length {len(start)} but should have the same length as `guidance_scales` " + f"({len(self.guidance_scales)})" + ) + if isinstance(stop, float): + self.guidance_stop = [stop] * self.levels + elif len(stop) == self.levels: + self.guidance_stop = stop + else: + raise ValueError( + f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` " + f"({len(self.guidance_scales)})" + ) + def prepare_inputs( self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None ) -> List["BlockState"]: @@ -168,25 +201,28 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = # From high frequencies to low frequencies, following the paper implementation pred_guided_pyramid = [] - guidance_scales = [self.guidance_scale_high, self.guidance_scale_low] - parallel_weights = [self.parallel_weights_high, self.parallel_weights_low] - parameters = zip(guidance_scales, parallel_weights) - for level, (guidance_scale, parallel_weight) in enumerate(parameters): - shift = pred_cond_pyramid[level] - pred_uncond_pyramid[level] + parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale) + for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters): + if self._is_fdg_enabled_for_level(level): + shift = pred_cond_pyramid[level] - pred_uncond_pyramid[level] - # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) - shift_parallel, shift_orthogonal = project(shift, pred_cond) - shift = parallel_weight * shift_parallel + shift_orthogonal + # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) + if not math.isclose(parallel_weight, 1.0): + shift_parallel, shift_orthogonal = project(shift, pred_cond) + shift = parallel_weight * shift_parallel + shift_orthogonal - # Apply CFG for the current frequency level - pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + guidance_scale * shift + # Apply CFG update for the current frequency level + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + guidance_scale * shift - if self.guidance_rescale > 0.0: - pred = rescale_noise_cfg(pred, pred_cond_pyramid[level], self.guidance_rescale) + if guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond_pyramid[level], guidance_rescale) - # Add the current FDG guided level to the guided pyramid - pred_guided_pyramid.append(pred) + # Add the current FDG guided level to the FDG prediction pyramid + pred_guided_pyramid.append(pred) + else: + # Add the current pred_cond_pyramid level as the "non-FDG" prediction + pred_guided_pyramid.append(pred_cond_pyramid[level]) # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform pred = build_image_from_pyramid(pred_guided_pyramid) @@ -216,8 +252,26 @@ def _is_fdg_enabled(self) -> bool: is_close = False if self.use_original_formulation: - is_close = math.isclose(self.guidance_scale_low, 0.0) and math.isclose(self.guidance_scale_high, 0.0) + is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales) + else: + is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales) + + return is_within_range and not is_close + + def _is_fdg_enabled_for_level(self, level: int) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.guidance_start[level] * self._num_inference_steps) + skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scales[level], 0.0) else: - is_close = math.isclose(self.guidance_scale_low, 1.0) and math.isclose(self.guidance_scale_high, 1.0) + is_close = math.isclose(self.guidance_scales[level], 1.0) return is_within_range and not is_close From 33822e80732ea793d3ba344ac0b2c9ce3820850d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 23 Jul 2025 17:00:32 -0700 Subject: [PATCH 05/14] Add comment with Laplacian pyramid shapes --- src/diffusers/guiders/frequency_decoupled_guidance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index e8da8efba57e..d004b98fdf66 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -48,6 +48,7 @@ def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor: Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper (Algorihtm 2). """ + # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...] img = pyramid[-1] for i in range(len(pyramid) - 2, -1, -1): img = kornia.geometry.pyrup(img) + pyramid[i] From 565ce2a589f3daff319dcb7ec7555300837ba0f2 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 23 Jul 2025 17:10:45 -0700 Subject: [PATCH 06/14] Add function to import_utils to check if the kornia package is available --- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cadcedb98a14..c08d245cb44e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -81,6 +81,7 @@ is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, + is_kornia_available, is_librosa_available, is_matplotlib_available, is_nltk_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index a27c2da648f4..d0d3121ce942 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -223,6 +223,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") +_kornia_available, _kornia_version = _is_package_available("kornia") def is_torch_available(): @@ -393,6 +394,10 @@ def is_flash_attn_3_available(): return _flash_attn_3_available +def is_kornia_available(): + return _kornia_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the From f608c5f0eb034232e5e4931fcf8a389cd4857314 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 23 Jul 2025 17:26:20 -0700 Subject: [PATCH 07/14] Only import from kornia if package is available --- .../guiders/frequency_decoupled_guidance.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index d004b98fdf66..79929ed443ad 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -15,10 +15,10 @@ import math from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -import kornia import torch from ..configuration_utils import register_to_config +from ..utils import is_kornia_available from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -26,6 +26,17 @@ from ..modular_pipelines.modular_pipeline import BlockState +_CAN_USE_KORNIA = is_kornia_available() + + +if _CAN_USE_KORNIA: + from kornia.geometry import pyrup as upsample_and_blur_func + from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func +else: + upsample_and_blur_func = None + build_laplacian_pyramid_func = None + + def project(v0: torch.Tensor, v1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from @@ -51,7 +62,7 @@ def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor: # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...] img = pyramid[-1] for i in range(len(pyramid) - 2, -1, -1): - img = kornia.geometry.pyrup(img) + pyramid[i] + img = upsample_and_blur_func(img) + pyramid[i] return img @@ -125,6 +136,12 @@ def __init__( start: Union[float, List[float], Tuple[float]] = 0.0, stop: Union[float, List[float], Tuple[float]] = 1.0, ): + if not _CAN_USE_KORNIA: + raise ImportError( + "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which" + "it depends is not available in the current environment." + ) + # Set start to earliest start for any freq component and stop to latest stop for any freq component min_start = start if isinstance(start, float) else min(start) max_stop = stop if isinstance(stop, float) else max(stop) @@ -197,8 +214,8 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = pred = pred_cond else: # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions. - pred_cond_pyramid = kornia.geometry.transform.build_laplacian_pyramid(pred_cond, self.levels) - pred_uncond_pyramid = kornia.geometry.transform.build_laplacian_pyramid(pred_uncond, self.levels) + pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels) + pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels) # From high frequencies to low frequencies, following the paper implementation pred_guided_pyramid = [] From c5070e0031e13869e3a3141478767fe14aecc418 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 25 Jul 2025 14:36:01 -0700 Subject: [PATCH 08/14] Fix bug: use pred_cond/uncond in freq space rather than data space --- .../guiders/frequency_decoupled_guidance.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index 79929ed443ad..f7cc452e0176 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -222,25 +222,29 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale) for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters): if self._is_fdg_enabled_for_level(level): - shift = pred_cond_pyramid[level] - pred_uncond_pyramid[level] + # Get the cond/uncond preds (in freq space) at the current frequency level + pred_cond_freq = pred_cond_pyramid[level] + pred_uncond_freq = pred_uncond_pyramid[level] + + shift = pred_cond_freq - pred_uncond_freq # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) if not math.isclose(parallel_weight, 1.0): - shift_parallel, shift_orthogonal = project(shift, pred_cond) + shift_parallel, shift_orthogonal = project(shift, pred_cond_freq) shift = parallel_weight * shift_parallel + shift_orthogonal # Apply CFG update for the current frequency level - pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq pred = pred + guidance_scale * shift if guidance_rescale > 0.0: - pred = rescale_noise_cfg(pred, pred_cond_pyramid[level], guidance_rescale) + pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale) # Add the current FDG guided level to the FDG prediction pyramid pred_guided_pyramid.append(pred) else: # Add the current pred_cond_pyramid level as the "non-FDG" prediction - pred_guided_pyramid.append(pred_cond_pyramid[level]) + pred_guided_pyramid.append(pred_cond_freq) # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform pred = build_image_from_pyramid(pred_guided_pyramid) From 149c91539c71546e52d6f6a66e8172b20b8ed678 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 25 Jul 2025 14:58:11 -0700 Subject: [PATCH 09/14] Allow guidance rescaling to be done in data space or frequency space (speculative) --- .../guiders/frequency_decoupled_guidance.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index f7cc452e0176..1359046ccca9 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -122,6 +122,11 @@ class FrequencyDecoupledGuidance(BaseGuidance): stop (`float` or `List[float]`, defaults to `1.0`): The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it should be the same length as `guidance_scales`. + guidance_rescale_space (`str`, defaults to `"data"`): + Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in + `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is + speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value + will be used; otherwise, per-frequency-level guidance rescale values will be used if available. """ _input_predictions = ["pred_cond", "pred_uncond"] @@ -135,6 +140,7 @@ def __init__( use_original_formulation: bool = False, start: Union[float, List[float], Tuple[float]] = 0.0, stop: Union[float, List[float], Tuple[float]] = 1.0, + guidance_rescale_space: str = "data", ): if not _CAN_USE_KORNIA: raise ImportError( @@ -159,6 +165,13 @@ def __init__( f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as " f"`guidance_scales` ({len(self.guidance_scales)})" ) + # Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after + # transforming from frequency space back to data space) + if guidance_rescale_space not in ["data", "freq"]: + raise ValueError( + f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`." + ) + self.guidance_rescale_space = guidance_rescale_space if parallel_weights is None: # Use normal CFG shift (equal weights for parallel and orthogonal components) @@ -237,7 +250,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq pred = pred + guidance_scale * shift - if guidance_rescale > 0.0: + if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale) # Add the current FDG guided level to the FDG prediction pyramid @@ -249,6 +262,11 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform pred = build_image_from_pyramid(pred_guided_pyramid) + # If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value + # across all freq levels + if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0]) + return pred, {} @property From 0a3f90853ca830d47747457cd142239a1e2f65b6 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 30 Jul 2025 20:49:05 -0700 Subject: [PATCH 10/14] Add kornia install instructions to kornia import error message --- src/diffusers/guiders/frequency_decoupled_guidance.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index 1359046ccca9..dcba3e1dc601 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -144,8 +144,9 @@ def __init__( ): if not _CAN_USE_KORNIA: raise ImportError( - "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which" - "it depends is not available in the current environment." + "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which " + "it depends is not available in the current environment. You can install `kornia` with `pip install " + "kornia`." ) # Set start to earliest start for any freq component and stop to latest stop for any freq component From 259952ab108ec9f5d1363527715a7a9a59018320 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 30 Jul 2025 21:34:19 -0700 Subject: [PATCH 11/14] Add config to control whether operations are upcast to fp64 --- .../guiders/frequency_decoupled_guidance.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index dcba3e1dc601..936f7a4ad053 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -37,21 +37,25 @@ build_laplacian_pyramid_func = None -def project(v0: torch.Tensor, v1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """ Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper (Algorithm 2). """ # v0 shape: [B, ...] # v1 shape: [B, ...] - dtype = v0.dtype # Assume first dim is a batch dim and all other dims are channel or "spatial" dims all_dims_but_first = list(range(1, len(v0.shape))) - v0, v1 = v0.double(), v1.double() + if upcast_to_double: + dtype = v0.dtype + v0, v1 = v0.double(), v1.double() v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first) v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel - return v0_parallel.to(dtype), v0_orthogonal.to(dtype) + if upcast_to_double: + v0_parallel = v0_parallel.to(dtype) + v0_orthogonal = v0_orthogonal.to(dtype) + return v0_parallel, v0_orthogonal def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor: @@ -127,6 +131,9 @@ class FrequencyDecoupledGuidance(BaseGuidance): `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value will be used; otherwise, per-frequency-level guidance rescale values will be used if available. + upcast_to_double (`bool`, defaults to `True`): + Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to + float64 when performing guidance. This may result in better performance at the cost of increased runtime. """ _input_predictions = ["pred_cond", "pred_uncond"] @@ -141,6 +148,7 @@ def __init__( start: Union[float, List[float], Tuple[float]] = 0.0, stop: Union[float, List[float], Tuple[float]] = 1.0, guidance_rescale_space: str = "data", + upcast_to_double: bool = True, ): if not _CAN_USE_KORNIA: raise ImportError( @@ -188,6 +196,7 @@ def __init__( ) self.use_original_formulation = use_original_formulation + self.upcast_to_double = upcast_to_double if isinstance(start, float): self.guidance_start = [start] * self.levels @@ -244,7 +253,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) if not math.isclose(parallel_weight, 1.0): - shift_parallel, shift_orthogonal = project(shift, pred_cond_freq) + shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double) shift = parallel_weight * shift_parallel + shift_orthogonal # Apply CFG update for the current frequency level From 9c94aef6e9f35d73812811b60eef1731817606dd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 31 Jul 2025 01:21:54 -0700 Subject: [PATCH 12/14] Add parallel_weights recommended values to docstring --- src/diffusers/guiders/frequency_decoupled_guidance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index 936f7a4ad053..713a4fda55c0 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -114,8 +114,8 @@ class FrequencyDecoupledGuidance(BaseGuidance): parallel_weights (`float` or `List[float]`, *optional*): Optional weights for the parallel component of each frequency component of the projected CFG shift. If not set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift - (that is, equal weights for the parallel and orthogonal components). If a list is supplied, it should be - the same length as `guidance_scales`. + (that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is + recommended. If a list is supplied, it should be the same length as `guidance_scales`. use_original_formulation (`bool`, defaults to `False`): Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. See From 4c379a45aa0fb1b865f837f83923f27f60bcb46f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 2 Aug 2025 13:34:29 +0000 Subject: [PATCH 13/14] Apply style fixes --- .../guiders/frequency_decoupled_guidance.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index 713a4fda55c0..35bc99ac4dde 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -39,8 +39,8 @@ def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """ - Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from - paper (Algorithm 2). + Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper + (Algorithm 2). """ # v0 shape: [B, ...] # v1 shape: [B, ...] @@ -76,19 +76,19 @@ class FrequencyDecoupledGuidance(BaseGuidance): FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both - conditional and unconditional data, and use a combination of the two during inference. (If you want more details - on how CFG works, you can check out the CFG guider.) + conditional and unconditional data, and use a combination of the two during inference. (If you want more details on + how CFG works, you can check out the CFG guider.) - FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency - components using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in - frequency space separately for the low- and high-frequency components with different guidance scales. Finally, the - inverse frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for - images) to form the final FDG prediction. + FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components + using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space + separately for the low- and high-frequency components with different guidance scales. Finally, the inverse + frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images) + to form the final FDG prediction. For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample diversity and realistic color composition, while using high guidance scales for high-frequency components enhances - sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) - for the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an + sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for + the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper). As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen From 5d1652108f794939aa385a9f3e9b6a8f28b3ad97 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 5 Aug 2025 17:01:48 -0700 Subject: [PATCH 14/14] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 35df559ce4dd..08a816ce4b3c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FrequencyDecoupledGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PerturbedAttentionGuidance(metaclass=DummyObject): _backends = ["torch"]