|
import os |
|
import argparse |
|
import torch |
|
|
|
from logger import utils |
|
from data_loaders import get_data_loaders |
|
from solver import train |
|
from ddsp.vocoder import Sins, CombSub, CombSubFast |
|
from ddsp.loss import RSSLoss |
|
|
|
|
|
def parse_args(args=None, namespace=None): |
|
"""Parse command-line arguments.""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config", |
|
type=str, |
|
required=True, |
|
help="path to the config file") |
|
return parser.parse_args(args=args, namespace=namespace) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
cmd = parse_args() |
|
|
|
|
|
args = utils.load_config(cmd.config) |
|
print(' > config:', cmd.config) |
|
print(' > exp:', args.env.expdir) |
|
|
|
|
|
model = None |
|
|
|
if args.model.type == 'Sins': |
|
model = Sins( |
|
sampling_rate=args.data.sampling_rate, |
|
block_size=args.data.block_size, |
|
n_harmonics=args.model.n_harmonics, |
|
n_mag_allpass=args.model.n_mag_allpass, |
|
n_mag_noise=args.model.n_mag_noise, |
|
n_unit=args.data.encoder_out_channels, |
|
n_spk=args.model.n_spk) |
|
|
|
elif args.model.type == 'CombSub': |
|
model = CombSub( |
|
sampling_rate=args.data.sampling_rate, |
|
block_size=args.data.block_size, |
|
n_mag_allpass=args.model.n_mag_allpass, |
|
n_mag_harmonic=args.model.n_mag_harmonic, |
|
n_mag_noise=args.model.n_mag_noise, |
|
n_unit=args.data.encoder_out_channels, |
|
n_spk=args.model.n_spk) |
|
|
|
elif args.model.type == 'CombSubFast': |
|
model = CombSubFast( |
|
sampling_rate=args.data.sampling_rate, |
|
block_size=args.data.block_size, |
|
n_unit=args.data.encoder_out_channels, |
|
n_spk=args.model.n_spk) |
|
|
|
else: |
|
raise ValueError(f" [x] Unknown Model: {args.model.type}") |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters()) |
|
initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = args.train.lr |
|
param_group['weight_decay'] = args.train.weight_decay |
|
|
|
|
|
loss_func = RSSLoss(args.loss.fft_min, args.loss.fft_max, args.loss.n_scale, device = args.device) |
|
|
|
|
|
if args.device == 'cuda': |
|
torch.cuda.set_device(args.env.gpu_id) |
|
model.to(args.device) |
|
|
|
for state in optimizer.state.values(): |
|
for k, v in state.items(): |
|
if torch.is_tensor(v): |
|
state[k] = v.to(args.device) |
|
|
|
loss_func.to(args.device) |
|
|
|
|
|
loader_train, loader_valid = get_data_loaders(args, whole_audio=False) |
|
|
|
|
|
train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_valid) |
|
|
|
|