# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py import math import copy from pathlib import Path from random import random from functools import partial from collections import namedtuple from multiprocessing import cpu_count import torch from torch import nn, einsum import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.optim import Adam from torchvision import transforms as T, utils from einops import rearrange, reduce from einops.layers.torch import Rearrange from PIL import Image from tqdm.auto import tqdm from typing import Any, Dict, List, Optional, Tuple, Union # constants ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"]) # helpers functions def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if callable(d) else d def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def linear_beta_schedule(timesteps): scale = 1000 / timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_cumprod = ( torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 ) alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) class GaussianDiffusion(nn.Module): def __init__( self, timesteps=100, sampling_timesteps=None, beta_1=0.0001, beta_T=0.1, loss_type="l1", objective="pred_noise", beta_schedule="custom", p2_loss_weight_gamma=0.0, p2_loss_weight_k=1, ): super().__init__() self.objective = objective assert objective in { "pred_noise", "pred_x0", }, "objective must be either pred_noise (predict noise) \ or pred_x0 (predict image start)" self.timesteps = timesteps self.sampling_timesteps = sampling_timesteps self.beta_1 = beta_1 self.beta_T = beta_T self.loss_type = loss_type self.objective = objective self.beta_schedule = beta_schedule self.p2_loss_weight_gamma = p2_loss_weight_gamma self.p2_loss_weight_k = p2_loss_weight_k self.init_diff_hyper( self.timesteps, self.sampling_timesteps, self.beta_1, self.beta_T, self.loss_type, self.objective, self.beta_schedule, self.p2_loss_weight_gamma, self.p2_loss_weight_k, ) def init_diff_hyper( self, timesteps, sampling_timesteps, beta_1, beta_T, loss_type, objective, beta_schedule, p2_loss_weight_gamma, p2_loss_weight_k, ): if beta_schedule == "linear": betas = linear_beta_schedule(timesteps) elif beta_schedule == "cosine": betas = cosine_beta_schedule(timesteps) elif beta_schedule == "custom": betas = torch.linspace( beta_1, beta_T, timesteps, dtype=torch.float64 ) else: raise ValueError(f"unknown beta schedule {beta_schedule}") alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.loss_type = loss_type # sampling related parameters self.sampling_timesteps = default( sampling_timesteps, timesteps ) # default num sampling timesteps to number of timesteps at training assert self.sampling_timesteps <= timesteps # helper function to register buffer from float64 to float32 register_buffer = lambda name, val: self.register_buffer( name, val.to(torch.float32) ) register_buffer("betas", betas) register_buffer("alphas_cumprod", alphas_cumprod) register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) register_buffer( "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) ) register_buffer( "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod) ) register_buffer( "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod) ) register_buffer( "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer("posterior_variance", posterior_variance) # below: log calculation clipped because the posterior variance is 0 # at the beginning of the diffusion chain register_buffer( "posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)), ) register_buffer( "posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), ) register_buffer( "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod), ) # calculate p2 reweighting register_buffer( "p2_loss_weight", (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma, ) # helper functions def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def predict_noise_from_start(self, x_t, t, x0): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0 ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract( self.posterior_log_variance_clipped, t, x_t.shape ) return ( posterior_mean, posterior_variance, posterior_log_variance_clipped, ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def model_predictions(self, x, t, z, x_self_cond=None): model_output = self.model(x, t, z) if self.objective == "pred_noise": pred_noise = model_output x_start = self.predict_start_from_noise(x, t, model_output) elif self.objective == "pred_x0": pred_noise = self.predict_noise_from_start(x, t, model_output) x_start = model_output return ModelPrediction(pred_noise, x_start) def p_mean_variance( self, x: torch.Tensor, # B x N_x x dim t: int, z: torch.Tensor, x_self_cond=None, clip_denoised=False, ): preds = self.model_predictions(x, t, z) x_start = preds.pred_x_start if clip_denoised: raise NotImplementedError( "We don't clip the output because \ pose does not have a clear bound." ) ( model_mean, posterior_variance, posterior_log_variance, ) = self.q_posterior(x_start=x_start, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance, x_start @torch.no_grad() def p_sample( self, x: torch.Tensor, # B x N_x x dim t: int, z: torch.Tensor, x_self_cond=None, clip_denoised=False, cond_fn=None, cond_start_step=0, ): b, *_, device = *x.shape, x.device batched_times = torch.full( (x.shape[0],), t, device=x.device, dtype=torch.long ) model_mean, _, model_log_variance, x_start = self.p_mean_variance( x=x, t=batched_times, z=z, x_self_cond=x_self_cond, clip_denoised=clip_denoised, ) if cond_fn is not None and t < cond_start_step: model_mean = cond_fn(model_mean, t) noise = 0.0 else: noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0 pred = model_mean + (0.5 * model_log_variance).exp() * noise return pred, x_start @torch.no_grad() def p_sample_loop( self, shape, z: torch.Tensor, cond_fn=None, cond_start_step=0, ): batch, device = shape[0], self.betas.device # Init here pose = torch.randn(shape, device=device) x_start = None pose_process = [] pose_process.append(pose.unsqueeze(0)) for t in reversed(range(0, self.num_timesteps)): pose, _ = self.p_sample( x=pose, t=t, z=z, cond_fn=cond_fn, cond_start_step=cond_start_step, ) pose_process.append(pose.unsqueeze(0)) return pose, torch.cat(pose_process) @torch.no_grad() def sample(self, shape, z, cond_fn=None, cond_start_step=0): # TODO: add more variants sample_fn = self.p_sample_loop return sample_fn( shape, z=z, cond_fn=cond_fn, cond_start_step=cond_start_step ) def p_losses( self, x_start, t, z=None, noise=None, ): noise = default(noise, lambda: torch.randn_like(x_start)) # noise sample x = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x, t, z) if self.objective == "pred_noise": target = noise x_0_pred = self.predict_start_from_noise(x, t, model_out) elif self.objective == "pred_x0": target = x_start x_0_pred = model_out else: raise ValueError(f"unknown objective {self.objective}") loss = self.loss_fn(model_out, target, reduction="none") loss = reduce(loss, "b ... -> b (...)", "mean") loss = loss * extract(self.p2_loss_weight, t, loss.shape) return { "loss": loss, "noise": noise, "x_0_pred": x_0_pred, "x_t": x, "t": t, } def forward(self, pose, z=None, *args, **kwargs): b = len(pose) t = torch.randint( 0, self.num_timesteps, (b,), device=pose.device ).long() return self.p_losses(pose, t, z=z, *args, **kwargs) @property def loss_fn(self): if self.loss_type == "l1": return F.l1_loss elif self.loss_type == "l2": return F.mse_loss else: raise ValueError(f"invalid loss type {self.loss_type}")