pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
raw
history blame
12 kB
# This file trains a DDPM diffusion model on 2D datasets.
import argparse
import json
import time
import os.path as osp
import numpy as np
from tqdm.auto import tqdm
import npeet.entropy_estimators as ee
import pickle
import pathlib
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from ema_pytorch import EMA
import datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int, scale: float = 1.0):
super().__init__()
self.dim = dim
self.scale = scale
def forward(self, x: torch.Tensor):
x = x * self.scale
half_dim = self.dim // 2
emb = torch.log(torch.Tensor([10000.0])) / (half_dim - 1)
emb = torch.exp(-emb * torch.arange(half_dim)).to(device)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
return emb
class ResidualBlock(nn.Module):
def __init__(self, width: int):
super().__init__()
self.ff = nn.Linear(width, width)
self.act = nn.ReLU()
def forward(self, x: torch.Tensor):
return x + self.ff(self.act(x))
class MLPDenoiser(nn.Module):
def __init__(
self,
embedding_dim: int = 128,
hidden_dim: int = 256,
hidden_layers: int = 3,
):
super().__init__()
self.time_mlp = SinusoidalEmbedding(embedding_dim)
# sinusoidal embeddings help capture high-frequency patterns for low-dim data
self.input_mlp1 = SinusoidalEmbedding(embedding_dim, scale=25.0)
self.input_mlp2 = SinusoidalEmbedding(embedding_dim, scale=25.0)
self.gating_network = nn.Sequential(
nn.Linear(embedding_dim * 3, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
self.expert1 = nn.Sequential(
nn.Linear(embedding_dim * 3, hidden_dim),
*[ResidualBlock(hidden_dim) for _ in range(hidden_layers)],
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 2),
)
self.expert2 = nn.Sequential(
nn.Linear(embedding_dim * 3, hidden_dim),
*[ResidualBlock(hidden_dim) for _ in range(hidden_layers)],
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 2),
)
def forward(self, x, t):
x1_emb = self.input_mlp1(x[:, 0])
x2_emb = self.input_mlp2(x[:, 1])
t_emb = self.time_mlp(t)
emb = torch.cat([x1_emb, x2_emb, t_emb], dim=-1)
gating_weight = self.gating_network(emb)
expert1_output = self.expert1(emb)
expert2_output = self.expert2(emb)
return gating_weight * expert1_output + (1 - gating_weight) * expert2_output
class NoiseScheduler():
def __init__(
self,
num_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
):
self.num_timesteps = num_timesteps
if beta_schedule == "linear":
self.betas = torch.linspace(
beta_start, beta_end, num_timesteps, dtype=torch.float32).to(device)
elif beta_schedule == "quadratic":
self.betas = (torch.linspace(
beta_start ** 0.5, beta_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2).to(device)
else:
raise ValueError(f"Unknown beta schedule: {beta_schedule}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(device)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.).to(device)
# required for self.add_noise
self.sqrt_alphas_cumprod = (self.alphas_cumprod ** 0.5).to(device)
self.sqrt_one_minus_alphas_cumprod = ((1 - self.alphas_cumprod) ** 0.5).to(device)
# required for reconstruct_x0
self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod).to(device)
self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(
1 / self.alphas_cumprod - 1).to(device)
# required for q_posterior
self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod).to(
device)
self.posterior_mean_coef2 = ((1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (
1. - self.alphas_cumprod)).to(device)
def reconstruct_x0(self, x_t, t, noise):
s1 = self.sqrt_inv_alphas_cumprod[t]
s2 = self.sqrt_inv_alphas_cumprod_minus_one[t]
s1 = s1.reshape(-1, 1)
s2 = s2.reshape(-1, 1)
return s1 * x_t - s2 * noise
def q_posterior(self, x_0, x_t, t):
s1 = self.posterior_mean_coef1[t]
s2 = self.posterior_mean_coef2[t]
s1 = s1.reshape(-1, 1)
s2 = s2.reshape(-1, 1)
mu = s1 * x_0 + s2 * x_t
return mu
def get_variance(self, t):
if t == 0:
return 0
variance = self.betas[t] * (1. - self.alphas_cumprod_prev[t]) / (1. - self.alphas_cumprod[t])
variance = variance.clip(1e-20)
return variance
def step(self, model_output, timestep, sample):
t = timestep
pred_original_sample = self.reconstruct_x0(sample, t, model_output)
pred_prev_sample = self.q_posterior(pred_original_sample, sample, t)
variance = 0
if t > 0:
noise = torch.randn_like(model_output)
variance = (self.get_variance(t) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
def add_noise(self, x_start, x_noise, timesteps):
s1 = self.sqrt_alphas_cumprod[timesteps]
s2 = self.sqrt_one_minus_alphas_cumprod[timesteps]
s1 = s1.reshape(-1, 1)
s2 = s2.reshape(-1, 1)
return s1 * x_start + s2 * x_noise
def __len__(self):
return self.num_timesteps
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train_batch_size", type=int, default=256)
parser.add_argument("--eval_batch_size", type=int, default=10000)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--num_timesteps", type=int, default=100)
parser.add_argument("--num_train_steps", type=int, default=10000)
parser.add_argument("--beta_schedule", type=str, default="linear", choices=["linear", "quadratic"])
parser.add_argument("--embedding_dim", type=int, default=128)
parser.add_argument("--hidden_size", type=int, default=256)
parser.add_argument("--hidden_layers", type=int, default=3)
parser.add_argument("--out_dir", type=str, default="run_0")
config = parser.parse_args()
final_infos = {}
all_results = {}
pathlib.Path(config.out_dir).mkdir(parents=True, exist_ok=True)
for dataset_name in ["circle", "dino", "line", "moons"]:
dataset = datasets.get_dataset(dataset_name, n=100000)
dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
model = MLPDenoiser(
embedding_dim=config.embedding_dim,
hidden_dim=config.hidden_size,
hidden_layers=config.hidden_layers,
).to(device)
ema_model = EMA(model, beta=0.995, update_every=10).to(device)
noise_scheduler = NoiseScheduler(num_timesteps=config.num_timesteps, beta_schedule=config.beta_schedule)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
)
scheduler = CosineAnnealingLR(optimizer, T_max=config.num_train_steps)
train_losses = []
print("Training model...")
model.train()
global_step = 0
progress_bar = tqdm(total=config.num_train_steps)
progress_bar.set_description("Training")
start_time = time.time()
while global_step < config.num_train_steps:
for batch in dataloader:
if global_step >= config.num_train_steps:
break
batch = batch[0].to(device)
noise = torch.randn(batch.shape).to(device)
timesteps = torch.randint(
0, noise_scheduler.num_timesteps, (batch.shape[0],)
).long().to(device)
noisy = noise_scheduler.add_noise(batch, noise, timesteps)
noise_pred = model(noisy, timesteps)
loss = F.mse_loss(noise_pred, noise)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
optimizer.zero_grad()
ema_model.update()
scheduler.step()
progress_bar.update(1)
logs = {"loss": loss.detach().item()}
train_losses.append(loss.detach().item())
progress_bar.set_postfix(**logs)
global_step += 1
progress_bar.close()
end_time = time.time()
training_time = end_time - start_time
# Eval loss
model.eval()
eval_losses = []
for batch in dataloader:
batch = batch[0].to(device)
noise = torch.randn(batch.shape).to(device)
timesteps = torch.randint(
0, noise_scheduler.num_timesteps, (batch.shape[0],)
).long().to(device)
noisy = noise_scheduler.add_noise(batch, noise, timesteps)
noise_pred = model(noisy, timesteps)
loss = F.mse_loss(noise_pred, noise)
eval_losses.append(loss.detach().item())
eval_loss = np.mean(eval_losses)
# Eval image saving
ema_model.eval()
sample = torch.randn(config.eval_batch_size, 2).to(device)
timesteps = list(range(len(noise_scheduler)))[::-1]
inference_start_time = time.time()
for t in timesteps:
t = torch.from_numpy(np.repeat(t, config.eval_batch_size)).long().to(device)
with torch.no_grad():
residual = ema_model(sample, t)
sample = noise_scheduler.step(residual, t[0], sample)
sample = sample.cpu().numpy()
inference_end_time = time.time()
inference_time = inference_end_time - inference_start_time
# Eval estimated KL
real_data = dataset.tensors[0].numpy()
kl_divergence = ee.kldiv(real_data, sample, k=5)
# Calculate gating weights for visualization
with torch.no_grad():
x = torch.from_numpy(sample).float().to(device)
t = torch.zeros(x.shape[0], dtype=torch.long).to(device)
gating_weights = ema_model.ema_model.gating_network(
torch.cat([
ema_model.ema_model.input_mlp1(x[:, 0]),
ema_model.ema_model.input_mlp2(x[:, 1]),
ema_model.ema_model.time_mlp(t)
], dim=-1)
).cpu().numpy()
final_infos[dataset_name] = {
"means": {
"training_time": training_time,
"eval_loss": eval_loss,
"inference_time": inference_time,
"kl_divergence": kl_divergence,
}
}
all_results[dataset_name] = {
"train_losses": train_losses,
"images": sample,
"gating_weights": gating_weights,
}
with open(osp.join(config.out_dir, "final_info.json"), "w") as f:
json.dump(final_infos, f)
with open(osp.join(config.out_dir, "all_results.pkl"), "wb") as f:
pickle.dump(all_results, f)