Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn import functional as F | |
from .dit import DiffusionTransformer | |
from .adp import UNet1d | |
from .sampling import sample | |
import math | |
from model.base import BaseModule | |
import pdb | |
target_length = 1536 | |
def pad_and_create_mask(matrix, target_length): | |
T = matrix.shape[2] | |
if T > target_length: | |
raise ValueError("The third dimension length %s should not exceed %s" % (T, target_length)) | |
padding_size = target_length - T | |
padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0) | |
mask = torch.ones((1, target_length)) | |
mask[:, T:] = 0 # Set the padding part to 0 | |
return padded_matrix.to(matrix.device), mask.to(matrix.device) | |
class Stable_Diffusion(BaseModule): | |
def __init__(self, io_channels, input_concat_dim=None, embed_dim=768, depth=24, num_heads=24, | |
project_cond_tokens=False, transformer_type="continuous_transformer"): | |
super(Stable_Diffusion, self).__init__() | |
self.diffusion = DiffusionTransformer( | |
io_channels=io_channels, | |
input_concat_dim=input_concat_dim, | |
embed_dim=embed_dim, | |
# cond_token_dim=target_length, | |
depth=depth, | |
num_heads=num_heads, | |
project_cond_tokens=project_cond_tokens, | |
transformer_type=transformer_type, | |
) | |
# self.diffusion = UNet1d( | |
# in_channels=80, | |
# channels=256, | |
# resnet_groups=16, | |
# kernel_multiplier_downsample=2, | |
# multipliers=[4, 4, 4, 5, 5], | |
# factors=[1, 2, 2, 4], # θΎε ₯ιΏεΊ¦δΈδΈθ΄ε·η§―ηΌ©η | |
# num_blocks=[2, 2, 2, 2], | |
# attentions=[1, 3, 3, 3, 3], | |
# attention_heads=16, | |
# attention_multiplier=4, | |
# use_nearest_upsample=False, | |
# use_skip_scale=True, | |
# use_context_time=True | |
# ) | |
self.rng = torch.quasirandom.SobolEngine(1, scramble=True) | |
def forward(self, mu, mask, n_timesteps): | |
# pdb.set_trace() | |
mask = mask.squeeze(1) | |
noise = torch.randn_like(mu).to(mu.device) | |
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) | |
# extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask} | |
extra_args = {"input_concat_cond": mu, "mask": mask} | |
fakes = sample(self.diffusion, noise, n_timesteps, 0, **extra_args) | |
return fakes | |
def compute_loss(self, x0, mask, mu): | |
# pdb.set_trace() | |
t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device) | |
alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) | |
alphas = alphas[:, None, None] | |
sigmas = sigmas[:, None, None] | |
noise = torch.randn_like(x0) | |
noised_inputs = x0 * alphas + noise * sigmas | |
targets = noise * alphas - x0 * sigmas | |
mask = mask.squeeze(1) | |
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) | |
# output = self.diffusion(noised_inputs, t, cross_attn_cond=mu, | |
# cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1) | |
# pdb.set_trace() | |
output = self.diffusion(noised_inputs, # [bs, 80, 229] | |
t, # (bs,) | |
input_concat_cond=mu, | |
mask=mask, # [bs, 229] | |
cfg_dropout_prob=0.1) | |
return self.mse_loss(output, targets, mask), output | |
def mse_loss(self, output, targets, mask): | |
mse_loss = F.mse_loss(output, targets, reduction='none') | |
if mask.ndim == 2 and mse_loss.ndim == 3: | |
mask = mask.unsqueeze(1) | |
if mask.shape[1] != mse_loss.shape[1]: | |
mask = mask.repeat(1, mse_loss.shape[1], 1) | |
mse_loss = mse_loss * mask | |
mse_loss = mse_loss.mean() | |
return mse_loss | |