Spaces:
Running
on
Zero
Running
on
Zero
from functools import partial | |
import torch | |
from ...util import default, instantiate_from_config | |
class VanillaCFG: | |
""" | |
implements parallelized CFG | |
""" | |
def __init__(self, scale, dyn_thresh_config=None): | |
scale_schedule = lambda scale, sigma: scale # independent of step | |
self.scale_schedule = partial(scale_schedule, scale) | |
self.dyn_thresh = instantiate_from_config( | |
default( | |
dyn_thresh_config, | |
{ | |
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" | |
}, | |
) | |
) | |
def __call__(self, x, sigma): | |
x_u, x_c = x.chunk(2) | |
scale_value = self.scale_schedule(sigma) | |
x_pred = self.dyn_thresh(x_u, x_c, scale_value) | |
return x_pred | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
for k in c: | |
if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']: | |
c_out[k] = torch.cat((uc[k], c[k]), 0) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
class LinearCFG: | |
def __init__(self, scale, scale_min=None, dyn_thresh_config=None): | |
if scale_min is None: | |
scale_min = scale | |
scale_schedule = lambda scale, scale_min, sigma: (scale - scale_min) * sigma / 14.6146 + scale_min | |
self.scale_schedule = partial(scale_schedule, scale, scale_min) | |
self.dyn_thresh = instantiate_from_config( | |
default( | |
dyn_thresh_config, | |
{ | |
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" | |
}, | |
) | |
) | |
def __call__(self, x, sigma): | |
x_u, x_c = x.chunk(2) | |
scale_value = self.scale_schedule(sigma) | |
x_pred = self.dyn_thresh(x_u, x_c, scale_value) | |
return x_pred | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
for k in c: | |
if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']: | |
c_out[k] = torch.cat((uc[k], c[k]), 0) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
class IdentityGuider: | |
def __call__(self, x, sigma): | |
return x | |
def prepare_inputs(self, x, s, c, uc): | |
c_out = dict() | |
for k in c: | |
c_out[k] = c[k] | |
return x, s, c_out | |