diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2d7e09838e76..f4a5d97c8f2c 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -810,7 +810,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No @torch.no_grad() -def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint', cfg_pp=False): """DPM-Solver++(2M) SDE.""" if len(sigmas) <= 1: return x @@ -830,6 +830,16 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl old_denoised = None h, h_last = None, None + uncond_denoised = None + + def post_cfg_function(args): + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] + return args["denoised"] + + if cfg_pp: + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -846,28 +856,43 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl alpha_t = sigmas[i + 1] * lambda_t.exp() + current_denoised = uncond_denoised if cfg_pp else denoised + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised + if cfg_pp: + x = x + alpha_t * (-h_eta).exp().neg() * (current_denoised - denoised) + if old_denoised is not None: r = h_last / h if solver_type == 'heun': - x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (current_denoised - old_denoised) elif solver_type == 'midpoint': - x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (current_denoised - old_denoised) if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise - old_denoised = denoised + old_denoised = current_denoised h_last = h return x +@torch.no_grad() +def sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type, cfg_pp=True) + + @torch.no_grad() def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) +@torch.no_grad() +def sample_dpmpp_2m_sde_heun_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type, cfg_pp=True) + + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -940,6 +965,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) +@torch.no_grad() +def sample_dpmpp_2m_sde_heun_cfg_pp_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_2m_sde_heun_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): if len(sigmas) <= 1: @@ -950,6 +985,16 @@ def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=Non return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) +@torch.no_grad() +def sample_dpmpp_2m_sde_cfg_pp_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: diff --git a/comfy/samplers.py b/comfy/samplers.py index b3202cec6f2c..4c47f6271945 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -729,8 +729,9 @@ def max_denoise(self, model_wrap, sigmas): KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", + "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_cfg_pp", "dpmpp_2m_sde_cfg_pp_gpu", + "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_2m_sde_heun_cfg_pp", "dpmpp_2m_sde_heun_cfg_pp_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", + "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] class KSAMPLER(Sampler):