# Copyright 2024 SLAPaper # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from tqdm.auto import trange @torch.no_grad() def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None): """DPM-Solver++(2M) alt""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() old_denoised = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback( { "x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised, } ) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_denoised is None or sigmas[i + 1] == 0: x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised else: h_last = t - t_fn(sigmas[i - 1]) r = h_last / h denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d sigma_progress = i / len(sigmas) adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress)) old_denoised = denoised * adjustment_factor return x def add_sample_dpmpp_2m_alt_comfy() -> None: try: from comfy.samplers import KSampler, k_diffusion_sampling # type: ignore except ImportError: return if "dpmpp_2m_alt" not in KSampler.SAMPLERS: try: idx = KSampler.SAMPLERS.index("dpmpp_2m") KSampler.SAMPLERS.insert(idx + 1, "dpmpp_2m_alt") setattr(k_diffusion_sampling, "sample_dpmpp_2m_alt", sample_dpmpp_2m_alt) import importlib importlib.reload(k_diffusion_sampling) except ValueError: pass def add_custom_samplers(): samplers = [ add_sample_dpmpp_2m_alt_comfy, ] for add_sampler in samplers: add_sampler()