|
from tqdm import tqdm
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class Audio2Exp(nn.Module):
|
|
def __init__(self, netG, cfg, device, prepare_training_loss=False):
|
|
super(Audio2Exp, self).__init__()
|
|
self.cfg = cfg
|
|
self.device = device
|
|
self.netG = netG.to(device)
|
|
|
|
def test(self, batch):
|
|
|
|
mel_input = batch['indiv_mels']
|
|
bs = mel_input.shape[0]
|
|
T = mel_input.shape[1]
|
|
|
|
exp_coeff_pred = []
|
|
|
|
for i in tqdm(range(0, T, 10),'audio2exp:'):
|
|
|
|
current_mel_input = mel_input[:,i:i+10]
|
|
|
|
|
|
ref = batch['ref'][:, :, :64][:, i:i+10]
|
|
ratio = batch['ratio_gt'][:, i:i+10]
|
|
|
|
audiox = current_mel_input.view(-1, 1, 80, 16)
|
|
|
|
curr_exp_coeff_pred = self.netG(audiox, ref, ratio)
|
|
|
|
exp_coeff_pred += [curr_exp_coeff_pred]
|
|
|
|
|
|
results_dict = {
|
|
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
|
|
}
|
|
return results_dict
|
|
|
|
|
|
|