jadechoghari's picture
Update audioldm_train/modules/latent_diffusion/ddpm.py
ee4e5da verified
from multiprocessing.sharedctypes import Value
import statistics
import sys
import os
# from tkinter import Ec
# sys.path.append('/home/changli/Adan')
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
# from adan import Adan
from einops import rearrange, repeat
from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from qa_mdt.audioldm_train.conditional_models import *
import datetime
from qa_mdt.audioldm_train.utilities.model_util import (
exists,
default,
mean_flat,
count_params,
instantiate_from_config,
)
from qa_mdt.audioldm_train.utilities.diffusion_util import (
make_beta_schedule,
extract_into_tensor,
noise_like,
)
from qa_mdt.audioldm_train.modules.diffusionmodules.ema import LitEma
from qa_mdt.audioldm_train.modules.diffusionmodules.distributions import (
normal_kl,
DiagonalGaussianDistribution,
)
# from audioldm_train.modules.diffusionmodules.transport import
from qa_mdt.audioldm_train.modules.latent_diffusion.ddim import DDIMSampler
from qa_mdt.audioldm_train.modules.latent_diffusion.plms import PLMSSampler
import soundfile as sf
import os
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
import json
with open('./qa_mdt/offset_pretrained_checkpoints.json', 'r') as config_file:
config_data = json.load(config_file)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
unet_config,
sampling_rate=None,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
latent_t_size=256,
latent_f_size=16,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.0,
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.0,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.0,
evaluator=None,
):
super().__init__()
assert parameterization in [
"eps",
"x0",
"v",
], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
self.state = None
print(
f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
)
assert sampling_rate is not None
self.validation_folder_name = "temp_name"
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.sampling_rate = sampling_rate
self.clap = CLAPAudioEmbeddingClassifierFreev2(
pretrained_path=config_data["clap_music"],
sampling_rate=self.sampling_rate,
embed_mode="audio",
amodel="HTSAT-base",
)
if self.global_rank == 0:
self.evaluator = evaluator
self.initialize_param_check_toolkit()
self.latent_t_size = latent_t_size
self.latent_f_size = latent_f_size
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(
ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
)
self.register_schedule(
given_betas=given_betas,
beta_schedule=beta_schedule,
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
else:
self.logvar = nn.Parameter(self.logvar, requires_grad=False)
self.logger_save_dir = None
self.logger_exp_name = None
self.logger_exp_group_name = None
self.logger_version = None
self.label_indices_total = None
# To avoid the system cannot find metric value for checkpoint
self.metrics_buffer = {
"val/kullback_leibler_divergence_sigmoid": 15.0,
"val/kullback_leibler_divergence_softmax": 10.0,
"val/psnr": 0.0,
"val/ssim": 0.0,
"val/inception_score_mean": 1.0,
"val/inception_score_std": 0.0,
"val/kernel_inception_distance_mean": 0.0,
"val/kernel_inception_distance_std": 0.0,
"val/frechet_inception_distance": 133.0,
"val/frechet_audio_distance": 32.0,
}
self.initial_learning_rate = None
self.test_data_subset_path = None
def get_log_dir(self):
return os.path.join(
self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name
)
def set_log_dir(self, save_dir, exp_group_name, exp_name):
self.logger_save_dir = save_dir
self.logger_exp_group_name = exp_group_name
self.logger_exp_name = exp_name
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)
if self.parameterization == "eps":
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0":
lvlb_weights = (
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(
self.betas**2
/ (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
)
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(
x=x, t=t, clip_denoised=clip_denoised
)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (
(1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
)
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(
reversed(range(0, self.num_timesteps)),
desc="Sampling t",
total=self.num_timesteps,
):
img = self.p_sample(
img,
torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised,
)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)
channels = self.channels
return self.p_sample_loop(shape, return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def get_loss(self, pred, target, mean=True):
if self.loss_type == "l1":
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == "l2":
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
* x_t
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
mse_loss_weight = None
alpha = extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape)
sigma = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape)
snr = (alpha / sigma) ** 2
# velocity = (alpha[:, None, None, None] * x_noisy - x_start) / sigma[:, None, None, None]
# get loss weight
if self.parameterization != "x0":
mse_loss_weight = torch.ones_like(t)
k = 5.0
# min{snr, k}
mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0] / snr
else:
k = 5.0
# min{snr, k}
mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0]
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(
f"Paramterization {self.parameterization} not yet supported"
)
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
loss = mse_loss_weight * loss
log_prefix = "train" if self.training else "val"
loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f"{log_prefix}/loss": loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(
0, self.num_timesteps, (x.shape[0],), device=self.device
).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
# fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch
# fbank, stft, label_indices, fname, waveform, text = batch
# a = 1/0
fname, text, label_indices, waveform, stft, fbank, mos = (
batch["fname"],
batch["text"],
batch["label_vector"],
batch["waveform"],
batch["stft"],
batch["log_mel_spec"],
batch["mos"],
# batch
)
# for i in range(fbank.size(0)):
# fb = fbank[i].numpy()
# seg_lb = seg_label[i].numpy()
# logits = np.mean(seg_lb, axis=0)
# index = np.argsort(logits)[::-1][:5]
# plt.imshow(seg_lb[:,index], aspect="auto")
# plt.title(index)
# plt.savefig("%s_label.png" % i)
# plt.close()
# plt.imshow(fb, aspect="auto")
# plt.savefig("%s_fb.png" % i)
# plt.close()
ret = {}
ret["fbank"] = (
fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
)
ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
# ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
ret["text"] = list(text)
ret["fname"] = fname
ret["mos"] = list(mos)
for key in batch.keys():
if key not in ret.keys():
ret[key] = batch[key]
return ret[k]
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def warmup_step(self):
if self.initial_learning_rate is None:
self.initial_learning_rate = self.learning_rate
# Only the first parameter group
if self.global_step <= self.warmup_steps:
if self.global_step == 0:
print(
"Warming up learning rate start with %s"
% self.initial_learning_rate
)
self.trainer.optimizers[0].param_groups[0]["lr"] = (
self.global_step / self.warmup_steps
) * self.initial_learning_rate
else:
# TODO set learning rate here
self.trainer.optimizers[0].param_groups[0][
"lr"
] = self.initial_learning_rate
def training_step(self, batch, batch_idx):
# You instantiate a optimizer for the scheduler
# But later you overwrite the optimizer by reloading its states from a checkpoint
# So you need to replace the optimizer with the checkpoint one
# if(self.lr_schedulers().optimizer.param_groups[0]['lr'] != self.trainer.optimizers[0].param_groups[0]['lr']):
# self.lr_schedulers().optimizer = self.trainer.optimizers[0]
# if(self.ckpt is not None):
# self.reload_everything()
# self.ckpt = None
self.random_clap_condition()
self.warmup_step()
# if (
# self.state is None
# and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0
# ):
# self.state = (
# self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone()
# )
# elif self.state is not None and batch_idx % 1000 == 0:
# assert (
# torch.sum(
# torch.abs(
# self.state
# - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"]
# )
# )
# > 1e-7
# ), "Optimizer is not working"
if len(self.metrics_buffer.keys()) > 0:
for k in self.metrics_buffer.keys():
self.log(
k,
self.metrics_buffer[k],
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
# print(k, self.metrics_buffer[k])
self.metrics_buffer = {}
loss, loss_dict = self.shared_step(batch)
self.log_dict(
{k: float(v) for k, v in loss_dict.items()},
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log(
"global_step",
float(self.global_step),
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
self.log(
"lr_abs",
float(lr),
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
return loss
def random_clap_condition(self):
# This function is only used during training, let the CLAP model to use both text and audio as condition
assert self.training == True
for key in self.cond_stage_model_metadata.keys():
metadata = self.cond_stage_model_metadata[key]
model_idx, cond_stage_key, conditioning_key = (
metadata["model_idx"],
metadata["cond_stage_key"],
metadata["conditioning_key"],
)
# If we use CLAP as condition, we might use audio for training, but we also must use text for evaluation
if isinstance(
self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2
):
self.cond_stage_model_metadata[key][
"cond_stage_key_orig"
] = self.cond_stage_model_metadata[key]["cond_stage_key"]
self.cond_stage_model_metadata[key][
"embed_mode_orig"
] = self.cond_stage_models[model_idx].embed_mode
if torch.randn(1).item() < 0.5:
self.cond_stage_model_metadata[key]["cond_stage_key"] = "text"
self.cond_stage_models[model_idx].embed_mode = "text"
else:
self.cond_stage_model_metadata[key]["cond_stage_key"] = "waveform"
self.cond_stage_models[model_idx].embed_mode = "audio"
def on_validation_epoch_start(self) -> None:
# Use text as condition during validation
for key in self.cond_stage_model_metadata.keys():
metadata = self.cond_stage_model_metadata[key]
model_idx, cond_stage_key, conditioning_key = (
metadata["model_idx"],
metadata["cond_stage_key"],
metadata["conditioning_key"],
)
# If we use CLAP as condition, we might use audio for training, but we also must use text for evaluation
if isinstance(
self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2
):
self.cond_stage_model_metadata[key][
"cond_stage_key_orig"
] = self.cond_stage_model_metadata[key]["cond_stage_key"]
self.cond_stage_model_metadata[key][
"embed_mode_orig"
] = self.cond_stage_models[model_idx].embed_mode
print(
"Change the model original cond_keyand embed_mode %s, %s to text during evaluation"
% (
self.cond_stage_model_metadata[key]["cond_stage_key_orig"],
self.cond_stage_model_metadata[key]["embed_mode_orig"],
)
)
self.cond_stage_model_metadata[key]["cond_stage_key"] = "text"
self.cond_stage_models[model_idx].embed_mode = "text"
if isinstance(
self.cond_stage_models[model_idx], CLAPGenAudioMAECond
) or isinstance(self.cond_stage_models[model_idx], SequenceGenAudioMAECond):
self.cond_stage_model_metadata[key][
"use_gt_mae_output_orig"
] = self.cond_stage_models[model_idx].use_gt_mae_output
self.cond_stage_model_metadata[key][
"use_gt_mae_prob_orig"
] = self.cond_stage_models[model_idx].use_gt_mae_prob
print("Change the model condition to the predicted AudioMAE tokens")
self.cond_stage_models[model_idx].use_gt_mae_output = False
self.cond_stage_models[model_idx].use_gt_mae_prob = 0.0
self.validation_folder_name = self.get_validation_folder_name()
return super().on_validation_epoch_start()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
self.generate_sample(
[batch],
name=self.validation_folder_name,
unconditional_guidance_scale=self.evaluation_params[
"unconditional_guidance_scale"
],
ddim_steps=self.evaluation_params["ddim_sampling_steps"],
n_gen=self.evaluation_params["n_candidates_per_samples"],
)
def get_validation_folder_name(self):
now = datetime.datetime.now()
timestamp = now.strftime("%m-%d-%H:%M")
return "val_%s_%s_cfg_scale_%s_ddim_%s_n_cand_%s" % (
self.global_step,
timestamp,
self.evaluation_params["unconditional_guidance_scale"],
self.evaluation_params["ddim_sampling_steps"],
self.evaluation_params["n_candidates_per_samples"],
)
def on_validation_epoch_end(self) -> None:
if self.global_rank == 0 and self.evaluator is not None:
assert (
self.test_data_subset_path is not None
), "Please set test_data_subset_path before validation so that model have a target folder"
try:
name = self.validation_folder_name
# import pdb
# pdb.set_trace()
waveform_save_path = os.path.join(self.get_log_dir(), name)
if (
os.path.exists(waveform_save_path)
and len(os.listdir(waveform_save_path)) > 0
):
metrics = self.evaluator.main(
waveform_save_path,
self.test_data_subset_path,
)
self.metrics_buffer = {
("val/" + k): float(v) for k, v in metrics.items()
}
else:
print(
"The target folder for evaluation does not exist: %s"
% waveform_save_path
)
except Exception as e:
print("Error encountered during evaluation: ", e)
# Very important or the program may fail
torch.cuda.synchronize()
for key in self.cond_stage_model_metadata.keys():
metadata = self.cond_stage_model_metadata[key]
model_idx, cond_stage_key, conditioning_key = (
metadata["model_idx"],
metadata["cond_stage_key"],
metadata["conditioning_key"],
)
if isinstance(
self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2
):
self.cond_stage_model_metadata[key][
"cond_stage_key"
] = self.cond_stage_model_metadata[key]["cond_stage_key_orig"]
self.cond_stage_models[
model_idx
].embed_mode = self.cond_stage_model_metadata[key]["embed_mode_orig"]
print(
"Change back the embedding mode to %s %s"
% (
self.cond_stage_model_metadata[key]["cond_stage_key"],
self.cond_stage_models[model_idx].embed_mode,
)
)
if isinstance(
self.cond_stage_models[model_idx], CLAPGenAudioMAECond
) or isinstance(self.cond_stage_models[model_idx], SequenceGenAudioMAECond):
self.cond_stage_models[
model_idx
].use_gt_mae_output = self.cond_stage_model_metadata[key][
"use_gt_mae_output_orig"
]
self.cond_stage_models[
model_idx
].use_gt_mae_prob = self.cond_stage_model_metadata[key][
"use_gt_mae_prob_orig"
]
print(
"Change the AudioMAE condition setting to %s (Use gt) %s (gt prob)"
% (
self.cond_stage_models[model_idx].use_gt_mae_output,
self.cond_stage_models[model_idx].use_gt_mae_prob,
)
)
return super().on_validation_epoch_end()
def on_train_epoch_start(self, *args, **kwargs):
print("Log directory: ", self.get_log_dir())
def on_train_batch_end(self, *args, **kwargs):
# Does this affect speed?
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(
batch_size=N, return_intermediates=True
)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
# opt = Adan(params, lr=lr, max_grad_norm=1, fused=True)
return opt
def initialize_param_check_toolkit(self):
self.tracked_steps = 0
self.param_dict = {}
def statistic_require_grad_tensor_number(self, module, name=None):
requires_grad_num = 0
total_num = 0
require_grad_tensor = None
for p in module.parameters():
if p.requires_grad:
requires_grad_num += 1
if require_grad_tensor is None:
require_grad_tensor = p
total_num += 1
print(
"Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
% (name, requires_grad_num, total_num, requires_grad_num / total_num)
)
return require_grad_tensor
def check_module_param_update(self):
if self.tracked_steps == 0:
for name, module in self.named_children():
try:
require_grad_tensor = self.statistic_require_grad_tensor_number(
module, name=name
)
if require_grad_tensor is not None:
self.param_dict[name] = require_grad_tensor.clone()
else:
print("==> %s does not requires grad" % name)
except Exception as e:
print("%s does not have trainable parameters: %s" % (name, e))
continue
if self.tracked_steps % 5000 == 0:
for name, module in self.named_children():
try:
require_grad_tensor = self.statistic_require_grad_tensor_number(
module, name=name
)
if require_grad_tensor is not None:
print(
"===> Param diff %s: %s; Size: %s"
% (
name,
torch.sum(
torch.abs(
self.param_dict[name] - require_grad_tensor
)
),
require_grad_tensor.size(),
)
)
else:
print("%s does not requires grad" % name)
except Exception as e:
print("%s does not have trainable parameters: %s" % (name, e))
continue
self.tracked_steps += 1
class LatentDiffusion(DDPM):
"""main class"""
def __init__(
self,
first_stage_config,
cond_stage_config=None,
num_timesteps_cond=None,
cond_stage_key="image",
optimize_ddpm_parameter=True,
unconditional_prob_cfg=0.1,
warmup_steps=10000,
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
batchsize=None,
evaluation_params={},
scale_by_std=False,
base_learning_rate=None,
*args,
**kwargs,
):
self.learning_rate = base_learning_rate
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
self.warmup_steps = warmup_steps
if optimize_ddpm_parameter:
if unconditional_prob_cfg == 0.0:
"You choose to optimize DDPM. The classifier free guidance scale should be 0.1"
unconditional_prob_cfg = 0.1
else:
if unconditional_prob_cfg == 0.1:
"You choose not to optimize DDPM. The classifier free guidance scale should be 0.0"
unconditional_prob_cfg = 0.0
self.evaluation_params = evaluation_params
assert self.num_timesteps_cond <= kwargs["timesteps"]
# for backwards compatibility after implementation of DiffusionWrapper
# if conditioning_key is None:
# conditioning_key = "concat" if concat_mode else "crossattn"
# if cond_stage_config == "__is_unconditional__":
# conditioning_key = None
conditioning_key = list(cond_stage_config.keys())
self.conditioning_key = conditioning_key
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.optimize_ddpm_parameter = optimize_ddpm_parameter
# if(not optimize_ddpm_parameter):
# print("Warning: Close the optimization of the latent diffusion model")
# for p in self.model.parameters():
# p.requires_grad=False
self.concat_mode = concat_mode
self.cond_stage_key = cond_stage_key
self.cond_stage_key_orig = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer("scale_factor", torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.unconditional_prob_cfg = unconditional_prob_cfg
self.cond_stage_models = nn.ModuleList([])
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.conditional_dry_run_finished = False
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
for each in self.cond_stage_models:
params = params + list(
each.parameters()
) # Add the parameter from the conditional stage
if self.learn_logvar:
print("Diffusion model optimizing logvar")
params.append(self.logvar)
# opt = Adan(params, lr=lr, max_grad_norm=1, fused=True)
opt = torch.optim.AdamW(params, lr=lr)
# if self.use_scheduler:
# assert "target" in self.scheduler_config
# scheduler = instantiate_from_config(self.scheduler_config)
# print("Setting up LambdaLR scheduler...")
# scheduler = [
# {
# "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
# "interval": "step",
# "frequency": 1,
# }
# ]
# return [opt], scheduler
return opt
def make_cond_schedule(
self,
):
self.cond_ids = torch.full(
size=(self.num_timesteps,),
fill_value=self.num_timesteps - 1,
dtype=torch.long,
)
ids = torch.round(
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
).long()
self.cond_ids[: self.num_timesteps_cond] = ids
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if (
self.scale_factor == 1
and self.scale_by_std
and self.current_epoch == 0
and self.global_step == 0
and batch_idx == 0
and not self.restarted_from_ckpt
):
# assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer("scale_factor", 1.0 / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
super().register_schedule(
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def make_decision(self, probability):
if float(torch.rand(1)) < probability:
return True
else:
return False
def instantiate_cond_stage(self, config):
self.cond_stage_model_metadata = {}
for i, cond_model_key in enumerate(config.keys()):
model = instantiate_from_config(config[cond_model_key])
self.cond_stage_models.append(model)
self.cond_stage_model_metadata[cond_model_key] = {
"model_idx": i,
"cond_stage_key": config[cond_model_key]["cond_stage_key"],
"conditioning_key": config[cond_model_key]["conditioning_key"],
}
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z
def get_learned_conditioning(self, c, key, unconditional_cfg):
assert key in self.cond_stage_model_metadata.keys()
# Classifier-free guidance
if not unconditional_cfg:
c = self.cond_stage_models[
self.cond_stage_model_metadata[key]["model_idx"]
](c)
else:
# when the cond_stage_key is "all", pick one random element out
if isinstance(c, dict):
c = c[list(c.keys())[0]]
if isinstance(c, torch.Tensor):
batchsize = c.size(0)
elif isinstance(c, list):
batchsize = len(c)
else:
raise NotImplementedError()
c = self.cond_stage_models[
self.cond_stage_model_metadata[key]["model_idx"]
].get_unconditional_condition(batchsize)
return c
def get_input(
self,
batch,
k,
return_first_stage_encode=True,
return_decoding_output=False,
return_encoder_input=False,
return_encoder_output=False,
unconditional_prob_cfg=0.1,
):
# print(self.cond_stage_model_metadata.keys())
x = super().get_input(batch, k)
x = x.to(self.device)
if return_first_stage_encode:
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
else:
z = None
cond_dict = {}
if len(self.cond_stage_model_metadata.keys()) > 0:
unconditional_cfg = False
if self.conditional_dry_run_finished and self.make_decision(
unconditional_prob_cfg
):
unconditional_cfg = True
for cond_model_key in self.cond_stage_model_metadata.keys():
cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
"cond_stage_key"
]
if cond_model_key in cond_dict.keys():
continue
if not self.training:
if isinstance(
self.cond_stage_models[
self.cond_stage_model_metadata[cond_model_key]["model_idx"]
],
CLAPAudioEmbeddingClassifierFreev2,
):
print(
"Warning: CLAP model normally should use text for evaluation"
)
# The original data for conditioning
# If cond_model_key is "all", that means the conditional model need all the information from a batch
if cond_stage_key != "all":
xc = super().get_input(batch, cond_stage_key)
if type(xc) == torch.Tensor:
xc = xc.to(self.device)
else:
xc = batch
# batch inference BUG
#if cond_stage_key == 'text':
# xc = xc[0]
# if cond_stage_key is "all", xc will be a dictionary containing all keys
# Otherwise xc will be an entry of the dictionary
c = self.get_learned_conditioning(
xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
)
# cond_dict will be used to condition the diffusion model
# If one conditional model return multiple conditioning signal
if isinstance(c, dict):
for k in c.keys():
cond_dict[k] = c[k]
else:
cond_dict[cond_model_key] = c
# If the key is accidently added to the dictionary and not in the condition list, remove the condition
# for k in list(cond_dict.keys()):
# if(k not in self.cond_stage_model_metadata.keys()):
# del cond_dict[k]
cond_dict['mos'] = batch['mos']
out = [z, cond_dict]
if return_decoding_output:
xrec = self.decode_first_stage(z)
out += [xrec]
if return_encoder_input:
out += [x]
if return_encoder_output:
out += [encoder_posterior]
if not self.conditional_dry_run_finished:
self.conditional_dry_run_finished = True
# Output is a dictionary, where the value could only be tensor or tuple
return out
def decode_first_stage(self, z):
with torch.no_grad():
z = 1.0 / self.scale_factor * z
decoding = self.first_stage_model.decode(z)
return decoding
def mel_spectrogram_to_waveform(
self, mel, savepath=".", bs=None, name="outwav", save=True, n_gen=1
):
# Mel: [bs, 1, t-steps, fbins]
if len(mel.size()) == 4:
mel = mel.squeeze(1)
mel = mel.permute(0, 2, 1)
waveform = self.first_stage_model.vocoder(mel)
waveform = waveform.cpu().detach().numpy()
if save:
self.save_waveform(waveform, savepath="./")
return waveform
def encode_first_stage(self, x):
with torch.no_grad():
return self.first_stage_model.encode(x)
def extract_possible_loss_in_cond_dict(self, cond_dict):
# This function enable the conditional module to return loss function that can optimize them
assert isinstance(cond_dict, dict)
losses = {}
for cond_key in cond_dict.keys():
if "loss" in cond_key and "noncond" in cond_key:
assert cond_key not in losses.keys()
losses[cond_key] = cond_dict[cond_key]
return losses
def filter_useful_cond_dict(self, cond_dict):
new_cond_dict = {}
for key in cond_dict.keys():
if key in self.cond_stage_model_metadata.keys():
new_cond_dict[key] = cond_dict[key]
# All the conditional key in the metadata should be used
for key in self.cond_stage_model_metadata.keys():
assert key in new_cond_dict.keys(), "%s, %s" % (
key,
str(new_cond_dict.keys()),
)
try:
new_cond_dict['mos'] = cond_dict['mos']
except:
pass
return new_cond_dict
def shared_step(self, batch, **kwargs):
# self.check_module_param_update()
if self.training:
# Classifier-free guidance
unconditional_prob_cfg = self.unconditional_prob_cfg
else:
unconditional_prob_cfg = 0.0 # TODO possible bug here
x, c = self.get_input(
batch, self.first_stage_key, unconditional_prob_cfg=unconditional_prob_cfg
)
if self.optimize_ddpm_parameter:
loss, loss_dict = self(x, self.filter_useful_cond_dict(c))
else:
loss_dict = {}
loss = None
additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c)
assert isinstance(additional_loss_for_cond_modules, dict)
loss_dict.update(additional_loss_for_cond_modules)
if len(additional_loss_for_cond_modules.keys()) > 0:
for k in additional_loss_for_cond_modules.keys():
if loss is None:
loss = additional_loss_for_cond_modules[k]
else:
loss = loss + additional_loss_for_cond_modules[k]
# for k,v in additional_loss_for_cond_modules.items():
# self.log(
# "cond_stage/"+k,
# float(v),
# prog_bar=True,
# logger=True,
# on_step=True,
# on_epoch=True,
# )
if self.training:
assert loss is not None
return loss, loss_dict
def forward(self, x, c, *args, **kwargs):
t = torch.randint(
0, self.num_timesteps, (x.shape[0],), device=self.device
).long()
# assert c is not None
# c = self.get_learned_conditioning(c)
loss, loss_dict = self.p_losses(x, c, t, *args, **kwargs)
return loss, loss_dict
def reorder_cond_dict(self, cond_dict):
# To make sure the order is correct
new_cond_dict = {}
for key in self.conditioning_key:
new_cond_dict[key] = cond_dict[key]
new_cond_dict['mos'] = cond_dict['mos']
return new_cond_dict
def apply_model(self, x_noisy, t, cond, return_ids=False):
cond = self.reorder_cond_dict(cond)
# import pdb; pdb.set_trace()
x_recon = self.model(x_noisy, t, cond_dict=cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = "train" if self.training else "val"
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError()
# print(model_output.size(), target.size())
mse_loss_weight = None
alpha = extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape)
sigma = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape)
snr = (alpha / sigma) ** 2
# velocity = (alpha[:, None, None, None] * x_t - x_start) / sigma[:, None, None, None]
# get loss weight
if self.parameterization != "x0":
mse_loss_weight = torch.ones_like(t)
k = 5.0
# min{snr, k}
mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0] / snr
else:
k = 5.0
# min{snr, k}
mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0]
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_simple = loss_simple * mse_loss_weight
# import pdb
# pdb.set_trace()
loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
loss_dict.update({"logvar": self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
loss += self.original_elbo_weight * loss_vlb
loss_dict.update({f"{prefix}/loss": loss})
return loss, loss_dict
def p_mean_variance(
self,
x,
c,
t,
clip_denoised: bool,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
score_corrector=None,
corrector_kwargs=None,
):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(
self, model_out, x, t, c, **corrector_kwargs
)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(
self,
x,
c,
t,
clip_denoised=False,
repeat_noise=False,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(
x=x,
c=c,
t=t,
clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (
(1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
)
# if return_codebook_ids:
# return model_mean + nonzero_mask * (
# 0.5 * model_log_variance
# ).exp() * noise, logits.argmax(dim=1)
if return_x0:
return (
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
x0,
)
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(
self,
cond,
shape,
verbose=True,
callback=None,
quantize_denoised=False,
img_callback=None,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
batch_size=None,
x_T=None,
start_T=None,
log_every_t=None,
):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
for key in cond
}
else:
cond = (
[c[:batch_size] for c in cond]
if isinstance(cond, list)
else cond[:batch_size]
)
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = (
tqdm(
reversed(range(0, timesteps)),
desc="Progressive Generation",
total=timesteps,
)
if verbose
else reversed(range(0, timesteps))
)
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != "hybrid"
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(
img,
cond,
ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised,
return_x0=True,
temperature=temperature[i],
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1.0 - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(
self,
cond,
shape,
return_intermediates=False,
x_T=None,
verbose=True,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
start_T=None,
log_every_t=None,
):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = (
tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
if verbose
else reversed(range(0, timesteps))
)
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != "hybrid"
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
# import pdb
# pdb.set_trace()
img = self.p_sample(
img,
cond,
ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised,
)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1.0 - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(
self,
cond,
batch_size=16,
return_intermediates=False,
x_T=None,
verbose=True,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
shape=None,
**kwargs,
):
if shape is None:
shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
for key in cond
}
else:
cond = (
[c[:batch_size] for c in cond]
if isinstance(cond, list)
else cond[:batch_size]
)
return self.p_sample_loop(
cond,
shape,
return_intermediates=return_intermediates,
x_T=x_T,
verbose=verbose,
timesteps=timesteps,
quantize_denoised=quantize_denoised,
mask=mask,
x0=x0,
**kwargs,
)
def save_waveform(self, waveform, savepath="./", name="awesome.wav", n_gen=1):
print(f'debug_name : {name}')
# If `name` is a list, join the elements into a string or select the first element
if isinstance(name, list):
name = "_".join(name) # Joins the list elements with an underscore
name += ".wav" # Ensures the file has a `.wav` extension
elif not isinstance(name, str):
raise TypeError("Name must be a string or list")
# Normalize the energy of the waveform
todo_waveform = waveform[0, 0] # Assuming you are only saving the first waveform
todo_waveform = (todo_waveform / np.max(np.abs(todo_waveform))) * 0.8
# Define the path where to save the file
path = os.path.join(savepath, name)
try:
# Save the waveform to the specified path
sf.write(path, todo_waveform, samplerate=self.sampling_rate)
print(f'Waveform saved at -> {path}')
except Exception as e:
print(f'Error saving waveform: {e}')
@torch.no_grad()
def sample_log(
self,
cond,
batch_size,
ddim,
ddim_steps,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_plms=False,
mask=None,
**kwargs,
):
if mask is not None:
shape = (self.channels, mask.size()[-2], mask.size()[-1])
else:
shape = (self.channels, self.latent_t_size, self.latent_f_size)
intermediate = None
if ddim and not use_plms:
print("Use ddim sampler")
ddim_sampler = DDIMSampler(self)
samples, intermediates = ddim_sampler.sample(
ddim_steps,
batch_size,
shape,
cond,
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
**kwargs,
)
elif use_plms:
print("Use plms sampler")
plms_sampler = PLMSSampler(self)
samples, intermediates = plms_sampler.sample(
ddim_steps,
batch_size,
shape,
cond,
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
mask=mask,
unconditional_conditioning=unconditional_conditioning,
**kwargs,
)
else:
print("Use DDPM sampler")
samples, intermediates = self.sample(
cond=cond,
batch_size=batch_size,
return_intermediates=True,
unconditional_guidance_scale=unconditional_guidance_scale,
mask=mask,
unconditional_conditioning=unconditional_conditioning,
**kwargs,
)
return samples, intermediate
@torch.no_grad()
def generate_sample(
self,
batchs,
ddim_steps=200,
ddim_eta=1.0,
x_T=None,
n_gen=1,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
name=None,
use_plms=False,
limit_num=None,
**kwargs,
):
# Generate n_gen times and select the best
# Batch: audio, text, fnames
# import pdb
# pdb.set_trace()
assert x_T is None
try:
batchs = iter(batchs)
except TypeError:
raise ValueError("The first input argument should be an iterable object")
if use_plms:
assert ddim_steps is not None
use_ddim = ddim_steps is not None
if name is None:
name = self.get_validation_folder_name()
waveform_save_path = os.path.join(self.get_log_dir(), name)
os.makedirs(waveform_save_path, exist_ok=True)
print("Waveform save path: ", waveform_save_path)
# if (
# "audiocaps" in waveform_save_path
# and len(os.listdir(waveform_save_path)) >= 964
# ):
# print("The evaluation has already been done at %s" % waveform_save_path)
# return waveform_save_path
with self.ema_scope("Plotting"):
for i, batch in enumerate(batchs):
#print(batch)
z, c = self.get_input(
batch,
self.first_stage_key,
unconditional_prob_cfg=0.0, # Do not output unconditional information in the c
)
# import pdb; pdb.set_trace()
if limit_num is not None and i * z.size(0) > limit_num:
break
c = self.filter_useful_cond_dict(c)
text = super().get_input(batch, "text")
mos = super().get_input(batch, "mos")
# for cond_key in c.keys():
# c[cond_key] = self.cond_stage_models[self.cond_stage_model_metadata[cond_key]["model_idx"]](text[0])
# Generate multiple samples
batch_size = z.shape[0] * n_gen
# Generate multiple samples at a time and filter out the best
# The condition to the diffusion wrapper can have many format
# import pdb
# pdb.set_trace()
for cond_key in c.keys():
if isinstance(c[cond_key], list):
for i in range(len(c[cond_key])):
c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
elif isinstance(c[cond_key], dict):
for k in c[cond_key].keys():
c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
else:
c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
text = text * n_gen
mos = mos * n_gen
c['mos'] = torch.stack(mos).unsqueeze(1)
if unconditional_guidance_scale != 1.0:
unconditional_conditioning = {}
for key in self.cond_stage_model_metadata:
model_idx = self.cond_stage_model_metadata[key]["model_idx"]
unconditional_conditioning[key] = self.cond_stage_models[
model_idx
].get_unconditional_condition(batch_size)
fnames = list(super().get_input(batch, "fname"))
# import pdb; pdb.set_trace()
samples, _ = self.sample_log(
cond=c,
batch_size=batch_size,
x_T=x_T,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
use_plms=use_plms,
)
mel = self.decode_first_stage(samples)
# mel = super().get_input(batch, "log_mel_spec")
waveform = self.mel_spectrogram_to_waveform(
mel, savepath=waveform_save_path, bs=None, name=fnames, save=False, n_gen=n_gen
)
if n_gen > 1:
try:
best_index = []
similarity = self.clap.cos_similarity(
torch.FloatTensor(waveform).squeeze(1), text
)
for i in range(z.shape[0]):
candidates = similarity[i :: z.shape[0]]
max_index = torch.argmax(candidates).item()
best_index.append(i + max_index * z.shape[0])
waveform = waveform[best_index]
print("Similarity between generated audio and text", similarity)
print("Choose the following indexes:", best_index)
except Exception as e:
print("Warning: while calculating CLAP score (not fatal), ", e)
self.save_waveform(waveform, savepath="./")
return waveform_save_path
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
for key in self.conditioning_key:
if (
"concat" in key
or "crossattn" in key
or "hybrid" in key
or "film" in key
or "noncond" in key
):
continue
else:
raise Value("The conditioning key %s is illegal" % key)
self.being_verbosed_once = False
def forward(self, x, t, cond_dict: dict = {}):
x = x.contiguous()
t = t.contiguous()
# import pdb
# pdb.set_trace()
# x with condition (or maybe not)
xc = x
y = None
context_list, attn_mask_list = [], []
conditional_keys = cond_dict.keys()
for key in conditional_keys:
if "concat" in key:
xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1)
elif "film" in key:
if y is None:
y = cond_dict[key].squeeze(1)
else:
y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1)
elif "crossattn" in key:
# assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys())
if isinstance(cond_dict[key], dict):
for k in cond_dict[key].keys():
if "crossattn" in k:
context, attn_mask = cond_dict[key][
k
] # crossattn_audiomae_pooled: torch.Size([12, 128, 768])
else:
assert len(cond_dict[key]) == 2, (
"The context condition for %s you returned should have two element, one context one mask"
% (key)
)
context, attn_mask = cond_dict[key]
# The input to the UNet model is a list of context matrix
context_list.append(context)
attn_mask_list.append(attn_mask)
elif (
"noncond" in key
): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary
continue
elif "mos" in key:
mos = cond_dict['mos']
else:
raise NotImplementedError()
if not self.being_verbosed_once:
print("The input shape to the diffusion model is as follows:")
print("xc", xc.size())
print("t", t.size())
for i in range(len(context_list)):
print(
"context_%s" % i, context_list[i].size(), attn_mask_list[i].size()
)
if y is not None:
print("y", y.size())
self.being_verbosed_once = True
# try:
# out = self.diffusion_model.forward_with_dpmsolver(
# xc, timestep=t, y=context_list, mask=attn_mask_list
# )
# except:
out = self.diffusion_model.forward(
xc, timestep=t, context_list=context_list, context_mask_list=attn_mask_list, mos=mos
)
return out
@torch.no_grad()
def forward_with_cfg(self, x, t, cond_dict: dict = {}, cfg_scale=4.0, **model_kwargs):
x = x.contiguous()
t = t.contiguous()
# x with condition (or maybe not)
xc = x
y = None
context_list, attn_mask_list = [], []
conditional_keys = cond_dict.keys()
for key in conditional_keys:
if "concat" in key:
xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1)
elif "film" in key:
if y is None:
y = cond_dict[key].squeeze(1)
else:
y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1)
elif "crossattn" in key:
# assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys())
if isinstance(cond_dict[key], dict):
for k in cond_dict[key].keys():
if "crossattn" in k:
context, attn_mask = cond_dict[key][
k
] # crossattn_audiomae_pooled: torch.Size([12, 128, 768])
else:
assert len(cond_dict[key]) == 2, (
"The context condition for %s you returned should have two element, one context one mask"
% (key)
)
context, attn_mask = cond_dict[key]
# The input to the UNet model is a list of context matrix
context_list.append(context)
attn_mask_list.append(attn_mask)
elif (
"noncond" in key
): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary
continue
else:
raise NotImplementedError()
if not self.being_verbosed_once:
print("The input shape to the diffusion model is as follows:")
print("xc", xc.size())
print("t", t.size())
for i in range(len(context_list)):
print(
"context_%s" % i, context_list[i].size(), attn_mask_list[i].size()
)
if y is not None:
print("y", y.size())
self.being_verbosed_once = True
# try:
# out = self.diffusion_model.forward_with_dpmsolver(
# xc, timestep=t, y=context_list, mask=attn_mask_list
# )
# except:
out = self.diffusion_model.forward_with_cfg(
xc, timestep=t, context_list=context_list, context_mask_list=attn_mask_list, cfg_scale=cfg_scale, **model_kwargs
)
# import pdb
# pdb.set_trace()
return out
class LatentDiffusionSpeedTest(pl.LightningModule):
"""main class"""
def __init__(
self,
first_stage_config,
cond_stage_config=None,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
batchsize=None,
evaluation_params={},
scale_by_std=False,
base_learning_rate=None,
*args,
**kwargs,
):
super().__init__()
self.l1 = nn.Linear(1, 1)
self.logger_save_dir = None
self.logger_exp_group_name = None
self.logger_exp_name = None
self.test_data_subset_path = None
def set_log_dir(self, save_dir, exp_group_name, exp_name):
self.logger_save_dir = save_dir
self.logger_exp_group_name = exp_group_name
self.logger_exp_name = exp_name
def forward(self, x):
return self.l1(x.permute(0, 2, 1)).permute(0, 2, 1)
def training_step(self, batch, batch_idx):
x = batch["waveform"]
loss = self(x)
return torch.mean(loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
class LatentDiffusionVAELearnable(LatentDiffusion):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
for each in self.cond_stage_models:
params = params + list(
each.parameters()
) # Add the parameter from the conditional stage
if self.learn_logvar:
print("Diffusion model optimizing logvar")
params.append(self.logvar)
ldm_opt = torch.optim.AdamW(params, lr=lr)
opt_autoencoder, opt_scheduler = self.first_stage_model.configure_optimizers()
opt_ae, opt_disc = opt_autoencoder
return [ldm_opt, opt_ae, opt_disc], []
def encode_first_stage(self, x):
# with torch.no_grad():
encoding = self.first_stage_model.encode(x)
return encoding
def decode_first_stage(self, z):
# with torch.no_grad():
z = 1.0 / self.scale_factor * z
decoding = self.first_stage_model.decode(z)
return decoding
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.train()
# self.first_stage_model.train = disabled_train
# for param in self.first_stage_model.parameters():
# param.requires_grad = False
def shared_step(self, batch, **kwargs):
ldm_opt, g_opt, d_opt = self.optimizers()
if self.training:
# Classifier-free guidance
unconditional_prob_cfg = self.unconditional_prob_cfg
else:
unconditional_prob_cfg = 0.0
x, c, decoder_xrec, encoder_x, encoder_posterior = self.get_input(
batch,
self.first_stage_key,
unconditional_prob_cfg=unconditional_prob_cfg,
return_decoding_output=True,
return_encoder_input=True,
return_encoder_output=True,
)
loss, loss_dict = self(x, self.filter_useful_cond_dict(c))
additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c)
assert isinstance(additional_loss_for_cond_modules, dict)
loss_dict.update(additional_loss_for_cond_modules)
if len(additional_loss_for_cond_modules.keys()) > 0:
for k in additional_loss_for_cond_modules.keys():
loss = loss + additional_loss_for_cond_modules[k]
for k, v in additional_loss_for_cond_modules.items():
self.log(
"cond_stage/" + k,
float(v),
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
aeloss, log_dict_ae = self.first_stage_model.loss(
encoder_x,
decoder_xrec,
encoder_posterior,
optimizer_idx=0,
global_step=self.first_stage_model.global_step,
last_layer=self.first_stage_model.get_last_layer(),
split="train",
)
self.manual_backward(loss + aeloss)
ldm_opt.step()
ldm_opt.zero_grad()
g_opt.step()
g_opt.zero_grad()
discloss, log_dict_disc = self.first_stage_model.loss(
encoder_x,
decoder_xrec,
encoder_posterior,
optimizer_idx=1,
global_step=self.first_stage_model.global_step,
last_layer=self.first_stage_model.get_last_layer(),
split="train",
)
self.manual_backward(discloss)
d_opt.step()
d_opt.zero_grad()
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
self.log(
"posterior_std",
torch.mean(encoder_posterior.var),
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
loss_dict.update(log_dict_disc)
loss_dict.update(log_dict_ae)
return None, loss_dict
def training_step(self, batch, batch_idx):
self.warmup_step()
self.check_module_param_update()
if (
self.state is None
and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0
):
self.state = (
self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone()
)
elif self.state is not None and batch_idx % 1000 == 0:
assert (
torch.sum(
torch.abs(
self.state
- self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"]
)
)
> 1e-7
), "Optimizer is not working"
if len(self.metrics_buffer.keys()) > 0:
for k in self.metrics_buffer.keys():
self.log(
k,
self.metrics_buffer[k],
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
print(k, self.metrics_buffer[k])
self.metrics_buffer = {}
loss, loss_dict = self.shared_step(batch)
self.log_dict(
{k: float(v) for k, v in loss_dict.items()},
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log(
"global_step",
float(self.global_step),
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
self.log(
"lr_abs",
float(lr),
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)