from multiprocessing.sharedctypes import Value import statistics import sys import os # from tkinter import Ec # sys.path.append('/home/changli/Adan') import torch import torch.nn as nn import numpy as np import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR # from adan import Adan from einops import rearrange, repeat from contextlib import contextmanager from functools import partial from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.rank_zero import rank_zero_only from qa_mdt.audioldm_train.conditional_models import * import datetime from qa_mdt.audioldm_train.utilities.model_util import ( exists, default, mean_flat, count_params, instantiate_from_config, ) from qa_mdt.audioldm_train.utilities.diffusion_util import ( make_beta_schedule, extract_into_tensor, noise_like, ) from qa_mdt.audioldm_train.modules.diffusionmodules.ema import LitEma from qa_mdt.audioldm_train.modules.diffusionmodules.distributions import ( normal_kl, DiagonalGaussianDistribution, ) # from audioldm_train.modules.diffusionmodules.transport import from qa_mdt.audioldm_train.modules.latent_diffusion.ddim import DDIMSampler from qa_mdt.audioldm_train.modules.latent_diffusion.plms import PLMSSampler import soundfile as sf import os __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} import json with open('offset_pretrained_checkpoints.json', 'r') as config_file: config_data = json.load(config_file) def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def uniform_on_device(r1, r2, shape, device): return (r1 - r2) * torch.rand(*shape, device=device) + r2 class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space def __init__( self, unet_config, sampling_rate=None, timesteps=1000, beta_schedule="linear", loss_type="l2", ckpt_path=None, ignore_keys=[], load_only_unet=False, monitor="val/loss", use_ema=True, first_stage_key="image", latent_t_size=256, latent_f_size=16, channels=3, log_every_t=100, clip_denoised=True, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, given_betas=None, original_elbo_weight=0.0, v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta l_simple_weight=1.0, conditioning_key=None, parameterization="eps", # all assuming fixed variance schedules scheduler_config=None, use_positional_encodings=False, learn_logvar=False, logvar_init=0.0, evaluator=None, ): super().__init__() assert parameterization in [ "eps", "x0", "v", ], 'currently only supporting "eps" and "x0" and "v"' self.parameterization = parameterization self.state = None print( f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" ) assert sampling_rate is not None self.validation_folder_name = "temp_name" self.clip_denoised = clip_denoised self.log_every_t = log_every_t self.first_stage_key = first_stage_key self.sampling_rate = sampling_rate self.clap = CLAPAudioEmbeddingClassifierFreev2( pretrained_path=config_data["clap_music"], sampling_rate=self.sampling_rate, embed_mode="audio", amodel="HTSAT-base", ) if self.global_rank == 0: self.evaluator = evaluator self.initialize_param_check_toolkit() self.latent_t_size = latent_t_size self.latent_f_size = latent_f_size self.channels = channels self.use_positional_encodings = use_positional_encodings self.model = DiffusionWrapper(unet_config, conditioning_key) count_params(self.model, verbose=True) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.use_scheduler = scheduler_config is not None if self.use_scheduler: self.scheduler_config = scheduler_config self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt( ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet ) self.register_schedule( given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) self.loss_type = loss_type self.learn_logvar = learn_logvar self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) else: self.logvar = nn.Parameter(self.logvar, requires_grad=False) self.logger_save_dir = None self.logger_exp_name = None self.logger_exp_group_name = None self.logger_version = None self.label_indices_total = None # To avoid the system cannot find metric value for checkpoint self.metrics_buffer = { "val/kullback_leibler_divergence_sigmoid": 15.0, "val/kullback_leibler_divergence_softmax": 10.0, "val/psnr": 0.0, "val/ssim": 0.0, "val/inception_score_mean": 1.0, "val/inception_score_std": 0.0, "val/kernel_inception_distance_mean": 0.0, "val/kernel_inception_distance_std": 0.0, "val/frechet_inception_distance": 133.0, "val/frechet_audio_distance": 32.0, } self.initial_learning_rate = None self.test_data_subset_path = None def get_log_dir(self): return os.path.join( self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name ) def set_log_dir(self, save_dir, exp_group_name, exp_name): self.logger_save_dir = save_dir self.logger_exp_group_name = exp_group_name self.logger_exp_name = exp_name def register_schedule( self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): if exists(given_betas): betas = given_betas else: betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert ( alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer("betas", to_torch(betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) self.register_buffer( "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) ) self.register_buffer( "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( 1.0 - alphas_cumprod_prev ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer( "posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))), ) self.register_buffer( "posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), ) self.register_buffer( "posterior_mean_coef2", to_torch( (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) ), ) if self.parameterization == "eps": lvlb_weights = self.betas**2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) ) elif self.parameterization == "x0": lvlb_weights = ( 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) ) elif self.parameterization == "v": lvlb_weights = torch.ones_like( self.betas**2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) ) ) else: raise NotImplementedError("mu not supported") # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = ( self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) ) print( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, clip_denoised: bool): model_out = self.model(x, t) if self.parameterization == "eps": x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == "x0": x_recon = model_out if clip_denoised: x_recon.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_recon, x_t=x, t=t ) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance( x=x, t=t, clip_denoised=clip_denoised ) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = ( (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() ) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() def p_sample_loop(self, shape, return_intermediates=False): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] for i in tqdm( reversed(range(0, self.num_timesteps)), desc="Sampling t", total=self.num_timesteps, ): img = self.p_sample( img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised, ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample(self, batch_size=16, return_intermediates=False): shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) channels = self.channels return self.p_sample_loop(shape, return_intermediates=return_intermediates) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def get_loss(self, pred, target, mean=True): if self.loss_type == "l1": loss = (target - pred).abs() if mean: loss = loss.mean() elif self.loss_type == "l2": if mean: loss = torch.nn.functional.mse_loss(target, pred) else: loss = torch.nn.functional.mse_loss(target, pred, reduction="none") else: raise NotImplementedError("unknown loss type '{loss_type}'") return loss def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v ) def predict_eps_from_z_and_v(self, x_t, t, v): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t ) def get_v(self, x, noise, t): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x ) def p_losses(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t) mse_loss_weight = None alpha = extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape) sigma = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape) snr = (alpha / sigma) ** 2 # velocity = (alpha[:, None, None, None] * x_noisy - x_start) / sigma[:, None, None, None] # get loss weight if self.parameterization != "x0": mse_loss_weight = torch.ones_like(t) k = 5.0 # min{snr, k} mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0] / snr else: k = 5.0 # min{snr, k} mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0] loss_dict = {} if self.parameterization == "eps": target = noise elif self.parameterization == "x0": target = x_start elif self.parameterization == "v": target = self.get_v(x_start, noise, t) else: raise NotImplementedError( f"Paramterization {self.parameterization} not yet supported" ) loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) loss = mse_loss_weight * loss log_prefix = "train" if self.training else "val" loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb loss_dict.update({f"{log_prefix}/loss": loss}) return loss, loss_dict def forward(self, x, *args, **kwargs): # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' t = torch.randint( 0, self.num_timesteps, (x.shape[0],), device=self.device ).long() return self.p_losses(x, t, *args, **kwargs) def get_input(self, batch, k): # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch # fbank, stft, label_indices, fname, waveform, text = batch # a = 1/0 fname, text, label_indices, waveform, stft, fbank, mos = ( batch["fname"], batch["text"], batch["label_vector"], batch["waveform"], batch["stft"], batch["log_mel_spec"], batch["mos"], # batch ) # for i in range(fbank.size(0)): # fb = fbank[i].numpy() # seg_lb = seg_label[i].numpy() # logits = np.mean(seg_lb, axis=0) # index = np.argsort(logits)[::-1][:5] # plt.imshow(seg_lb[:,index], aspect="auto") # plt.title(index) # plt.savefig("%s_label.png" % i) # plt.close() # plt.imshow(fb, aspect="auto") # plt.savefig("%s_fb.png" % i) # plt.close() ret = {} ret["fbank"] = ( fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() ) ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() ret["text"] = list(text) ret["fname"] = fname ret["mos"] = list(mos) for key in batch.keys(): if key not in ret.keys(): ret[key] = batch[key] return ret[k] def shared_step(self, batch): x = self.get_input(batch, self.first_stage_key) loss, loss_dict = self(x) return loss, loss_dict def warmup_step(self): if self.initial_learning_rate is None: self.initial_learning_rate = self.learning_rate # Only the first parameter group if self.global_step <= self.warmup_steps: if self.global_step == 0: print( "Warming up learning rate start with %s" % self.initial_learning_rate ) self.trainer.optimizers[0].param_groups[0]["lr"] = ( self.global_step / self.warmup_steps ) * self.initial_learning_rate else: # TODO set learning rate here self.trainer.optimizers[0].param_groups[0][ "lr" ] = self.initial_learning_rate def training_step(self, batch, batch_idx): # You instantiate a optimizer for the scheduler # But later you overwrite the optimizer by reloading its states from a checkpoint # So you need to replace the optimizer with the checkpoint one # if(self.lr_schedulers().optimizer.param_groups[0]['lr'] != self.trainer.optimizers[0].param_groups[0]['lr']): # self.lr_schedulers().optimizer = self.trainer.optimizers[0] # if(self.ckpt is not None): # self.reload_everything() # self.ckpt = None self.random_clap_condition() self.warmup_step() # if ( # self.state is None # and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0 # ): # self.state = ( # self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone() # ) # elif self.state is not None and batch_idx % 1000 == 0: # assert ( # torch.sum( # torch.abs( # self.state # - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"] # ) # ) # > 1e-7 # ), "Optimizer is not working" if len(self.metrics_buffer.keys()) > 0: for k in self.metrics_buffer.keys(): self.log( k, self.metrics_buffer[k], prog_bar=False, logger=True, on_step=True, on_epoch=False, ) # print(k, self.metrics_buffer[k]) self.metrics_buffer = {} loss, loss_dict = self.shared_step(batch) self.log_dict( {k: float(v) for k, v in loss_dict.items()}, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log( "global_step", float(self.global_step), prog_bar=True, logger=True, on_step=True, on_epoch=False, ) lr = self.trainer.optimizers[0].param_groups[0]["lr"] self.log( "lr_abs", float(lr), prog_bar=True, logger=True, on_step=True, on_epoch=False, ) return loss def random_clap_condition(self): # This function is only used during training, let the CLAP model to use both text and audio as condition assert self.training == True for key in self.cond_stage_model_metadata.keys(): metadata = self.cond_stage_model_metadata[key] model_idx, cond_stage_key, conditioning_key = ( metadata["model_idx"], metadata["cond_stage_key"], metadata["conditioning_key"], ) # If we use CLAP as condition, we might use audio for training, but we also must use text for evaluation if isinstance( self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2 ): self.cond_stage_model_metadata[key][ "cond_stage_key_orig" ] = self.cond_stage_model_metadata[key]["cond_stage_key"] self.cond_stage_model_metadata[key][ "embed_mode_orig" ] = self.cond_stage_models[model_idx].embed_mode if torch.randn(1).item() < 0.5: self.cond_stage_model_metadata[key]["cond_stage_key"] = "text" self.cond_stage_models[model_idx].embed_mode = "text" else: self.cond_stage_model_metadata[key]["cond_stage_key"] = "waveform" self.cond_stage_models[model_idx].embed_mode = "audio" def on_validation_epoch_start(self) -> None: # Use text as condition during validation for key in self.cond_stage_model_metadata.keys(): metadata = self.cond_stage_model_metadata[key] model_idx, cond_stage_key, conditioning_key = ( metadata["model_idx"], metadata["cond_stage_key"], metadata["conditioning_key"], ) # If we use CLAP as condition, we might use audio for training, but we also must use text for evaluation if isinstance( self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2 ): self.cond_stage_model_metadata[key][ "cond_stage_key_orig" ] = self.cond_stage_model_metadata[key]["cond_stage_key"] self.cond_stage_model_metadata[key][ "embed_mode_orig" ] = self.cond_stage_models[model_idx].embed_mode print( "Change the model original cond_keyand embed_mode %s, %s to text during evaluation" % ( self.cond_stage_model_metadata[key]["cond_stage_key_orig"], self.cond_stage_model_metadata[key]["embed_mode_orig"], ) ) self.cond_stage_model_metadata[key]["cond_stage_key"] = "text" self.cond_stage_models[model_idx].embed_mode = "text" if isinstance( self.cond_stage_models[model_idx], CLAPGenAudioMAECond ) or isinstance(self.cond_stage_models[model_idx], SequenceGenAudioMAECond): self.cond_stage_model_metadata[key][ "use_gt_mae_output_orig" ] = self.cond_stage_models[model_idx].use_gt_mae_output self.cond_stage_model_metadata[key][ "use_gt_mae_prob_orig" ] = self.cond_stage_models[model_idx].use_gt_mae_prob print("Change the model condition to the predicted AudioMAE tokens") self.cond_stage_models[model_idx].use_gt_mae_output = False self.cond_stage_models[model_idx].use_gt_mae_prob = 0.0 self.validation_folder_name = self.get_validation_folder_name() return super().on_validation_epoch_start() @torch.no_grad() def validation_step(self, batch, batch_idx): self.generate_sample( [batch], name=self.validation_folder_name, unconditional_guidance_scale=self.evaluation_params[ "unconditional_guidance_scale" ], ddim_steps=self.evaluation_params["ddim_sampling_steps"], n_gen=self.evaluation_params["n_candidates_per_samples"], ) def get_validation_folder_name(self): now = datetime.datetime.now() timestamp = now.strftime("%m-%d-%H:%M") return "val_%s_%s_cfg_scale_%s_ddim_%s_n_cand_%s" % ( self.global_step, timestamp, self.evaluation_params["unconditional_guidance_scale"], self.evaluation_params["ddim_sampling_steps"], self.evaluation_params["n_candidates_per_samples"], ) def on_validation_epoch_end(self) -> None: if self.global_rank == 0 and self.evaluator is not None: assert ( self.test_data_subset_path is not None ), "Please set test_data_subset_path before validation so that model have a target folder" try: name = self.validation_folder_name # import pdb # pdb.set_trace() waveform_save_path = os.path.join(self.get_log_dir(), name) if ( os.path.exists(waveform_save_path) and len(os.listdir(waveform_save_path)) > 0 ): metrics = self.evaluator.main( waveform_save_path, self.test_data_subset_path, ) self.metrics_buffer = { ("val/" + k): float(v) for k, v in metrics.items() } else: print( "The target folder for evaluation does not exist: %s" % waveform_save_path ) except Exception as e: print("Error encountered during evaluation: ", e) # Very important or the program may fail torch.cuda.synchronize() for key in self.cond_stage_model_metadata.keys(): metadata = self.cond_stage_model_metadata[key] model_idx, cond_stage_key, conditioning_key = ( metadata["model_idx"], metadata["cond_stage_key"], metadata["conditioning_key"], ) if isinstance( self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2 ): self.cond_stage_model_metadata[key][ "cond_stage_key" ] = self.cond_stage_model_metadata[key]["cond_stage_key_orig"] self.cond_stage_models[ model_idx ].embed_mode = self.cond_stage_model_metadata[key]["embed_mode_orig"] print( "Change back the embedding mode to %s %s" % ( self.cond_stage_model_metadata[key]["cond_stage_key"], self.cond_stage_models[model_idx].embed_mode, ) ) if isinstance( self.cond_stage_models[model_idx], CLAPGenAudioMAECond ) or isinstance(self.cond_stage_models[model_idx], SequenceGenAudioMAECond): self.cond_stage_models[ model_idx ].use_gt_mae_output = self.cond_stage_model_metadata[key][ "use_gt_mae_output_orig" ] self.cond_stage_models[ model_idx ].use_gt_mae_prob = self.cond_stage_model_metadata[key][ "use_gt_mae_prob_orig" ] print( "Change the AudioMAE condition setting to %s (Use gt) %s (gt prob)" % ( self.cond_stage_models[model_idx].use_gt_mae_output, self.cond_stage_models[model_idx].use_gt_mae_prob, ) ) return super().on_validation_epoch_end() def on_train_epoch_start(self, *args, **kwargs): print("Log directory: ", self.get_log_dir()) def on_train_batch_end(self, *args, **kwargs): # Does this affect speed? if self.use_ema: self.model_ema(self.model) def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) denoise_grid = rearrange(samples, "n b c h w -> b n c h w") denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): log = dict() x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) x = x.to(self.device)[:N] log["inputs"] = x # get diffusion row diffusion_row = list() x_start = x[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) diffusion_row.append(x_noisy) log["diffusion_row"] = self._get_rows_from_list(diffusion_row) if sample: # get denoise row with self.ema_scope("Plotting"): samples, denoise_row = self.sample( batch_size=N, return_intermediates=True ) log["samples"] = samples log["denoise_row"] = self._get_rows_from_list(denoise_row) if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.learn_logvar: params = params + [self.logvar] opt = torch.optim.AdamW(params, lr=lr) # opt = Adan(params, lr=lr, max_grad_norm=1, fused=True) return opt def initialize_param_check_toolkit(self): self.tracked_steps = 0 self.param_dict = {} def statistic_require_grad_tensor_number(self, module, name=None): requires_grad_num = 0 total_num = 0 require_grad_tensor = None for p in module.parameters(): if p.requires_grad: requires_grad_num += 1 if require_grad_tensor is None: require_grad_tensor = p total_num += 1 print( "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" % (name, requires_grad_num, total_num, requires_grad_num / total_num) ) return require_grad_tensor def check_module_param_update(self): if self.tracked_steps == 0: for name, module in self.named_children(): try: require_grad_tensor = self.statistic_require_grad_tensor_number( module, name=name ) if require_grad_tensor is not None: self.param_dict[name] = require_grad_tensor.clone() else: print("==> %s does not requires grad" % name) except Exception as e: print("%s does not have trainable parameters: %s" % (name, e)) continue if self.tracked_steps % 5000 == 0: for name, module in self.named_children(): try: require_grad_tensor = self.statistic_require_grad_tensor_number( module, name=name ) if require_grad_tensor is not None: print( "===> Param diff %s: %s; Size: %s" % ( name, torch.sum( torch.abs( self.param_dict[name] - require_grad_tensor ) ), require_grad_tensor.size(), ) ) else: print("%s does not requires grad" % name) except Exception as e: print("%s does not have trainable parameters: %s" % (name, e)) continue self.tracked_steps += 1 class LatentDiffusion(DDPM): """main class""" def __init__( self, first_stage_config, cond_stage_config=None, num_timesteps_cond=None, cond_stage_key="image", optimize_ddpm_parameter=True, unconditional_prob_cfg=0.1, warmup_steps=10000, cond_stage_trainable=False, concat_mode=True, cond_stage_forward=None, conditioning_key=None, scale_factor=1.0, batchsize=None, evaluation_params={}, scale_by_std=False, base_learning_rate=None, *args, **kwargs, ): self.learning_rate = base_learning_rate self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std self.warmup_steps = warmup_steps if optimize_ddpm_parameter: if unconditional_prob_cfg == 0.0: "You choose to optimize DDPM. The classifier free guidance scale should be 0.1" unconditional_prob_cfg = 0.1 else: if unconditional_prob_cfg == 0.1: "You choose not to optimize DDPM. The classifier free guidance scale should be 0.0" unconditional_prob_cfg = 0.0 self.evaluation_params = evaluation_params assert self.num_timesteps_cond <= kwargs["timesteps"] # for backwards compatibility after implementation of DiffusionWrapper # if conditioning_key is None: # conditioning_key = "concat" if concat_mode else "crossattn" # if cond_stage_config == "__is_unconditional__": # conditioning_key = None conditioning_key = list(cond_stage_config.keys()) self.conditioning_key = conditioning_key ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.optimize_ddpm_parameter = optimize_ddpm_parameter # if(not optimize_ddpm_parameter): # print("Warning: Close the optimization of the latent diffusion model") # for p in self.model.parameters(): # p.requires_grad=False self.concat_mode = concat_mode self.cond_stage_key = cond_stage_key self.cond_stage_key_orig = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor else: self.register_buffer("scale_factor", torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.unconditional_prob_cfg = unconditional_prob_cfg self.cond_stage_models = nn.ModuleList([]) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None self.conditional_dry_run_finished = False self.restarted_from_ckpt = False if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) for each in self.cond_stage_models: params = params + list( each.parameters() ) # Add the parameter from the conditional stage if self.learn_logvar: print("Diffusion model optimizing logvar") params.append(self.logvar) # opt = Adan(params, lr=lr, max_grad_norm=1, fused=True) opt = torch.optim.AdamW(params, lr=lr) # if self.use_scheduler: # assert "target" in self.scheduler_config # scheduler = instantiate_from_config(self.scheduler_config) # print("Setting up LambdaLR scheduler...") # scheduler = [ # { # "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), # "interval": "step", # "frequency": 1, # } # ] # return [opt], scheduler return opt def make_cond_schedule( self, ): self.cond_ids = torch.full( size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long, ) ids = torch.round( torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) ).long() self.cond_ids[: self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx): # only for very first batch if ( self.scale_factor == 1 and self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt ): # assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' # set rescale weight to 1./std of encodings print("### USING STD-RESCALING ###") x = super().get_input(batch, self.first_stage_key) x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor self.register_buffer("scale_factor", 1.0 / z.flatten().std()) print(f"setting self.scale_factor to {self.scale_factor}") print("### USING STD-RESCALING ###") def register_schedule( self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): super().register_schedule( given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s ) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: self.make_cond_schedule() def instantiate_first_stage(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): param.requires_grad = False def make_decision(self, probability): if float(torch.rand(1)) < probability: return True else: return False def instantiate_cond_stage(self, config): self.cond_stage_model_metadata = {} for i, cond_model_key in enumerate(config.keys()): model = instantiate_from_config(config[cond_model_key]) self.cond_stage_models.append(model) self.cond_stage_model_metadata[cond_model_key] = { "model_idx": i, "cond_stage_key": config[cond_model_key]["cond_stage_key"], "conditioning_key": config[cond_model_key]["conditioning_key"], } def get_first_stage_encoding(self, encoder_posterior): if isinstance(encoder_posterior, DiagonalGaussianDistribution): z = encoder_posterior.sample() elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError( f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" ) return self.scale_factor * z def get_learned_conditioning(self, c, key, unconditional_cfg): assert key in self.cond_stage_model_metadata.keys() # Classifier-free guidance if not unconditional_cfg: c = self.cond_stage_models[ self.cond_stage_model_metadata[key]["model_idx"] ](c) else: # when the cond_stage_key is "all", pick one random element out if isinstance(c, dict): c = c[list(c.keys())[0]] if isinstance(c, torch.Tensor): batchsize = c.size(0) elif isinstance(c, list): batchsize = len(c) else: raise NotImplementedError() c = self.cond_stage_models[ self.cond_stage_model_metadata[key]["model_idx"] ].get_unconditional_condition(batchsize) return c def get_input( self, batch, k, return_first_stage_encode=True, return_decoding_output=False, return_encoder_input=False, return_encoder_output=False, unconditional_prob_cfg=0.1, ): # print(self.cond_stage_model_metadata.keys()) x = super().get_input(batch, k) x = x.to(self.device) if return_first_stage_encode: encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() else: z = None cond_dict = {} if len(self.cond_stage_model_metadata.keys()) > 0: unconditional_cfg = False if self.conditional_dry_run_finished and self.make_decision( unconditional_prob_cfg ): unconditional_cfg = True for cond_model_key in self.cond_stage_model_metadata.keys(): cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ "cond_stage_key" ] if cond_model_key in cond_dict.keys(): continue if not self.training: if isinstance( self.cond_stage_models[ self.cond_stage_model_metadata[cond_model_key]["model_idx"] ], CLAPAudioEmbeddingClassifierFreev2, ): print( "Warning: CLAP model normally should use text for evaluation" ) # The original data for conditioning # If cond_model_key is "all", that means the conditional model need all the information from a batch if cond_stage_key != "all": xc = super().get_input(batch, cond_stage_key) if type(xc) == torch.Tensor: xc = xc.to(self.device) else: xc = batch # batch inference BUG #if cond_stage_key == 'text': # xc = xc[0] # if cond_stage_key is "all", xc will be a dictionary containing all keys # Otherwise xc will be an entry of the dictionary c = self.get_learned_conditioning( xc, key=cond_model_key, unconditional_cfg=unconditional_cfg ) # cond_dict will be used to condition the diffusion model # If one conditional model return multiple conditioning signal if isinstance(c, dict): for k in c.keys(): cond_dict[k] = c[k] else: cond_dict[cond_model_key] = c # If the key is accidently added to the dictionary and not in the condition list, remove the condition # for k in list(cond_dict.keys()): # if(k not in self.cond_stage_model_metadata.keys()): # del cond_dict[k] cond_dict['mos'] = batch['mos'] out = [z, cond_dict] if return_decoding_output: xrec = self.decode_first_stage(z) out += [xrec] if return_encoder_input: out += [x] if return_encoder_output: out += [encoder_posterior] if not self.conditional_dry_run_finished: self.conditional_dry_run_finished = True # Output is a dictionary, where the value could only be tensor or tuple return out def decode_first_stage(self, z): with torch.no_grad(): z = 1.0 / self.scale_factor * z decoding = self.first_stage_model.decode(z) return decoding def mel_spectrogram_to_waveform( self, mel, savepath=".", bs=None, name="outwav", save=True, n_gen=1 ): # Mel: [bs, 1, t-steps, fbins] if len(mel.size()) == 4: mel = mel.squeeze(1) mel = mel.permute(0, 2, 1) waveform = self.first_stage_model.vocoder(mel) waveform = waveform.cpu().detach().numpy() if save: self.save_waveform(waveform, savepath, name, n_gen) return waveform def encode_first_stage(self, x): with torch.no_grad(): return self.first_stage_model.encode(x) def extract_possible_loss_in_cond_dict(self, cond_dict): # This function enable the conditional module to return loss function that can optimize them assert isinstance(cond_dict, dict) losses = {} for cond_key in cond_dict.keys(): if "loss" in cond_key and "noncond" in cond_key: assert cond_key not in losses.keys() losses[cond_key] = cond_dict[cond_key] return losses def filter_useful_cond_dict(self, cond_dict): new_cond_dict = {} for key in cond_dict.keys(): if key in self.cond_stage_model_metadata.keys(): new_cond_dict[key] = cond_dict[key] # All the conditional key in the metadata should be used for key in self.cond_stage_model_metadata.keys(): assert key in new_cond_dict.keys(), "%s, %s" % ( key, str(new_cond_dict.keys()), ) try: new_cond_dict['mos'] = cond_dict['mos'] except: pass return new_cond_dict def shared_step(self, batch, **kwargs): # self.check_module_param_update() if self.training: # Classifier-free guidance unconditional_prob_cfg = self.unconditional_prob_cfg else: unconditional_prob_cfg = 0.0 # TODO possible bug here x, c = self.get_input( batch, self.first_stage_key, unconditional_prob_cfg=unconditional_prob_cfg ) if self.optimize_ddpm_parameter: loss, loss_dict = self(x, self.filter_useful_cond_dict(c)) else: loss_dict = {} loss = None additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c) assert isinstance(additional_loss_for_cond_modules, dict) loss_dict.update(additional_loss_for_cond_modules) if len(additional_loss_for_cond_modules.keys()) > 0: for k in additional_loss_for_cond_modules.keys(): if loss is None: loss = additional_loss_for_cond_modules[k] else: loss = loss + additional_loss_for_cond_modules[k] # for k,v in additional_loss_for_cond_modules.items(): # self.log( # "cond_stage/"+k, # float(v), # prog_bar=True, # logger=True, # on_step=True, # on_epoch=True, # ) if self.training: assert loss is not None return loss, loss_dict def forward(self, x, c, *args, **kwargs): t = torch.randint( 0, self.num_timesteps, (x.shape[0],), device=self.device ).long() # assert c is not None # c = self.get_learned_conditioning(c) loss, loss_dict = self.p_losses(x, c, t, *args, **kwargs) return loss, loss_dict def reorder_cond_dict(self, cond_dict): # To make sure the order is correct new_cond_dict = {} for key in self.conditioning_key: new_cond_dict[key] = cond_dict[key] new_cond_dict['mos'] = cond_dict['mos'] return new_cond_dict def apply_model(self, x_noisy, t, cond, return_ids=False): cond = self.reorder_cond_dict(cond) # import pdb; pdb.set_trace() x_recon = self.model(x_noisy, t, cond_dict=cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] else: return x_recon def p_losses(self, x_start, cond, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} prefix = "train" if self.training else "val" if self.parameterization == "x0": target = x_start elif self.parameterization == "eps": target = noise elif self.parameterization == "v": target = self.get_v(x_start, noise, t) else: raise NotImplementedError() # print(model_output.size(), target.size()) mse_loss_weight = None alpha = extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape) sigma = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape) snr = (alpha / sigma) ** 2 # velocity = (alpha[:, None, None, None] * x_t - x_start) / sigma[:, None, None, None] # get loss weight if self.parameterization != "x0": mse_loss_weight = torch.ones_like(t) k = 5.0 # min{snr, k} mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0] / snr else: k = 5.0 # min{snr, k} mse_loss_weight = torch.stack([snr, k * torch.ones_like(t)], dim=1).min(dim=1)[0] loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) loss_simple = loss_simple * mse_loss_weight # import pdb # pdb.set_trace() loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) loss_dict.update({"logvar": self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) loss += self.original_elbo_weight * loss_vlb loss_dict.update({f"{prefix}/loss": loss}) return loss, loss_dict def p_mean_variance( self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, return_x0=False, score_corrector=None, corrector_kwargs=None, ): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) if score_corrector is not None: assert self.parameterization == "eps" model_out = score_corrector.modify_score( self, model_out, x, t, c, **corrector_kwargs ) if return_codebook_ids: model_out, logits = model_out if self.parameterization == "eps": x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == "x0": x_recon = model_out else: raise NotImplementedError() if clip_denoised: x_recon.clamp_(-1.0, 1.0) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_recon, x_t=x, t=t ) if return_codebook_ids: return model_mean, posterior_variance, posterior_log_variance, logits elif return_x0: return model_mean, posterior_variance, posterior_log_variance, x_recon else: return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample( self, x, c, t, clip_denoised=False, repeat_noise=False, return_codebook_ids=False, quantize_denoised=False, return_x0=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, ): b, *_, device = *x.shape, x.device outputs = self.p_mean_variance( x=x, c=c, t=t, clip_denoised=clip_denoised, return_codebook_ids=return_codebook_ids, quantize_denoised=quantize_denoised, return_x0=return_x0, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, ) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs elif return_x0: model_mean, _, model_log_variance, x0 = outputs else: model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = ( (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() ) # if return_codebook_ids: # return model_mean + nonzero_mask * ( # 0.5 * model_log_variance # ).exp() * noise, logits.argmax(dim=1) if return_x0: return ( model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0, ) else: return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() def progressive_denoising( self, cond, shape, verbose=True, callback=None, quantize_denoised=False, img_callback=None, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, log_every_t=None, ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps if batch_size is not None: b = batch_size if batch_size is not None else shape[0] shape = [batch_size] + list(shape) else: b = batch_size = shape[0] if x_T is None: img = torch.randn(shape, device=self.device) else: img = x_T intermediates = [] if cond is not None: if isinstance(cond, dict): cond = { key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond } else: cond = ( [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] ) if start_T is not None: timesteps = min(timesteps, start_T) iterator = ( tqdm( reversed(range(0, timesteps)), desc="Progressive Generation", total=timesteps, ) if verbose else reversed(range(0, timesteps)) ) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img, x0_partial = self.p_sample( img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, return_x0=True, temperature=temperature[i], noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) if callback: callback(i) if img_callback: img_callback(img, i) return img, intermediates @torch.no_grad() def p_sample_loop( self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None, ): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T intermediates = [img] if timesteps is None: timesteps = self.num_timesteps if start_T is not None: timesteps = min(timesteps, start_T) iterator = ( tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) if verbose else reversed(range(0, timesteps)) ) if mask is not None: assert x0 is not None assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) # import pdb # pdb.set_trace() img = self.p_sample( img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, ) if mask is not None: img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) if callback: callback(i) if img_callback: img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample( self, cond, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, mask=None, x0=None, shape=None, **kwargs, ): if shape is None: shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) if cond is not None: if isinstance(cond, dict): cond = { key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond } else: cond = ( [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] ) return self.p_sample_loop( cond, shape, return_intermediates=return_intermediates, x_T=x_T, verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, mask=mask, x0=x0, **kwargs, ) def save_waveform(self, waveform, savepath, name="outwav", n_gen=1): print(f'debug_name : {name}') if type(name) != str and len(name[0][1]) > 1: name = list(name[0][1]) name = [_.decode() if type(_) is bytes else _ for _ in name] n_gen = int(waveform.shape[0] / len(name)) assert len(name) * n_gen == waveform.shape[0] lenn = len(name) for i in range(n_gen - 1): for x in range(lenn): name.append(name[x]) assert len(name) == waveform.shape[0] for i in range(waveform.shape[0]): if type(name) is str: path = os.path.join(savepath, "%s_%s_%s.wav" % (self.global_step, i, name)) elif type(name) is list: path = os.path.join( savepath, "%s.wav" % ( os.path.basename(name[i]) if (not ".wav" in name[i]) else os.path.basename(name[i]).split(".")[0] ), ) else: # import pdb # pdb.set_trace() raise NotImplementedError todo_waveform = waveform[i, 0] todo_waveform = ( todo_waveform / np.max(np.abs(todo_waveform)) ) * 0.8 # Normalize the energy of the generation output try: sf.write(path, todo_waveform, samplerate=self.sampling_rate) except: print('waveform name ERROR!!!!!!!!!!!!') @torch.no_grad() def sample_log( self, cond, batch_size, ddim, ddim_steps, unconditional_guidance_scale=1.0, unconditional_conditioning=None, use_plms=False, mask=None, **kwargs, ): if mask is not None: shape = (self.channels, mask.size()[-2], mask.size()[-1]) else: shape = (self.channels, self.latent_t_size, self.latent_f_size) intermediate = None if ddim and not use_plms: print("Use ddim sampler") ddim_sampler = DDIMSampler(self) samples, intermediates = ddim_sampler.sample( ddim_steps, batch_size, shape, cond, verbose=False, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, mask=mask, **kwargs, ) elif use_plms: print("Use plms sampler") plms_sampler = PLMSSampler(self) samples, intermediates = plms_sampler.sample( ddim_steps, batch_size, shape, cond, verbose=False, unconditional_guidance_scale=unconditional_guidance_scale, mask=mask, unconditional_conditioning=unconditional_conditioning, **kwargs, ) else: print("Use DDPM sampler") samples, intermediates = self.sample( cond=cond, batch_size=batch_size, return_intermediates=True, unconditional_guidance_scale=unconditional_guidance_scale, mask=mask, unconditional_conditioning=unconditional_conditioning, **kwargs, ) return samples, intermediate @torch.no_grad() def generate_sample( self, batchs, ddim_steps=200, ddim_eta=1.0, x_T=None, n_gen=1, unconditional_guidance_scale=1.0, unconditional_conditioning=None, name=None, use_plms=False, limit_num=None, **kwargs, ): # Generate n_gen times and select the best # Batch: audio, text, fnames # import pdb # pdb.set_trace() assert x_T is None try: batchs = iter(batchs) except TypeError: raise ValueError("The first input argument should be an iterable object") if use_plms: assert ddim_steps is not None use_ddim = ddim_steps is not None if name is None: name = self.get_validation_folder_name() waveform_save_path = os.path.join(self.get_log_dir(), name) os.makedirs(waveform_save_path, exist_ok=True) print("Waveform save path: ", waveform_save_path) # if ( # "audiocaps" in waveform_save_path # and len(os.listdir(waveform_save_path)) >= 964 # ): # print("The evaluation has already been done at %s" % waveform_save_path) # return waveform_save_path with self.ema_scope("Plotting"): for i, batch in enumerate(batchs): #print(batch) z, c = self.get_input( batch, self.first_stage_key, unconditional_prob_cfg=0.0, # Do not output unconditional information in the c ) # import pdb; pdb.set_trace() if limit_num is not None and i * z.size(0) > limit_num: break c = self.filter_useful_cond_dict(c) text = super().get_input(batch, "text") mos = super().get_input(batch, "mos") # for cond_key in c.keys(): # c[cond_key] = self.cond_stage_models[self.cond_stage_model_metadata[cond_key]["model_idx"]](text[0]) # Generate multiple samples batch_size = z.shape[0] * n_gen # Generate multiple samples at a time and filter out the best # The condition to the diffusion wrapper can have many format # import pdb # pdb.set_trace() for cond_key in c.keys(): if isinstance(c[cond_key], list): for i in range(len(c[cond_key])): c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0) elif isinstance(c[cond_key], dict): for k in c[cond_key].keys(): c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0) else: c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0) text = text * n_gen mos = mos * n_gen c['mos'] = torch.stack(mos).unsqueeze(1) if unconditional_guidance_scale != 1.0: unconditional_conditioning = {} for key in self.cond_stage_model_metadata: model_idx = self.cond_stage_model_metadata[key]["model_idx"] unconditional_conditioning[key] = self.cond_stage_models[ model_idx ].get_unconditional_condition(batch_size) fnames = list(super().get_input(batch, "fname")) # import pdb; pdb.set_trace() samples, _ = self.sample_log( cond=c, batch_size=batch_size, x_T=x_T, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, use_plms=use_plms, ) mel = self.decode_first_stage(samples) # mel = super().get_input(batch, "log_mel_spec") waveform = self.mel_spectrogram_to_waveform( mel, savepath=waveform_save_path, bs=None, name=fnames, save=False, n_gen=n_gen ) if n_gen > 1: try: best_index = [] similarity = self.clap.cos_similarity( torch.FloatTensor(waveform).squeeze(1), text ) for i in range(z.shape[0]): candidates = similarity[i :: z.shape[0]] max_index = torch.argmax(candidates).item() best_index.append(i + max_index * z.shape[0]) waveform = waveform[best_index] print("Similarity between generated audio and text", similarity) print("Choose the following indexes:", best_index) except Exception as e: print("Warning: while calculating CLAP score (not fatal), ", e) self.save_waveform(waveform, waveform_save_path, name=fnames, n_gen=n_gen) return waveform_save_path class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key for key in self.conditioning_key: if ( "concat" in key or "crossattn" in key or "hybrid" in key or "film" in key or "noncond" in key ): continue else: raise Value("The conditioning key %s is illegal" % key) self.being_verbosed_once = False def forward(self, x, t, cond_dict: dict = {}): x = x.contiguous() t = t.contiguous() # import pdb # pdb.set_trace() # x with condition (or maybe not) xc = x y = None context_list, attn_mask_list = [], [] conditional_keys = cond_dict.keys() for key in conditional_keys: if "concat" in key: xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1) elif "film" in key: if y is None: y = cond_dict[key].squeeze(1) else: y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1) elif "crossattn" in key: # assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys()) if isinstance(cond_dict[key], dict): for k in cond_dict[key].keys(): if "crossattn" in k: context, attn_mask = cond_dict[key][ k ] # crossattn_audiomae_pooled: torch.Size([12, 128, 768]) else: assert len(cond_dict[key]) == 2, ( "The context condition for %s you returned should have two element, one context one mask" % (key) ) context, attn_mask = cond_dict[key] # The input to the UNet model is a list of context matrix context_list.append(context) attn_mask_list.append(attn_mask) elif ( "noncond" in key ): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary continue elif "mos" in key: mos = cond_dict['mos'] else: raise NotImplementedError() if not self.being_verbosed_once: print("The input shape to the diffusion model is as follows:") print("xc", xc.size()) print("t", t.size()) for i in range(len(context_list)): print( "context_%s" % i, context_list[i].size(), attn_mask_list[i].size() ) if y is not None: print("y", y.size()) self.being_verbosed_once = True # try: # out = self.diffusion_model.forward_with_dpmsolver( # xc, timestep=t, y=context_list, mask=attn_mask_list # ) # except: out = self.diffusion_model.forward( xc, timestep=t, context_list=context_list, context_mask_list=attn_mask_list, mos=mos ) return out @torch.no_grad() def forward_with_cfg(self, x, t, cond_dict: dict = {}, cfg_scale=4.0, **model_kwargs): x = x.contiguous() t = t.contiguous() # x with condition (or maybe not) xc = x y = None context_list, attn_mask_list = [], [] conditional_keys = cond_dict.keys() for key in conditional_keys: if "concat" in key: xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1) elif "film" in key: if y is None: y = cond_dict[key].squeeze(1) else: y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1) elif "crossattn" in key: # assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys()) if isinstance(cond_dict[key], dict): for k in cond_dict[key].keys(): if "crossattn" in k: context, attn_mask = cond_dict[key][ k ] # crossattn_audiomae_pooled: torch.Size([12, 128, 768]) else: assert len(cond_dict[key]) == 2, ( "The context condition for %s you returned should have two element, one context one mask" % (key) ) context, attn_mask = cond_dict[key] # The input to the UNet model is a list of context matrix context_list.append(context) attn_mask_list.append(attn_mask) elif ( "noncond" in key ): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary continue else: raise NotImplementedError() if not self.being_verbosed_once: print("The input shape to the diffusion model is as follows:") print("xc", xc.size()) print("t", t.size()) for i in range(len(context_list)): print( "context_%s" % i, context_list[i].size(), attn_mask_list[i].size() ) if y is not None: print("y", y.size()) self.being_verbosed_once = True # try: # out = self.diffusion_model.forward_with_dpmsolver( # xc, timestep=t, y=context_list, mask=attn_mask_list # ) # except: out = self.diffusion_model.forward_with_cfg( xc, timestep=t, context_list=context_list, context_mask_list=attn_mask_list, cfg_scale=cfg_scale, **model_kwargs ) # import pdb # pdb.set_trace() return out class LatentDiffusionSpeedTest(pl.LightningModule): """main class""" def __init__( self, first_stage_config, cond_stage_config=None, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, concat_mode=True, cond_stage_forward=None, conditioning_key=None, scale_factor=1.0, batchsize=None, evaluation_params={}, scale_by_std=False, base_learning_rate=None, *args, **kwargs, ): super().__init__() self.l1 = nn.Linear(1, 1) self.logger_save_dir = None self.logger_exp_group_name = None self.logger_exp_name = None self.test_data_subset_path = None def set_log_dir(self, save_dir, exp_group_name, exp_name): self.logger_save_dir = save_dir self.logger_exp_group_name = exp_group_name self.logger_exp_name = exp_name def forward(self, x): return self.l1(x.permute(0, 2, 1)).permute(0, 2, 1) def training_step(self, batch, batch_idx): x = batch["waveform"] loss = self(x) return torch.mean(loss) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) class LatentDiffusionVAELearnable(LatentDiffusion): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.automatic_optimization = False def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) for each in self.cond_stage_models: params = params + list( each.parameters() ) # Add the parameter from the conditional stage if self.learn_logvar: print("Diffusion model optimizing logvar") params.append(self.logvar) ldm_opt = torch.optim.AdamW(params, lr=lr) opt_autoencoder, opt_scheduler = self.first_stage_model.configure_optimizers() opt_ae, opt_disc = opt_autoencoder return [ldm_opt, opt_ae, opt_disc], [] def encode_first_stage(self, x): # with torch.no_grad(): encoding = self.first_stage_model.encode(x) return encoding def decode_first_stage(self, z): # with torch.no_grad(): z = 1.0 / self.scale_factor * z decoding = self.first_stage_model.decode(z) return decoding def instantiate_first_stage(self, config): model = instantiate_from_config(config) self.first_stage_model = model.train() # self.first_stage_model.train = disabled_train # for param in self.first_stage_model.parameters(): # param.requires_grad = False def shared_step(self, batch, **kwargs): ldm_opt, g_opt, d_opt = self.optimizers() if self.training: # Classifier-free guidance unconditional_prob_cfg = self.unconditional_prob_cfg else: unconditional_prob_cfg = 0.0 x, c, decoder_xrec, encoder_x, encoder_posterior = self.get_input( batch, self.first_stage_key, unconditional_prob_cfg=unconditional_prob_cfg, return_decoding_output=True, return_encoder_input=True, return_encoder_output=True, ) loss, loss_dict = self(x, self.filter_useful_cond_dict(c)) additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c) assert isinstance(additional_loss_for_cond_modules, dict) loss_dict.update(additional_loss_for_cond_modules) if len(additional_loss_for_cond_modules.keys()) > 0: for k in additional_loss_for_cond_modules.keys(): loss = loss + additional_loss_for_cond_modules[k] for k, v in additional_loss_for_cond_modules.items(): self.log( "cond_stage/" + k, float(v), prog_bar=True, logger=True, on_step=True, on_epoch=True, ) aeloss, log_dict_ae = self.first_stage_model.loss( encoder_x, decoder_xrec, encoder_posterior, optimizer_idx=0, global_step=self.first_stage_model.global_step, last_layer=self.first_stage_model.get_last_layer(), split="train", ) self.manual_backward(loss + aeloss) ldm_opt.step() ldm_opt.zero_grad() g_opt.step() g_opt.zero_grad() discloss, log_dict_disc = self.first_stage_model.loss( encoder_x, decoder_xrec, encoder_posterior, optimizer_idx=1, global_step=self.first_stage_model.global_step, last_layer=self.first_stage_model.get_last_layer(), split="train", ) self.manual_backward(discloss) d_opt.step() d_opt.zero_grad() self.log( "aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=False, ) self.log( "posterior_std", torch.mean(encoder_posterior.var), prog_bar=True, logger=True, on_step=True, on_epoch=False, ) loss_dict.update(log_dict_disc) loss_dict.update(log_dict_ae) return None, loss_dict def training_step(self, batch, batch_idx): self.warmup_step() self.check_module_param_update() if ( self.state is None and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0 ): self.state = ( self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone() ) elif self.state is not None and batch_idx % 1000 == 0: assert ( torch.sum( torch.abs( self.state - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"] ) ) > 1e-7 ), "Optimizer is not working" if len(self.metrics_buffer.keys()) > 0: for k in self.metrics_buffer.keys(): self.log( k, self.metrics_buffer[k], prog_bar=False, logger=True, on_step=True, on_epoch=False, ) print(k, self.metrics_buffer[k]) self.metrics_buffer = {} loss, loss_dict = self.shared_step(batch) self.log_dict( {k: float(v) for k, v in loss_dict.items()}, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log( "global_step", float(self.global_step), prog_bar=True, logger=True, on_step=True, on_epoch=False, ) lr = self.trainer.optimizers[0].param_groups[0]["lr"] self.log( "lr_abs", float(lr), prog_bar=True, logger=True, on_step=True, on_epoch=False, )