|
import os |
|
import time |
|
import numpy as np |
|
import torch |
|
|
|
from logger.saver import Saver |
|
from logger import utils |
|
|
|
def test(args, model, loss_func, loader_test, saver): |
|
print(' [*] testing...') |
|
model.eval() |
|
|
|
|
|
test_loss = 0. |
|
test_loss_rss = 0. |
|
test_loss_uv = 0. |
|
|
|
|
|
num_batches = len(loader_test) |
|
rtf_all = [] |
|
|
|
|
|
with torch.no_grad(): |
|
for bidx, data in enumerate(loader_test): |
|
fn = data['name'][0] |
|
print('--------') |
|
print('{}/{} - {}'.format(bidx, num_batches, fn)) |
|
|
|
|
|
for k in data.keys(): |
|
if k != 'name': |
|
data[k] = data[k].to(args.device) |
|
print('>>', data['name'][0]) |
|
|
|
|
|
st_time = time.time() |
|
signal, _, (s_h, s_n) = model(data['units'], data['f0'], data['volume'], data['spk_id']) |
|
ed_time = time.time() |
|
|
|
|
|
min_len = np.min([signal.shape[1], data['audio'].shape[1]]) |
|
signal = signal[:,:min_len] |
|
data['audio'] = data['audio'][:,:min_len] |
|
|
|
|
|
run_time = ed_time - st_time |
|
song_time = data['audio'].shape[-1] / args.data.sampling_rate |
|
rtf = run_time / song_time |
|
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) |
|
rtf_all.append(rtf) |
|
|
|
|
|
loss = loss_func(signal, data['audio']) |
|
|
|
test_loss += loss.item() |
|
|
|
|
|
saver.log_audio({fn+'/gt.wav': data['audio'], fn+'/pred.wav': signal}) |
|
|
|
|
|
test_loss /= num_batches |
|
|
|
|
|
print(' [test_loss] test_loss:', test_loss) |
|
print(' Real Time Factor', np.mean(rtf_all)) |
|
return test_loss |
|
|
|
|
|
def train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_test): |
|
|
|
saver = Saver(args, initial_global_step=initial_global_step) |
|
|
|
|
|
params_count = utils.get_network_paras_amount({'model': model}) |
|
saver.log_info('--- model size ---') |
|
saver.log_info(params_count) |
|
|
|
|
|
best_loss = np.inf |
|
num_batches = len(loader_train) |
|
model.train() |
|
saver.log_info('======= start training =======') |
|
for epoch in range(args.train.epochs): |
|
for batch_idx, data in enumerate(loader_train): |
|
saver.global_step_increment() |
|
optimizer.zero_grad() |
|
|
|
|
|
for k in data.keys(): |
|
if k != 'name': |
|
data[k] = data[k].to(args.device) |
|
|
|
|
|
signal, _, (s_h, s_n) = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], infer=False) |
|
|
|
|
|
loss = loss_func(signal, data['audio']) |
|
|
|
|
|
if torch.isnan(loss): |
|
raise ValueError(' [x] nan loss ') |
|
else: |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
if saver.global_step % args.train.interval_log == 0: |
|
saver.log_info( |
|
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | loss: {:.3f} | time: {} | step: {}'.format( |
|
epoch, |
|
batch_idx, |
|
num_batches, |
|
args.env.expdir, |
|
args.train.interval_log/saver.get_interval_time(), |
|
loss.item(), |
|
saver.get_total_time(), |
|
saver.global_step |
|
) |
|
) |
|
|
|
saver.log_value({ |
|
'train/loss': loss.item() |
|
}) |
|
|
|
|
|
if saver.global_step % args.train.interval_val == 0: |
|
|
|
saver.save_model(model, optimizer, postfix=f'{saver.global_step}') |
|
|
|
|
|
test_loss = test(args, model, loss_func, loader_test, saver) |
|
|
|
saver.log_info( |
|
' --- <validation> --- \nloss: {:.3f}. '.format( |
|
test_loss, |
|
) |
|
) |
|
|
|
saver.log_value({ |
|
'validation/loss': test_loss |
|
}) |
|
model.train() |
|
|
|
|
|
if test_loss < best_loss: |
|
saver.log_info(' [V] best model updated.') |
|
saver.save_model(model, optimizer, postfix='best') |
|
best_loss = test_loss |
|
|
|
|
|
|