Skip to content

Commit d45199a

Browse files
dg845github-actions[bot]a-r-r-o-w
authored
Implement Frequency-Decoupled Guidance (FDG) as a Guider (#11976)
* Initial commit implementing frequency-decoupled guidance (FDG) as a guider * Update FrequencyDecoupledGuidance docstring to describe FDG * Update project so that it accepts any number of non-batch dims * Change guidance_scale and other params to accept a list of params for each freq level * Add comment with Laplacian pyramid shapes * Add function to import_utils to check if the kornia package is available * Only import from kornia if package is available * Fix bug: use pred_cond/uncond in freq space rather than data space * Allow guidance rescaling to be done in data space or frequency space (speculative) * Add kornia install instructions to kornia import error message * Add config to control whether operations are upcast to fp64 * Add parallel_weights recommended values to docstring * Apply style fixes * make fix-copies --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aryan <[email protected]>
1 parent 0611631 commit d45199a

File tree

6 files changed

+352
-0
lines changed

6 files changed

+352
-0
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
"AutoGuidance",
140140
"ClassifierFreeGuidance",
141141
"ClassifierFreeZeroStarGuidance",
142+
"FrequencyDecoupledGuidance",
142143
"PerturbedAttentionGuidance",
143144
"SkipLayerGuidance",
144145
"SmoothedEnergyGuidance",
@@ -804,6 +805,7 @@
804805
AutoGuidance,
805806
ClassifierFreeGuidance,
806807
ClassifierFreeZeroStarGuidance,
808+
FrequencyDecoupledGuidance,
807809
PerturbedAttentionGuidance,
808810
SkipLayerGuidance,
809811
SmoothedEnergyGuidance,

src/diffusers/guiders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .auto_guidance import AutoGuidance
2323
from .classifier_free_guidance import ClassifierFreeGuidance
2424
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
25+
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
2526
from .perturbed_attention_guidance import PerturbedAttentionGuidance
2627
from .skip_layer_guidance import SkipLayerGuidance
2728
from .smoothed_energy_guidance import SmoothedEnergyGuidance
@@ -32,6 +33,7 @@
3233
AutoGuidance,
3334
ClassifierFreeGuidance,
3435
ClassifierFreeZeroStarGuidance,
36+
FrequencyDecoupledGuidance,
3537
PerturbedAttentionGuidance,
3638
SkipLayerGuidance,
3739
SmoothedEnergyGuidance,
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17+
18+
import torch
19+
20+
from ..configuration_utils import register_to_config
21+
from ..utils import is_kornia_available
22+
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
24+
25+
if TYPE_CHECKING:
26+
from ..modular_pipelines.modular_pipeline import BlockState
27+
28+
29+
_CAN_USE_KORNIA = is_kornia_available()
30+
31+
32+
if _CAN_USE_KORNIA:
33+
from kornia.geometry import pyrup as upsample_and_blur_func
34+
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
35+
else:
36+
upsample_and_blur_func = None
37+
build_laplacian_pyramid_func = None
38+
39+
40+
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
41+
"""
42+
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
43+
(Algorithm 2).
44+
"""
45+
# v0 shape: [B, ...]
46+
# v1 shape: [B, ...]
47+
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
48+
all_dims_but_first = list(range(1, len(v0.shape)))
49+
if upcast_to_double:
50+
dtype = v0.dtype
51+
v0, v1 = v0.double(), v1.double()
52+
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
53+
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
54+
v0_orthogonal = v0 - v0_parallel
55+
if upcast_to_double:
56+
v0_parallel = v0_parallel.to(dtype)
57+
v0_orthogonal = v0_orthogonal.to(dtype)
58+
return v0_parallel, v0_orthogonal
59+
60+
61+
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
62+
"""
63+
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
64+
(Algorihtm 2).
65+
"""
66+
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
67+
img = pyramid[-1]
68+
for i in range(len(pyramid) - 2, -1, -1):
69+
img = upsample_and_blur_func(img) + pyramid[i]
70+
return img
71+
72+
73+
class FrequencyDecoupledGuidance(BaseGuidance):
74+
"""
75+
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
76+
77+
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
78+
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
79+
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
80+
how CFG works, you can check out the CFG guider.)
81+
82+
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
83+
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
84+
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
85+
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
86+
to form the final FDG prediction.
87+
88+
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
89+
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
90+
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
91+
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
92+
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
93+
94+
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
95+
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
96+
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
97+
98+
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
99+
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
100+
101+
Args:
102+
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
103+
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
104+
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
105+
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
106+
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
107+
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
108+
descending order).
109+
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
110+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
111+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
112+
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
113+
`guidance_scales`.
114+
parallel_weights (`float` or `List[float]`, *optional*):
115+
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
116+
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
117+
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
118+
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
119+
use_original_formulation (`bool`, defaults to `False`):
120+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
121+
we use the diffusers-native implementation that has been in the codebase for a long time. See
122+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
123+
start (`float` or `List[float]`, defaults to `0.0`):
124+
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
125+
should be the same length as `guidance_scales`.
126+
stop (`float` or `List[float]`, defaults to `1.0`):
127+
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
128+
should be the same length as `guidance_scales`.
129+
guidance_rescale_space (`str`, defaults to `"data"`):
130+
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
131+
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
132+
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
133+
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
134+
upcast_to_double (`bool`, defaults to `True`):
135+
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
136+
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
137+
"""
138+
139+
_input_predictions = ["pred_cond", "pred_uncond"]
140+
141+
@register_to_config
142+
def __init__(
143+
self,
144+
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
145+
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
146+
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
147+
use_original_formulation: bool = False,
148+
start: Union[float, List[float], Tuple[float]] = 0.0,
149+
stop: Union[float, List[float], Tuple[float]] = 1.0,
150+
guidance_rescale_space: str = "data",
151+
upcast_to_double: bool = True,
152+
):
153+
if not _CAN_USE_KORNIA:
154+
raise ImportError(
155+
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
156+
"it depends is not available in the current environment. You can install `kornia` with `pip install "
157+
"kornia`."
158+
)
159+
160+
# Set start to earliest start for any freq component and stop to latest stop for any freq component
161+
min_start = start if isinstance(start, float) else min(start)
162+
max_stop = stop if isinstance(stop, float) else max(stop)
163+
super().__init__(min_start, max_stop)
164+
165+
self.guidance_scales = guidance_scales
166+
self.levels = len(guidance_scales)
167+
168+
if isinstance(guidance_rescale, float):
169+
self.guidance_rescale = [guidance_rescale] * self.levels
170+
elif len(guidance_rescale) == self.levels:
171+
self.guidance_rescale = guidance_rescale
172+
else:
173+
raise ValueError(
174+
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
175+
f"`guidance_scales` ({len(self.guidance_scales)})"
176+
)
177+
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
178+
# transforming from frequency space back to data space)
179+
if guidance_rescale_space not in ["data", "freq"]:
180+
raise ValueError(
181+
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
182+
)
183+
self.guidance_rescale_space = guidance_rescale_space
184+
185+
if parallel_weights is None:
186+
# Use normal CFG shift (equal weights for parallel and orthogonal components)
187+
self.parallel_weights = [1.0] * self.levels
188+
elif isinstance(parallel_weights, float):
189+
self.parallel_weights = [parallel_weights] * self.levels
190+
elif len(parallel_weights) == self.levels:
191+
self.parallel_weights = parallel_weights
192+
else:
193+
raise ValueError(
194+
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
195+
f"`guidance_scales` ({len(self.guidance_scales)})"
196+
)
197+
198+
self.use_original_formulation = use_original_formulation
199+
self.upcast_to_double = upcast_to_double
200+
201+
if isinstance(start, float):
202+
self.guidance_start = [start] * self.levels
203+
elif len(start) == self.levels:
204+
self.guidance_start = start
205+
else:
206+
raise ValueError(
207+
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
208+
f"({len(self.guidance_scales)})"
209+
)
210+
if isinstance(stop, float):
211+
self.guidance_stop = [stop] * self.levels
212+
elif len(stop) == self.levels:
213+
self.guidance_stop = stop
214+
else:
215+
raise ValueError(
216+
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
217+
f"({len(self.guidance_scales)})"
218+
)
219+
220+
def prepare_inputs(
221+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
222+
) -> List["BlockState"]:
223+
if input_fields is None:
224+
input_fields = self._input_fields
225+
226+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
227+
data_batches = []
228+
for i in range(self.num_conditions):
229+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
230+
data_batches.append(data_batch)
231+
return data_batches
232+
233+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
234+
pred = None
235+
236+
if not self._is_fdg_enabled():
237+
pred = pred_cond
238+
else:
239+
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
240+
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
241+
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
242+
243+
# From high frequencies to low frequencies, following the paper implementation
244+
pred_guided_pyramid = []
245+
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
246+
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
247+
if self._is_fdg_enabled_for_level(level):
248+
# Get the cond/uncond preds (in freq space) at the current frequency level
249+
pred_cond_freq = pred_cond_pyramid[level]
250+
pred_uncond_freq = pred_uncond_pyramid[level]
251+
252+
shift = pred_cond_freq - pred_uncond_freq
253+
254+
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
255+
if not math.isclose(parallel_weight, 1.0):
256+
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
257+
shift = parallel_weight * shift_parallel + shift_orthogonal
258+
259+
# Apply CFG update for the current frequency level
260+
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
261+
pred = pred + guidance_scale * shift
262+
263+
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
264+
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
265+
266+
# Add the current FDG guided level to the FDG prediction pyramid
267+
pred_guided_pyramid.append(pred)
268+
else:
269+
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
270+
pred_guided_pyramid.append(pred_cond_freq)
271+
272+
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
273+
pred = build_image_from_pyramid(pred_guided_pyramid)
274+
275+
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
276+
# across all freq levels
277+
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
278+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
279+
280+
return pred, {}
281+
282+
@property
283+
def is_conditional(self) -> bool:
284+
return self._count_prepared == 1
285+
286+
@property
287+
def num_conditions(self) -> int:
288+
num_conditions = 1
289+
if self._is_fdg_enabled():
290+
num_conditions += 1
291+
return num_conditions
292+
293+
def _is_fdg_enabled(self) -> bool:
294+
if not self._enabled:
295+
return False
296+
297+
is_within_range = True
298+
if self._num_inference_steps is not None:
299+
skip_start_step = int(self._start * self._num_inference_steps)
300+
skip_stop_step = int(self._stop * self._num_inference_steps)
301+
is_within_range = skip_start_step <= self._step < skip_stop_step
302+
303+
is_close = False
304+
if self.use_original_formulation:
305+
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
306+
else:
307+
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
308+
309+
return is_within_range and not is_close
310+
311+
def _is_fdg_enabled_for_level(self, level: int) -> bool:
312+
if not self._enabled:
313+
return False
314+
315+
is_within_range = True
316+
if self._num_inference_steps is not None:
317+
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
318+
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
319+
is_within_range = skip_start_step <= self._step < skip_stop_step
320+
321+
is_close = False
322+
if self.use_original_formulation:
323+
is_close = math.isclose(self.guidance_scales[level], 0.0)
324+
else:
325+
is_close = math.isclose(self.guidance_scales[level], 1.0)
326+
327+
return is_within_range and not is_close

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
is_k_diffusion_available,
8383
is_k_diffusion_version,
8484
is_kernels_available,
85+
is_kornia_available,
8586
is_librosa_available,
8687
is_matplotlib_available,
8788
is_nltk_available,

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs):
6262
requires_backends(cls, ["torch"])
6363

6464

65+
class FrequencyDecoupledGuidance(metaclass=DummyObject):
66+
_backends = ["torch"]
67+
68+
def __init__(self, *args, **kwargs):
69+
requires_backends(self, ["torch"])
70+
71+
@classmethod
72+
def from_config(cls, *args, **kwargs):
73+
requires_backends(cls, ["torch"])
74+
75+
@classmethod
76+
def from_pretrained(cls, *args, **kwargs):
77+
requires_backends(cls, ["torch"])
78+
79+
6580
class PerturbedAttentionGuidance(metaclass=DummyObject):
6681
_backends = ["torch"]
6782

0 commit comments

Comments
 (0)