From f9ca75020ed8fe83e5a234dcc535b2215738c6be Mon Sep 17 00:00:00 2001 From: Balladie Date: Fri, 12 Sep 2025 15:17:25 +0900 Subject: [PATCH 1/3] add cfg++ to dpmpp 2m sde samplers --- comfy/k_diffusion/sampling.py | 51 ++++++++++++++++++++++++++++++++--- comfy/samplers.py | 5 ++-- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2d7e09838e76..dadfb28cedf1 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -810,7 +810,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No @torch.no_grad() -def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint', cfg_pp=False): """DPM-Solver++(2M) SDE.""" if len(sigmas) <= 1: return x @@ -830,6 +830,16 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl old_denoised = None h, h_last = None, None + uncond_denoised = None + + def post_cfg_function(args): + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] + return args["denoised"] + + if cfg_pp: + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -846,28 +856,41 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl alpha_t = sigmas[i + 1] * lambda_t.exp() + current_denoised = uncond_denoised if cfg_pp else denoised + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised + x = x + alpha_t * (-h_eta).exp().neg() * (current_denoised - denoised) if old_denoised is not None: r = h_last / h if solver_type == 'heun': - x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (current_denoised - old_denoised) elif solver_type == 'midpoint': - x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (current_denoised - old_denoised) if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise - old_denoised = denoised + old_denoised = current_denoised h_last = h return x +@torch.no_grad() +def sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type, cfg_pp=True) + + @torch.no_grad() def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) +@torch.no_grad() +def sample_dpmpp_2m_sde_heun_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type, cfg_pp=True) + + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -940,6 +963,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) +@torch.no_grad() +def sample_dpmpp_2m_sde_heun_cfg_pp_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_2m_sde_heun_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): if len(sigmas) <= 1: @@ -950,6 +983,16 @@ def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=Non return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) +@torch.no_grad() +def sample_dpmpp_2m_sde_cfg_pp_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: diff --git a/comfy/samplers.py b/comfy/samplers.py index b3202cec6f2c..4c47f6271945 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -729,8 +729,9 @@ def max_denoise(self, model_wrap, sigmas): KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", + "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_cfg_pp", "dpmpp_2m_sde_cfg_pp_gpu", + "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_2m_sde_heun_cfg_pp", "dpmpp_2m_sde_heun_cfg_pp_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", + "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] class KSAMPLER(Sampler): From 3de7d3bb25a3c60dc59e88114a97d3e0c19545a9 Mon Sep 17 00:00:00 2001 From: Balladie Date: Sun, 14 Sep 2025 15:30:45 +0900 Subject: [PATCH 2/3] remove redundant op when cfg_pp=False --- comfy/k_diffusion/sampling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index dadfb28cedf1..2a9561b80eb4 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -859,7 +859,9 @@ def post_cfg_function(args): current_denoised = uncond_denoised if cfg_pp else denoised x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised - x = x + alpha_t * (-h_eta).exp().neg() * (current_denoised - denoised) + + if cfg_pp: + x = x + alpha_t * (-h_eta).exp().neg() * (current_denoised - denoised) if old_denoised is not None: r = h_last / h From cae6be59881b8404efc609e76f21daac37ae8668 Mon Sep 17 00:00:00 2001 From: Balladie Date: Sun, 14 Sep 2025 15:32:37 +0900 Subject: [PATCH 3/3] apply ruff --- comfy/k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2a9561b80eb4..f4a5d97c8f2c 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -859,7 +859,7 @@ def post_cfg_function(args): current_denoised = uncond_denoised if cfg_pp else denoised x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised - + if cfg_pp: x = x + alpha_t * (-h_eta).exp().neg() * (current_denoised - denoised)