Spaces:
Running
on
L40S
Running
on
L40S
import logging | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from mmaudio.ext.autoencoder.edm2_utils import MPConv1D | |
from mmaudio.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, | |
Upsample1D, nonlinearity) | |
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution | |
log = logging.getLogger() | |
DATA_MEAN_80D = [ | |
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, | |
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728, | |
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, | |
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280, | |
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643, | |
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436, | |
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, | |
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673 | |
] | |
DATA_STD_80D = [ | |
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263, | |
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194, | |
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043, | |
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973, | |
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939, | |
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604, | |
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070 | |
] | |
DATA_MEAN_128D = [ | |
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597, | |
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033, | |
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, | |
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782, | |
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647, | |
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795, | |
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, | |
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960, | |
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712, | |
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120, | |
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, | |
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628, | |
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861 | |
] | |
DATA_STD_128D = [ | |
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659, | |
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557, | |
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182, | |
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991, | |
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900, | |
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817, | |
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609, | |
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812, | |
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451, | |
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877, | |
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164 | |
] | |
class VAE(nn.Module): | |
def __init__( | |
self, | |
*, | |
data_dim: int, | |
embed_dim: int, | |
hidden_dim: int, | |
): | |
super().__init__() | |
if data_dim == 80: | |
# self.data_mean = torch.tensor(DATA_MEAN_80D, dtype=torch.float32).cuda() | |
# self.data_std = torch.tensor(DATA_STD_80D, dtype=torch.float32).cuda() | |
self.register_buffer('data_mean', torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) | |
self.register_buffer('data_std', torch.tensor(DATA_STD_80D, dtype=torch.float32)) | |
elif data_dim == 128: | |
# torch.tensor(DATA_MEAN_128D, dtype=torch.float32).cuda() | |
# self.data_std = torch.tensor(DATA_STD_128D, dtype=torch.float32).cuda() | |
self.register_buffer('data_mean', torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) | |
self.register_buffer('data_std', torch.tensor(DATA_STD_128D, dtype=torch.float32)) | |
self.data_mean = self.data_mean.view(1, -1, 1) | |
self.data_std = self.data_std.view(1, -1, 1) | |
self.encoder = Encoder1D( | |
dim=hidden_dim, | |
ch_mult=(1, 2, 4), | |
num_res_blocks=2, | |
attn_layers=[3], | |
down_layers=[0], | |
in_dim=data_dim, | |
embed_dim=embed_dim, | |
) | |
self.decoder = Decoder1D( | |
dim=hidden_dim, | |
ch_mult=(1, 2, 4), | |
num_res_blocks=2, | |
attn_layers=[3], | |
down_layers=[0], | |
in_dim=data_dim, | |
out_dim=data_dim, | |
embed_dim=embed_dim, | |
) | |
self.embed_dim = embed_dim | |
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1) | |
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1) | |
self.initialize_weights() | |
def initialize_weights(self): | |
pass | |
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution: | |
if normalize: | |
x = self.normalize(x) | |
moments = self.encoder(x) | |
posterior = DiagonalGaussianDistribution(moments) | |
return posterior | |
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor: | |
dec = self.decoder(z) | |
if unnormalize: | |
dec = self.unnormalize(dec) | |
return dec | |
def normalize(self, x: torch.Tensor) -> torch.Tensor: | |
return (x - self.data_mean) / self.data_std | |
def unnormalize(self, x: torch.Tensor) -> torch.Tensor: | |
return x * self.data_std + self.data_mean | |
def forward( | |
self, | |
x: torch.Tensor, | |
sample_posterior: bool = True, | |
rng: Optional[torch.Generator] = None, | |
normalize: bool = True, | |
unnormalize: bool = True, | |
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]: | |
posterior = self.encode(x, normalize=normalize) | |
if sample_posterior: | |
z = posterior.sample(rng) | |
else: | |
z = posterior.mode() | |
dec = self.decode(z, unnormalize=unnormalize) | |
return dec, posterior | |
def load_weights(self, src_dict) -> None: | |
self.load_state_dict(src_dict, strict=True) | |
def device(self) -> torch.device: | |
return next(self.parameters()).device | |
def get_last_layer(self): | |
return self.decoder.conv_out.weight | |
def remove_weight_norm(self): | |
for name, m in self.named_modules(): | |
if isinstance(m, MPConv1D): | |
m.remove_weight_norm() | |
log.debug(f"Removed weight norm from {name}") | |
return self | |
class Encoder1D(nn.Module): | |
def __init__(self, | |
*, | |
dim: int, | |
ch_mult: tuple[int] = (1, 2, 4, 8), | |
num_res_blocks: int, | |
attn_layers: list[int] = [], | |
down_layers: list[int] = [], | |
resamp_with_conv: bool = True, | |
in_dim: int, | |
embed_dim: int, | |
double_z: bool = True, | |
kernel_size: int = 3, | |
clip_act: float = 256.0): | |
super().__init__() | |
self.dim = dim | |
self.num_layers = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.in_channels = in_dim | |
self.clip_act = clip_act | |
self.down_layers = down_layers | |
self.attn_layers = attn_layers | |
self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size) | |
in_ch_mult = (1, ) + tuple(ch_mult) | |
self.in_ch_mult = in_ch_mult | |
# downsampling | |
self.down = nn.ModuleList() | |
for i_level in range(self.num_layers): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_in = dim * in_ch_mult[i_level] | |
block_out = dim * ch_mult[i_level] | |
for i_block in range(self.num_res_blocks): | |
block.append( | |
ResnetBlock1D(in_dim=block_in, | |
out_dim=block_out, | |
kernel_size=kernel_size, | |
use_norm=True)) | |
block_in = block_out | |
if i_level in attn_layers: | |
attn.append(AttnBlock1D(block_in)) | |
down = nn.Module() | |
down.block = block | |
down.attn = attn | |
if i_level in down_layers: | |
down.downsample = Downsample1D(block_in, resamp_with_conv) | |
self.down.append(down) | |
# middle | |
self.mid = nn.Module() | |
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, | |
out_dim=block_in, | |
kernel_size=kernel_size, | |
use_norm=True) | |
self.mid.attn_1 = AttnBlock1D(block_in) | |
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, | |
out_dim=block_in, | |
kernel_size=kernel_size, | |
use_norm=True) | |
# end | |
self.conv_out = MPConv1D(block_in, | |
2 * embed_dim if double_z else embed_dim, | |
kernel_size=kernel_size) | |
self.learnable_gain = nn.Parameter(torch.zeros([])) | |
def forward(self, x): | |
# downsampling | |
hs = [self.conv_in(x)] | |
for i_level in range(self.num_layers): | |
for i_block in range(self.num_res_blocks): | |
h = self.down[i_level].block[i_block](hs[-1]) | |
if len(self.down[i_level].attn) > 0: | |
h = self.down[i_level].attn[i_block](h) | |
h = h.clamp(-self.clip_act, self.clip_act) | |
hs.append(h) | |
if i_level in self.down_layers: | |
hs.append(self.down[i_level].downsample(hs[-1])) | |
# middle | |
h = hs[-1] | |
h = self.mid.block_1(h) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h) | |
h = h.clamp(-self.clip_act, self.clip_act) | |
# end | |
h = nonlinearity(h) | |
h = self.conv_out(h, gain=(self.learnable_gain + 1)) | |
return h | |
class Decoder1D(nn.Module): | |
def __init__(self, | |
*, | |
dim: int, | |
out_dim: int, | |
ch_mult: tuple[int] = (1, 2, 4, 8), | |
num_res_blocks: int, | |
attn_layers: list[int] = [], | |
down_layers: list[int] = [], | |
kernel_size: int = 3, | |
resamp_with_conv: bool = True, | |
in_dim: int, | |
embed_dim: int, | |
clip_act: float = 256.0): | |
super().__init__() | |
self.ch = dim | |
self.num_layers = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.in_channels = in_dim | |
self.clip_act = clip_act | |
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one | |
# compute in_ch_mult, block_in and curr_res at lowest res | |
block_in = dim * ch_mult[self.num_layers - 1] | |
# z to block_in | |
self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size) | |
# middle | |
self.mid = nn.Module() | |
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) | |
self.mid.attn_1 = AttnBlock1D(block_in) | |
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) | |
# upsampling | |
self.up = nn.ModuleList() | |
for i_level in reversed(range(self.num_layers)): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_out = dim * ch_mult[i_level] | |
for i_block in range(self.num_res_blocks + 1): | |
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True)) | |
block_in = block_out | |
if i_level in attn_layers: | |
attn.append(AttnBlock1D(block_in)) | |
up = nn.Module() | |
up.block = block | |
up.attn = attn | |
if i_level in self.down_layers: | |
up.upsample = Upsample1D(block_in, resamp_with_conv) | |
self.up.insert(0, up) # prepend to get consistent order | |
# end | |
self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size) | |
self.learnable_gain = nn.Parameter(torch.zeros([])) | |
def forward(self, z): | |
# z to block_in | |
h = self.conv_in(z) | |
# middle | |
h = self.mid.block_1(h) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h) | |
h = h.clamp(-self.clip_act, self.clip_act) | |
# upsampling | |
for i_level in reversed(range(self.num_layers)): | |
for i_block in range(self.num_res_blocks + 1): | |
h = self.up[i_level].block[i_block](h) | |
if len(self.up[i_level].attn) > 0: | |
h = self.up[i_level].attn[i_block](h) | |
h = h.clamp(-self.clip_act, self.clip_act) | |
if i_level in self.down_layers: | |
h = self.up[i_level].upsample(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h, gain=(self.learnable_gain + 1)) | |
return h | |
def VAE_16k(**kwargs) -> VAE: | |
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs) | |
def VAE_44k(**kwargs) -> VAE: | |
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs) | |
def get_my_vae(name: str, **kwargs) -> VAE: | |
if name == '16k': | |
return VAE_16k(**kwargs) | |
if name == '44k': | |
return VAE_44k(**kwargs) | |
raise ValueError(f'Unknown model: {name}') | |
if __name__ == '__main__': | |
network = get_my_vae('standard') | |
# print the number of parameters in terms of millions | |
num_params = sum(p.numel() for p in network.parameters()) / 1e6 | |
print(f'Number of parameters: {num_params:.2f}M') | |