|
import os |
|
import argparse |
|
import torch |
|
from torch.optim import lr_scheduler |
|
from logger import utils |
|
from diffusion.data_loaders import get_data_loaders |
|
from diffusion.solver import train |
|
from diffusion.unit2mel import Unit2Mel |
|
from diffusion.vocoder import Vocoder |
|
|
|
|
|
def parse_args(args=None, namespace=None): |
|
"""Parse command-line arguments.""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config", |
|
type=str, |
|
required=True, |
|
help="path to the config file") |
|
return parser.parse_args(args=args, namespace=namespace) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
cmd = parse_args() |
|
|
|
|
|
args = utils.load_config(cmd.config) |
|
print(' > config:', cmd.config) |
|
print(' > exp:', args.env.expdir) |
|
|
|
|
|
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) |
|
|
|
|
|
model = Unit2Mel( |
|
args.data.encoder_out_channels, |
|
args.model.n_spk, |
|
args.model.use_pitch_aug, |
|
vocoder.dimension, |
|
args.model.n_layers, |
|
args.model.n_chans, |
|
args.model.n_hidden) |
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters()) |
|
initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = args.train.lr |
|
param_group['weight_decay'] = args.train.weight_decay |
|
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma) |
|
|
|
|
|
if args.device == 'cuda': |
|
torch.cuda.set_device(args.env.gpu_id) |
|
model.to(args.device) |
|
|
|
for state in optimizer.state.values(): |
|
for k, v in state.items(): |
|
if torch.is_tensor(v): |
|
state[k] = v.to(args.device) |
|
|
|
|
|
loader_train, loader_valid = get_data_loaders(args, whole_audio=False) |
|
|
|
|
|
train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid) |
|
|
|
|