Spaces:
Runtime error
Runtime error
File size: 4,718 Bytes
947db12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from typing import Any
import torch
from .scheduler_utils import SNR_to_betas, compute_snr
class ShiftSNRScheduler:
def __init__(
self,
noise_scheduler: Any,
timesteps: Any,
shift_scale: float,
scheduler_class: Any,
):
self.noise_scheduler = noise_scheduler
self.timesteps = timesteps
self.shift_scale = shift_scale
self.scheduler_class = scheduler_class
def _get_shift_scheduler(self):
"""
Prepare scheduler for shifted betas.
:return: A scheduler object configured with shifted betas
"""
snr = compute_snr(self.timesteps, self.noise_scheduler)
shifted_betas = SNR_to_betas(snr / self.shift_scale)
return self.scheduler_class.from_config(
self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
)
def _get_interpolated_shift_scheduler(self):
"""
Prepare scheduler for shifted betas and interpolate with the original betas in log space.
:return: A scheduler object configured with interpolated shifted betas
"""
snr = compute_snr(self.timesteps, self.noise_scheduler)
shifted_snr = snr / self.shift_scale
weighting = self.timesteps.float() / (
self.noise_scheduler.config.num_train_timesteps - 1
)
interpolated_snr = torch.exp(
torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
)
shifted_betas = SNR_to_betas(interpolated_snr)
return self.scheduler_class.from_config(
self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
)
@classmethod
def from_scheduler(
cls,
noise_scheduler: Any,
shift_mode: str = "default",
timesteps: Any = None,
shift_scale: float = 1.0,
scheduler_class: Any = None,
):
# Check input
if timesteps is None:
timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
if scheduler_class is None:
scheduler_class = noise_scheduler.__class__
# Create scheduler
shift_scheduler = cls(
noise_scheduler=noise_scheduler,
timesteps=timesteps,
shift_scale=shift_scale,
scheduler_class=scheduler_class,
)
if shift_mode == "default":
return shift_scheduler._get_shift_scheduler()
elif shift_mode == "interpolated":
return shift_scheduler._get_interpolated_shift_scheduler()
else:
raise ValueError(f"Unknown shift_mode: {shift_mode}")
if __name__ == "__main__":
"""
Compare the alpha values for different noise schedulers.
"""
import matplotlib.pyplot as plt
from diffusers import DDPMScheduler
from .scheduler_utils import compute_alpha
# Base
timesteps = torch.arange(0, 1000)
noise_scheduler_base = DDPMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler"
)
alpha = compute_alpha(timesteps, noise_scheduler_base)
plt.plot(timesteps.numpy(), alpha.numpy(), label="Base")
# Kolors
num_train_timesteps_ = 1100
timesteps_ = torch.arange(0, num_train_timesteps_)
noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_}
noise_scheduler_kolors = DDPMScheduler.from_config(
noise_scheduler_base.config, **noise_kwargs
)
alpha = compute_alpha(timesteps_, noise_scheduler_kolors)
plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors")
# Shift betas
shift_scale = 8.0
noise_scheduler_shift = ShiftSNRScheduler.from_scheduler(
noise_scheduler_base, shift_mode="default", shift_scale=shift_scale
)
alpha = compute_alpha(timesteps, noise_scheduler_shift)
plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)")
# Shift betas (interpolated)
noise_scheduler_inter = ShiftSNRScheduler.from_scheduler(
noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale
)
alpha = compute_alpha(timesteps, noise_scheduler_inter)
plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)")
# ZeroSNR
noise_scheduler = DDPMScheduler.from_config(
noise_scheduler_base.config, rescale_betas_zero_snr=True
)
alpha = compute_alpha(timesteps, noise_scheduler)
plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR")
plt.legend()
plt.grid()
plt.savefig("check_alpha.png")
|