|
|
|
|
|
import argparse |
|
import json |
|
import time |
|
import os.path as osp |
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
import npeet.entropy_estimators as ee |
|
import pickle |
|
import pathlib |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from ema_pytorch import EMA |
|
|
|
import datasets |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class SinusoidalEmbedding(nn.Module): |
|
def __init__(self, dim: int, scale: float = 1.0): |
|
super().__init__() |
|
self.dim = dim |
|
self.scale = scale |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = x * self.scale |
|
half_dim = self.dim // 2 |
|
emb = torch.log(torch.Tensor([10000.0])) / (half_dim - 1) |
|
emb = torch.exp(-emb * torch.arange(half_dim)).to(device) |
|
emb = x.unsqueeze(-1) * emb.unsqueeze(0) |
|
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1) |
|
return emb |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, width: int): |
|
super().__init__() |
|
self.ff = nn.Linear(width, width) |
|
self.act = nn.ReLU() |
|
|
|
def forward(self, x: torch.Tensor): |
|
return x + self.ff(self.act(x)) |
|
|
|
|
|
class MLPDenoiser(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: int = 128, |
|
hidden_dim: int = 256, |
|
hidden_layers: int = 3, |
|
): |
|
super().__init__() |
|
self.time_mlp = SinusoidalEmbedding(embedding_dim) |
|
|
|
self.input_mlp1 = SinusoidalEmbedding(embedding_dim, scale=25.0) |
|
self.input_mlp2 = SinusoidalEmbedding(embedding_dim, scale=25.0) |
|
|
|
self.gating_network = nn.Sequential( |
|
nn.Linear(embedding_dim * 3, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim // 2, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.expert1 = nn.Sequential( |
|
nn.Linear(embedding_dim * 3, hidden_dim), |
|
*[ResidualBlock(hidden_dim) for _ in range(hidden_layers)], |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim // 2, 2), |
|
) |
|
|
|
self.expert2 = nn.Sequential( |
|
nn.Linear(embedding_dim * 3, hidden_dim), |
|
*[ResidualBlock(hidden_dim) for _ in range(hidden_layers)], |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim // 2, 2), |
|
) |
|
|
|
def forward(self, x, t): |
|
x1_emb = self.input_mlp1(x[:, 0]) |
|
x2_emb = self.input_mlp2(x[:, 1]) |
|
t_emb = self.time_mlp(t) |
|
emb = torch.cat([x1_emb, x2_emb, t_emb], dim=-1) |
|
|
|
gating_weight = self.gating_network(emb) |
|
expert1_output = self.expert1(emb) |
|
expert2_output = self.expert2(emb) |
|
|
|
return gating_weight * expert1_output + (1 - gating_weight) * expert2_output |
|
|
|
|
|
class NoiseScheduler(): |
|
def __init__( |
|
self, |
|
num_timesteps=1000, |
|
beta_start=0.0001, |
|
beta_end=0.02, |
|
beta_schedule="linear", |
|
): |
|
self.num_timesteps = num_timesteps |
|
if beta_schedule == "linear": |
|
self.betas = torch.linspace( |
|
beta_start, beta_end, num_timesteps, dtype=torch.float32).to(device) |
|
elif beta_schedule == "quadratic": |
|
self.betas = (torch.linspace( |
|
beta_start ** 0.5, beta_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2).to(device) |
|
else: |
|
raise ValueError(f"Unknown beta schedule: {beta_schedule}") |
|
|
|
self.alphas = 1.0 - self.betas |
|
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(device) |
|
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.).to(device) |
|
|
|
|
|
self.sqrt_alphas_cumprod = (self.alphas_cumprod ** 0.5).to(device) |
|
self.sqrt_one_minus_alphas_cumprod = ((1 - self.alphas_cumprod) ** 0.5).to(device) |
|
|
|
|
|
self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod).to(device) |
|
self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt( |
|
1 / self.alphas_cumprod - 1).to(device) |
|
|
|
|
|
self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod).to( |
|
device) |
|
self.posterior_mean_coef2 = ((1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / ( |
|
1. - self.alphas_cumprod)).to(device) |
|
|
|
def reconstruct_x0(self, x_t, t, noise): |
|
s1 = self.sqrt_inv_alphas_cumprod[t] |
|
s2 = self.sqrt_inv_alphas_cumprod_minus_one[t] |
|
s1 = s1.reshape(-1, 1) |
|
s2 = s2.reshape(-1, 1) |
|
return s1 * x_t - s2 * noise |
|
|
|
def q_posterior(self, x_0, x_t, t): |
|
s1 = self.posterior_mean_coef1[t] |
|
s2 = self.posterior_mean_coef2[t] |
|
s1 = s1.reshape(-1, 1) |
|
s2 = s2.reshape(-1, 1) |
|
mu = s1 * x_0 + s2 * x_t |
|
return mu |
|
|
|
def get_variance(self, t): |
|
if t == 0: |
|
return 0 |
|
|
|
variance = self.betas[t] * (1. - self.alphas_cumprod_prev[t]) / (1. - self.alphas_cumprod[t]) |
|
variance = variance.clip(1e-20) |
|
return variance |
|
|
|
def step(self, model_output, timestep, sample): |
|
t = timestep |
|
pred_original_sample = self.reconstruct_x0(sample, t, model_output) |
|
pred_prev_sample = self.q_posterior(pred_original_sample, sample, t) |
|
|
|
variance = 0 |
|
if t > 0: |
|
noise = torch.randn_like(model_output) |
|
variance = (self.get_variance(t) ** 0.5) * noise |
|
|
|
pred_prev_sample = pred_prev_sample + variance |
|
|
|
return pred_prev_sample |
|
|
|
def add_noise(self, x_start, x_noise, timesteps): |
|
s1 = self.sqrt_alphas_cumprod[timesteps] |
|
s2 = self.sqrt_one_minus_alphas_cumprod[timesteps] |
|
|
|
s1 = s1.reshape(-1, 1) |
|
s2 = s2.reshape(-1, 1) |
|
|
|
return s1 * x_start + s2 * x_noise |
|
|
|
def __len__(self): |
|
return self.num_timesteps |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--train_batch_size", type=int, default=256) |
|
parser.add_argument("--eval_batch_size", type=int, default=10000) |
|
parser.add_argument("--learning_rate", type=float, default=3e-4) |
|
parser.add_argument("--num_timesteps", type=int, default=100) |
|
parser.add_argument("--num_train_steps", type=int, default=10000) |
|
parser.add_argument("--beta_schedule", type=str, default="linear", choices=["linear", "quadratic"]) |
|
parser.add_argument("--embedding_dim", type=int, default=128) |
|
parser.add_argument("--hidden_size", type=int, default=256) |
|
parser.add_argument("--hidden_layers", type=int, default=3) |
|
parser.add_argument("--out_dir", type=str, default="run_0") |
|
config = parser.parse_args() |
|
|
|
final_infos = {} |
|
all_results = {} |
|
|
|
pathlib.Path(config.out_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
for dataset_name in ["circle", "dino", "line", "moons"]: |
|
dataset = datasets.get_dataset(dataset_name, n=100000) |
|
dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) |
|
|
|
model = MLPDenoiser( |
|
embedding_dim=config.embedding_dim, |
|
hidden_dim=config.hidden_size, |
|
hidden_layers=config.hidden_layers, |
|
).to(device) |
|
ema_model = EMA(model, beta=0.995, update_every=10).to(device) |
|
|
|
noise_scheduler = NoiseScheduler(num_timesteps=config.num_timesteps, beta_schedule=config.beta_schedule) |
|
|
|
optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
lr=config.learning_rate, |
|
) |
|
scheduler = CosineAnnealingLR(optimizer, T_max=config.num_train_steps) |
|
train_losses = [] |
|
print("Training model...") |
|
|
|
model.train() |
|
global_step = 0 |
|
progress_bar = tqdm(total=config.num_train_steps) |
|
progress_bar.set_description("Training") |
|
|
|
start_time = time.time() |
|
while global_step < config.num_train_steps: |
|
for batch in dataloader: |
|
if global_step >= config.num_train_steps: |
|
break |
|
batch = batch[0].to(device) |
|
noise = torch.randn(batch.shape).to(device) |
|
timesteps = torch.randint( |
|
0, noise_scheduler.num_timesteps, (batch.shape[0],) |
|
).long().to(device) |
|
|
|
noisy = noise_scheduler.add_noise(batch, noise, timesteps) |
|
noise_pred = model(noisy, timesteps) |
|
loss = F.mse_loss(noise_pred, noise) |
|
loss.backward() |
|
|
|
nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
ema_model.update() |
|
|
|
scheduler.step() |
|
progress_bar.update(1) |
|
logs = {"loss": loss.detach().item()} |
|
train_losses.append(loss.detach().item()) |
|
progress_bar.set_postfix(**logs) |
|
global_step += 1 |
|
|
|
progress_bar.close() |
|
end_time = time.time() |
|
training_time = end_time - start_time |
|
|
|
|
|
model.eval() |
|
eval_losses = [] |
|
for batch in dataloader: |
|
batch = batch[0].to(device) |
|
noise = torch.randn(batch.shape).to(device) |
|
timesteps = torch.randint( |
|
0, noise_scheduler.num_timesteps, (batch.shape[0],) |
|
).long().to(device) |
|
noisy = noise_scheduler.add_noise(batch, noise, timesteps) |
|
noise_pred = model(noisy, timesteps) |
|
loss = F.mse_loss(noise_pred, noise) |
|
eval_losses.append(loss.detach().item()) |
|
eval_loss = np.mean(eval_losses) |
|
|
|
|
|
ema_model.eval() |
|
sample = torch.randn(config.eval_batch_size, 2).to(device) |
|
timesteps = list(range(len(noise_scheduler)))[::-1] |
|
inference_start_time = time.time() |
|
for t in timesteps: |
|
t = torch.from_numpy(np.repeat(t, config.eval_batch_size)).long().to(device) |
|
with torch.no_grad(): |
|
residual = ema_model(sample, t) |
|
sample = noise_scheduler.step(residual, t[0], sample) |
|
sample = sample.cpu().numpy() |
|
inference_end_time = time.time() |
|
inference_time = inference_end_time - inference_start_time |
|
|
|
|
|
real_data = dataset.tensors[0].numpy() |
|
kl_divergence = ee.kldiv(real_data, sample, k=5) |
|
|
|
|
|
with torch.no_grad(): |
|
x = torch.from_numpy(sample).float().to(device) |
|
t = torch.zeros(x.shape[0], dtype=torch.long).to(device) |
|
gating_weights = ema_model.ema_model.gating_network( |
|
torch.cat([ |
|
ema_model.ema_model.input_mlp1(x[:, 0]), |
|
ema_model.ema_model.input_mlp2(x[:, 1]), |
|
ema_model.ema_model.time_mlp(t) |
|
], dim=-1) |
|
).cpu().numpy() |
|
|
|
final_infos[dataset_name] = { |
|
"means": { |
|
"training_time": training_time, |
|
"eval_loss": eval_loss, |
|
"inference_time": inference_time, |
|
"kl_divergence": kl_divergence, |
|
} |
|
} |
|
|
|
all_results[dataset_name] = { |
|
"train_losses": train_losses, |
|
"images": sample, |
|
"gating_weights": gating_weights, |
|
} |
|
|
|
with open(osp.join(config.out_dir, "final_info.json"), "w") as f: |
|
json.dump(final_infos, f) |
|
|
|
with open(osp.join(config.out_dir, "all_results.pkl"), "wb") as f: |
|
pickle.dump(all_results, f) |
|
|