KevinGeng commited on
Commit
6d26a9c
1 Parent(s): e2a4e13

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +191 -0
model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import fairseq
4
+ import os
5
+ import hydra
6
+
7
+ def load_ssl_model(cp_path):
8
+ ssl_model_type = cp_path.split("/")[-1]
9
+ wavlm = "WavLM" in ssl_model_type
10
+ if wavlm:
11
+ checkpoint = torch.load(cp_path)
12
+ cfg = WavLMConfig(checkpoint['cfg'])
13
+ ssl_model = WavLM(cfg)
14
+ ssl_model.load_state_dict(checkpoint['model'])
15
+ if 'Large' in ssl_model_type:
16
+ SSL_OUT_DIM = 1024
17
+ else:
18
+ SSL_OUT_DIM = 768
19
+ else:
20
+ if ssl_model_type == "wav2vec_small.pt":
21
+ SSL_OUT_DIM = 768
22
+ elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
23
+ SSL_OUT_DIM = 1024
24
+ else:
25
+ print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
26
+ exit()
27
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
28
+ [cp_path]
29
+ )
30
+ ssl_model = model[0]
31
+ ssl_model.remove_pretraining_modules()
32
+ return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
33
+
34
+ class SSL_model(nn.Module):
35
+ def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
36
+ super(SSL_model,self).__init__()
37
+ self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
38
+ self.WavLM = wavlm
39
+
40
+ def forward(self,batch):
41
+ wav = batch['wav']
42
+ wav = wav.squeeze(1) # [batches, audio_len]
43
+ if self.WavLM:
44
+ x = self.ssl_model.extract_features(wav)[0]
45
+ else:
46
+ res = self.ssl_model(wav, mask=False, features_only=True)
47
+ x = res["x"]
48
+ return {"ssl-feature":x}
49
+ def get_output_dim(self):
50
+ return self.ssl_out_dim
51
+
52
+
53
+ class PhonemeEncoder(nn.Module):
54
+ '''
55
+ PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
56
+ Args:
57
+ vocab_size: the size of the vocabulary
58
+ hidden_dim: the size of the hidden state of the LSTM
59
+ emb_dim: the size of the embedding layer
60
+ out_dim: the size of the output of the linear layer
61
+ n_lstm_layers: the number of LSTM layers
62
+ '''
63
+ def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
64
+ super().__init__()
65
+ self.with_reference = with_reference
66
+ self.embedding = nn.Embedding(vocab_size, emb_dim)
67
+ self.encoder = nn.LSTM(emb_dim, hidden_dim,
68
+ num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
69
+ self.linear = nn.Sequential(
70
+ nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
71
+ nn.ReLU()
72
+ )
73
+ self.out_dim = out_dim
74
+
75
+ def forward(self,batch):
76
+ seq = batch['phonemes']
77
+ lens = batch['phoneme_lens']
78
+ reference_seq = batch['reference']
79
+ reference_lens = batch['reference_lens']
80
+ emb = self.embedding(seq)
81
+ emb = torch.nn.utils.rnn.pack_padded_sequence(
82
+ emb, lens, batch_first=True, enforce_sorted=False)
83
+ _, (ht, _) = self.encoder(emb)
84
+ feature = ht[-1] + ht[0]
85
+ if self.with_reference:
86
+ if reference_seq==None or reference_lens ==None:
87
+ raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
88
+ reference_emb = self.embedding(reference_seq)
89
+ reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
90
+ reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
91
+ _, (ht_ref, _) = self.encoder(emb)
92
+ reference_feature = ht_ref[-1] + ht_ref[0]
93
+ feature = self.linear(torch.cat([feature,reference_feature],1))
94
+ else:
95
+ feature = self.linear(feature)
96
+ return {"phoneme-feature": feature}
97
+ def get_output_dim(self):
98
+ return self.out_dim
99
+
100
+ class DomainEmbedding(nn.Module):
101
+ def __init__(self,n_domains,domain_dim) -> None:
102
+ super().__init__()
103
+ self.embedding = nn.Embedding(n_domains,domain_dim)
104
+ self.output_dim = domain_dim
105
+ def forward(self, batch):
106
+ return {"domain-feature": self.embedding(batch['domains'])}
107
+ def get_output_dim(self):
108
+ return self.output_dim
109
+
110
+
111
+ class LDConditioner(nn.Module):
112
+ '''
113
+ Conditions ssl output by listener embedding
114
+ '''
115
+ def __init__(self,input_dim, judge_dim, num_judges=None):
116
+ super().__init__()
117
+ self.input_dim = input_dim
118
+ self.judge_dim = judge_dim
119
+ self.num_judges = num_judges
120
+ assert num_judges !=None
121
+ self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
122
+ # concat [self.output_layer, phoneme features]
123
+
124
+ self.decoder_rnn = nn.LSTM(
125
+ input_size = self.input_dim + self.judge_dim,
126
+ hidden_size = 512,
127
+ num_layers = 1,
128
+ batch_first = True,
129
+ bidirectional = True
130
+ ) # linear?
131
+ self.out_dim = self.decoder_rnn.hidden_size*2
132
+
133
+ def get_output_dim(self):
134
+ return self.out_dim
135
+
136
+
137
+ def forward(self, x, batch):
138
+ judge_ids = batch['judge_id']
139
+ if 'phoneme-feature' in x.keys():
140
+ concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
141
+ else:
142
+ concatenated_feature = x['ssl-feature']
143
+ if 'domain-feature' in x.keys():
144
+ concatenated_feature = torch.cat(
145
+ (
146
+ concatenated_feature,
147
+ x['domain-feature']
148
+ .unsqueeze(1)
149
+ .expand(-1, concatenated_feature.size(1), -1),
150
+ ),
151
+ dim=2,
152
+ )
153
+ if judge_ids != None:
154
+ concatenated_feature = torch.cat(
155
+ (
156
+ concatenated_feature,
157
+ self.judge_embedding(judge_ids)
158
+ .unsqueeze(1)
159
+ .expand(-1, concatenated_feature.size(1), -1),
160
+ ),
161
+ dim=2,
162
+ )
163
+ decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
164
+ return decoder_output
165
+
166
+ class Projection(nn.Module):
167
+ def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
168
+ super(Projection, self).__init__()
169
+ self.range_clipping = range_clipping
170
+ output_dim = 1
171
+ if range_clipping:
172
+ self.proj = nn.Tanh()
173
+
174
+ self.net = nn.Sequential(
175
+ nn.Linear(input_dim, hidden_dim),
176
+ activation,
177
+ nn.Dropout(0.3),
178
+ nn.Linear(hidden_dim, output_dim),
179
+ )
180
+ self.output_dim = output_dim
181
+
182
+ def forward(self, x, batch):
183
+ output = self.net(x)
184
+
185
+ # range clipping
186
+ if self.range_clipping:
187
+ return self.proj(output) * 2.0 + 3
188
+ else:
189
+ return output
190
+ def get_output_dim(self):
191
+ return self.output_dim