File size: 6,929 Bytes
b33c328 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import torch
import torch.nn as nn
import fairseq
import os
import hydra
def load_ssl_model(cp_path):
ssl_model_type = cp_path.split("/")[-1]
wavlm = "WavLM" in ssl_model_type
if wavlm:
checkpoint = torch.load(cp_path)
cfg = WavLMConfig(checkpoint['cfg'])
ssl_model = WavLM(cfg)
ssl_model.load_state_dict(checkpoint['model'])
if 'Large' in ssl_model_type:
SSL_OUT_DIM = 1024
else:
SSL_OUT_DIM = 768
else:
if ssl_model_type == "wav2vec_small.pt":
SSL_OUT_DIM = 768
elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
SSL_OUT_DIM = 1024
else:
print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
exit()
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[cp_path]
)
ssl_model = model[0]
ssl_model.remove_pretraining_modules()
return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
class SSL_model(nn.Module):
def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
super(SSL_model,self).__init__()
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
self.WavLM = wavlm
def forward(self,batch):
wav = batch['wav']
wav = wav.squeeze(1) # [batches, audio_len]
if self.WavLM:
x = self.ssl_model.extract_features(wav)[0]
else:
res = self.ssl_model(wav, mask=False, features_only=True)
x = res["x"]
return {"ssl-feature":x}
def get_output_dim(self):
return self.ssl_out_dim
class PhonemeEncoder(nn.Module):
'''
PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
Args:
vocab_size: the size of the vocabulary
hidden_dim: the size of the hidden state of the LSTM
emb_dim: the size of the embedding layer
out_dim: the size of the output of the linear layer
n_lstm_layers: the number of LSTM layers
'''
def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
super().__init__()
self.with_reference = with_reference
self.embedding = nn.Embedding(vocab_size, emb_dim)
self.encoder = nn.LSTM(emb_dim, hidden_dim,
num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
self.linear = nn.Sequential(
nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
nn.ReLU()
)
self.out_dim = out_dim
def forward(self,batch):
seq = batch['phonemes']
lens = batch['phoneme_lens']
reference_seq = batch['reference']
reference_lens = batch['reference_lens']
emb = self.embedding(seq)
emb = torch.nn.utils.rnn.pack_padded_sequence(
emb, lens, batch_first=True, enforce_sorted=False)
_, (ht, _) = self.encoder(emb)
feature = ht[-1] + ht[0]
if self.with_reference:
if reference_seq==None or reference_lens ==None:
raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
reference_emb = self.embedding(reference_seq)
reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
_, (ht_ref, _) = self.encoder(emb)
reference_feature = ht_ref[-1] + ht_ref[0]
feature = self.linear(torch.cat([feature,reference_feature],1))
else:
feature = self.linear(feature)
return {"phoneme-feature": feature}
def get_output_dim(self):
return self.out_dim
class DomainEmbedding(nn.Module):
def __init__(self,n_domains,domain_dim) -> None:
super().__init__()
self.embedding = nn.Embedding(n_domains,domain_dim)
self.output_dim = domain_dim
def forward(self, batch):
return {"domain-feature": self.embedding(batch['domains'])}
def get_output_dim(self):
return self.output_dim
class LDConditioner(nn.Module):
'''
Conditions ssl output by listener embedding
'''
def __init__(self,input_dim, judge_dim, num_judges=None):
super().__init__()
self.input_dim = input_dim
self.judge_dim = judge_dim
self.num_judges = num_judges
assert num_judges !=None
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
# concat [self.output_layer, phoneme features]
self.decoder_rnn = nn.LSTM(
input_size = self.input_dim + self.judge_dim,
hidden_size = 512,
num_layers = 1,
batch_first = True,
bidirectional = True
) # linear?
self.out_dim = self.decoder_rnn.hidden_size*2
def get_output_dim(self):
return self.out_dim
def forward(self, x, batch):
judge_ids = batch['judge_id']
if 'phoneme-feature' in x.keys():
concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
else:
concatenated_feature = x['ssl-feature']
if 'domain-feature' in x.keys():
concatenated_feature = torch.cat(
(
concatenated_feature,
x['domain-feature']
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
if judge_ids != None:
concatenated_feature = torch.cat(
(
concatenated_feature,
self.judge_embedding(judge_ids)
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
return decoder_output
class Projection(nn.Module):
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
super(Projection, self).__init__()
self.range_clipping = range_clipping
output_dim = 1
if range_clipping:
self.proj = nn.Tanh()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
activation,
nn.Dropout(0.3),
nn.Linear(hidden_dim, output_dim),
)
self.output_dim = output_dim
def forward(self, x, batch):
output = self.net(x)
# range clipping
if self.range_clipping:
return self.proj(output) * 2.0 + 3
else:
return output
def get_output_dim(self):
return self.output_dim
|