-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Implement Frequency-Decoupled Guidance (FDG) as a Guider #11976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Some notes on the initial implementation:
|
Thank you for the quick implementation. Regarding your question, I believe it's cleaner to use tuples for the weights, as it allows users to seamlessly apply multiple levels when finer control over the generation is needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dg845 Thanks for taking it up, implementation looks great!
What you suggested about tuples sounds good, let's do that. We can always update the implementation later if needed to simplify since modular guiders is experimental at the moment (plus, users can pass their own guider implementations so if someone wants to simplify, it will be quite easy to take your implementation and make the necessary modifications)
Let's not add kornia as a dependancy. Instead, we can do the same thing done in attention dispatcher (import only if package is available):
if _CAN_USE_FLASH_ATTN_3: |
import math | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union | ||
|
||
import kornia |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a is_kornia_available
to diffusers.utils.import_utils
and import only if user already has it downloaded? A check could exist in __init__
as well so that if user tries to instantiate FDG guider, it fails if kornia isn't available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a is_kornia_available
function to utils
and added logic in the FDG guider to only import from kornia
if available following the Flash Attention 3 example above.
Hi @Msadat97, quick question: how should FDG interact with guidance rescaling (from https://arxiv.org/pdf/2305.08891)? Currently, I'm rescaling in frequency space for each frequency level, with different |
It seems more natural to perform a single rescaling at the end (after the FDG prediction) since FDG is meant to replace the CFG output. Rescaling in the frequency domain is also possible, but I can’t comment further as we haven’t tested FDG with guidance rescaling. Do you have any output comparisons for this? |
Can you share a code snippet how to use FDG . @dg845 |
@dg845 I noticed a mistake in the implementation. ![]() |
Here is a code sample for running the new FDG guider with a SD-XL modular pipeline: import torch
from diffusers.guiders import FrequencyDecoupledGuidance
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
num_inference_steps = 50
device = "cuda"
dtype = torch.float16
seed = 42
generator = torch.Generator(device=device)
generator.manual_seed(seed)
init_generator_state = generator.get_state()
# Create default SD-XL text-to-image ModularPipeline (with CFG guider)
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
sdxl_pipeline = t2i_blocks.init_pipeline(modular_repo_id)
# Load pretrained components
sdxl_pipeline.load_default_components(torch_dtype=dtype)
sdxl_pipeline.to(device)
# Create CFG baseline image
cfg_guider_spec = sdxl_pipeline.get_component_spec("guider")
cfg_guidance_scale = cfg_guider_spec.config["guidance_scale"]
cfg_image = sdxl_pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
generator=generator,
output="images",
)[0]
cfg_image.save(f"cfg_image_{cfg_guidance_scale}.png")
# Swap in new FDG guider
# Guidance scales listed from high frequency to low frequency
fdg_guidance_scales = [10.0, 5.0]
fdg_guider = FrequencyDecoupledGuidance(guidance_scales=fdg_guidance_scales)
sdxl_pipeline.update_components(guider=fdg_guider)
# TODO: is this necessary to instantiate a new guider?
sdxl_pipeline.load_components(names=["guider"], torch_dtype=dtype)
# Create FDG image
# Reset generator state
generator.set_state(init_generator_state)
fdg_image = sdxl_pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
generator=generator,
output="images",
)[0]
fdg_guidance_scale_str = "_".join(f"{scale:.1f}" for scale in fdg_guidance_scales)
fdg_image.save(f"fdg_image_{fdg_guidance_scale_str}.png") Quick question for @a-r-r-o-w or @yiyixuxu: for a |
Here are some samples for the prompt CFG with guidance scale 7.5 (pipeline default): ![]() FDG with guidance scales ![]() FDG with guidance scales ![]() |
With the same prompt and number of inference steps, here are some samples using a guidance rescale value of CFG with guidance scale ![]() FDG with guidance scales ![]() FDG with guidance scales ![]() It looks like rescaling in frequency space still produces coherent images, and may preserve high-frequency details better than rescaling in data space (for example, the extra details on the astronaut's square pack). |
@dg845 Probably @yiyixuxu will be better able to answer your question here since I haven't played around with the loader much or fully read through the refactored pipeline.update_components(
guider=ComponentSpec(
name="cfg",
type_hint=<GUIDER_CLASS>,
config=<GUIDER_INIT_KWARGS},
default_creation_method="from_config",
)
) From what I understand, the Nice explorations too! In a blind test, I think my preference would definitely lean towards FDG generations :)
Feel free to add another knob ( |
Oh lol, I just looked at the code after writing above comment and saw you've already added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again, the implementation is very clean and easy to understand!
if not _CAN_USE_KORNIA: | ||
raise ImportError( | ||
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which" | ||
"it depends is not available in the current environment." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a simple instruction like pip install kornia
to the message here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
"it depends is not available in the current environment." | ||
) | ||
|
||
# Set start to earliest start for any freq component and stop to latest stop for any freq component |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
dtype = v0.dtype | ||
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims | ||
all_dims_but_first = list(range(1, len(v0.shape))) | ||
v0, v1 = v0.double(), v1.double() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dg845 @Msadat97 Just curious whether this must be float64 and if you've tested the same with float32/lower-dtype and found it harmful? The operations here are very few, but fp64 is extremely slow and I wonder if this has any impact on the overall runtime (maybe negligible for images, but might be worth understanding for when number of tokens is larger, like in video models, and if the dtype here could be potentially user-configurable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The project
function is called when parallel_weights
is set (and is not the default value of 1.0
), so the upcasted operations will only be performed sometimes.
For now, I have added a upcast_to_double
argument which controls whether project
will upcast to fp64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are some FDG samples which use a guidance scale of [10.0, 5.0]
and a parallel_weights
of 1.5
. The 1.5
value is somewhat arbitrary; @Msadat97, what is a reasonable range of values for parallel_weights
?
FDG with guidance scales [10.0, 5.0]
, parallel_weights=1.5
, upcast to double:

FDG with guidance scales [10.0, 5.0]
, parallel_weights=1.5
, no upcast to double (with pipeline at fp16):

In this case, the images look of similar quality with and without upcasting (with perhaps a slight reduction in quality for the non-upcasted version).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven’t specifically tested the FP32 projection part, but I’m not sure how much it affects performance in this case, as the operations involved are quite lightweight and the model still runs in FP16. I just felt it might be safer to use double for normalization and projection to improve numerical accuracy a bit.
Regarding the parallel component, I think it’s best to keep the weight below 1. A value like 0.5 should give a good balance. That said, we used 1 in most parts of the paper and treated it as optional.
@dg845 One last question: are you using the noise prediction (i.e., the model output) for FDG, or the x_0 prediction? Perhaps using x_0 might be better, since frequency decomposition is likely more meaningful there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, I am using the raw model output, whether that's
I believe it would be difficult to use the prediction_type
and the beta
/sigma
/etc. schedule to calculate the step
method usually expects a raw model output and will convert to an prediction_type
so that the FDG prediction can be used as expected in the scheduler. @yiyixuxu thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s how we implement FDG as well, and it’s similar to how Adaptive Projected Guidance (APG) was handled in the guiders. So I assume it should also be compatible with FDG?
P.S.: btw, this conversion is mainly useful for projection to be more meaningful. Otherwise, it's almost the same for all prediction types, since the frequency operations are linear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the AdaptiveProjectedGuidance
guider is implemented in the same way the FDG guider is currently implemented: the forward
method takes in pred_cond
and pred_uncond
arguments but is agnostic as to whether these inputs are
My statement above that the FDG guider uses the raw model output is probably a little misleading, in the sense that this assumes that the calling code will supply the denoising model's output to the FDG guider. This is the case in e.g. StableDiffusionXLLoopDenoiser
:
diffusers/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
Lines 232 to 243 in e46e139
# Predict the noise residual | |
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches | |
guider_state_batch.noise_pred = components.unet( | |
block_state.scaled_latents, | |
t, | |
encoder_hidden_states=prompt_embeds, | |
timestep_cond=block_state.timestep_cond, | |
cross_attention_kwargs=block_state.cross_attention_kwargs, | |
added_cond_kwargs=cond_kwargs, | |
return_dict=False, | |
)[0] | |
components.guider.cleanup_models(components.unet) |
but we could imagine that the calling PipelineBlock
(such as StableDiffusionXLLoopDenoiser
) could instead do the conversion to expected_components
, and in this case we'd probably want the guider to expose a config like should_convert_to_sample_prediction
and the scheduler to expose convert_to_sample_prediction
/convert_to_prediction_type
methods.
In general, I think it may make more sense to do something like PipelineBlock
, since in the current design PipelineBlock
s can have access to the scheduler whereas the guider itself shouldn't be coupled to the scheduler.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks @a-r-r-o-w. My current understanding of how
So my understanding is that the following code from above fdg_guider = FrequencyDecoupledGuidance(guidance_scales=fdg_guidance_scales)
sdxl_pipeline.update_components(guider=fdg_guider) should correctly tell the pipeline to use the supplied FDG guider instance, and |
@bot /style |
Style bot fixed some files and pushed the changes. |
@dg845 Could you run |
What does this PR do?
This PR implements frequency-decoupled guidance (FDG) (paper), a new guidance strategy, as a guider. The idea behind FDG is to decompose the CFG prediction into low- and high-frequency components and apply guidance separately to each via a CFG-style update (with separate guidance scales$w_{low}$ and $w_{high}$ ). The authors find that low guidance scales work better for the low-frequency components while high guidance scales work better for the high-frequency components (e.g. you should set $w_{low} < w_{high}$ ).
Fixes #11956.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@a-r-r-o-w
@yiyixuxu
@Msadat97