File size: 4,259 Bytes
fa4dd2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from abc import ABCMeta
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from .modules import TFC_TDF
dim_s = 4
class AbstractMDXNet(LightningModule):
__metaclass__ = ABCMeta
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
super().__init__()
self.target_name = target_name
self.lr = lr
self.optimizer = optimizer
self.dim_c = dim_c
self.dim_f = dim_f
self.dim_t = dim_t
self.n_fft = n_fft
self.n_bins = n_fft // 2 + 1
self.hop_length = hop_length
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
def configure_optimizers(self):
if self.optimizer == 'rmsprop':
return torch.optim.RMSprop(self.parameters(), self.lr)
if self.optimizer == 'adamw':
return torch.optim.AdamW(self.parameters(), self.lr)
class ConvTDFNet(AbstractMDXNet):
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length,
num_blocks, l, g, k, bn, bias, overlap):
super(ConvTDFNet, self).__init__(
target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
self.save_hyperparameters()
self.num_blocks = num_blocks
self.l = l
self.g = g
self.k = k
self.bn = bn
self.bias = bias
if optimizer == 'rmsprop':
norm = nn.BatchNorm2d
if optimizer == 'adamw':
norm = lambda input:nn.GroupNorm(2, input)
self.n = num_blocks // 2
scale = (2, 2)
self.first_conv = nn.Sequential(
nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
norm(g),
nn.ReLU(),
)
f = self.dim_f
c = g
self.encoding_blocks = nn.ModuleList()
self.ds = nn.ModuleList()
for i in range(self.n):
self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
self.ds.append(
nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
norm(c + g),
nn.ReLU()
)
)
f = f // 2
c += g
self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
self.decoding_blocks = nn.ModuleList()
self.us = nn.ModuleList()
for i in range(self.n):
self.us.append(
nn.Sequential(
nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
norm(c - g),
nn.ReLU()
)
)
f = f * 2
c -= g
self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
self.final_conv = nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
)
def forward(self, x):
x = self.first_conv(x)
x = x.transpose(-1, -2)
ds_outputs = []
for i in range(self.n):
x = self.encoding_blocks[i](x)
ds_outputs.append(x)
x = self.ds[i](x)
x = self.bottleneck_block(x)
for i in range(self.n):
x = self.us[i](x)
x *= ds_outputs[-i - 1]
x = self.decoding_blocks[i](x)
x = x.transpose(-1, -2)
x = self.final_conv(x)
return x
class Mixer(nn.Module):
def __init__(self, device, mixer_path):
super(Mixer, self).__init__()
self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False)
self.load_state_dict(
torch.load(mixer_path, map_location=device)
)
def forward(self, x):
x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2)
x = self.linear(x)
return x.transpose(-1,-2).reshape(dim_s,2,-1) |