test python files
Browse files- app.py +77 -0
- app_record.py +65 -0
- app_record_streaming.py +63 -0
- lightning_module.py +41 -0
- model.py +191 -0
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from random import sample
|
3 |
+
import gradio as gr
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import lightning_module
|
8 |
+
import pdb
|
9 |
+
import jiwer
|
10 |
+
# ASR part
|
11 |
+
from transformers import pipeline
|
12 |
+
p = pipeline("automatic-speech-recognition")
|
13 |
+
|
14 |
+
# WER part
|
15 |
+
transformation = jiwer.Compose([
|
16 |
+
jiwer.ToLowerCase(),
|
17 |
+
jiwer.RemoveWhiteSpace(replace_by_space=True),
|
18 |
+
jiwer.RemoveMultipleSpaces(),
|
19 |
+
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
|
20 |
+
])
|
21 |
+
|
22 |
+
class ChangeSampleRate(nn.Module):
|
23 |
+
def __init__(self, input_rate: int, output_rate: int):
|
24 |
+
super().__init__()
|
25 |
+
self.output_rate = output_rate
|
26 |
+
self.input_rate = input_rate
|
27 |
+
|
28 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
29 |
+
# Only accepts 1-channel waveform input
|
30 |
+
wav = wav.view(wav.size(0), -1)
|
31 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
32 |
+
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
|
33 |
+
round_down = wav[:, indices.long()]
|
34 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
35 |
+
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
|
36 |
+
return output
|
37 |
+
|
38 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
|
39 |
+
def calc_mos(audio_path, ref):
|
40 |
+
wav, sr = torchaudio.load(audio_path)
|
41 |
+
osr = 16_000
|
42 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
43 |
+
csr = ChangeSampleRate(sr, osr)
|
44 |
+
out_wavs = csr(wav)
|
45 |
+
# ASR
|
46 |
+
trans = p(audio_path)["text"]
|
47 |
+
# WER
|
48 |
+
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
|
49 |
+
|
50 |
+
batch = {
|
51 |
+
'wav': out_wavs,
|
52 |
+
'domains': torch.tensor([0]),
|
53 |
+
'judge_id': torch.tensor([288])
|
54 |
+
}
|
55 |
+
with torch.no_grad():
|
56 |
+
output = model(batch)
|
57 |
+
|
58 |
+
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
|
59 |
+
|
60 |
+
return predic_mos, trans, wer
|
61 |
+
|
62 |
+
description ="""
|
63 |
+
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
|
64 |
+
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
65 |
+
|
66 |
+
Paper is available [here](https://arxiv.org/abs/2204.02152)
|
67 |
+
"""
|
68 |
+
|
69 |
+
iface = gr.Interface(
|
70 |
+
fn=calc_mos,
|
71 |
+
inputs=[gr.Audio(type='filepath'), gr.Textbox(placeholder="Insert referance here", label="Referance")],
|
72 |
+
outputs=[gr.Textbox("Predicted MOS"), gr.Textbox("Hypothesis"), gr.Textbox("WER")],
|
73 |
+
title="UTMOS Demo",
|
74 |
+
description=description,
|
75 |
+
allow_flagging="auto",
|
76 |
+
)
|
77 |
+
iface.launch()
|
app_record.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from random import sample
|
3 |
+
import gradio as gr
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import lightning_module
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
# ASR part
|
11 |
+
from transformers import pipeline
|
12 |
+
p = pipeline("automatic-speech-recognition")
|
13 |
+
|
14 |
+
class ChangeSampleRate(nn.Module):
|
15 |
+
def __init__(self, input_rate: int, output_rate: int):
|
16 |
+
super().__init__()
|
17 |
+
self.output_rate = output_rate
|
18 |
+
self.input_rate = input_rate
|
19 |
+
|
20 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
21 |
+
# Only accepts 1-channel waveform input
|
22 |
+
wav = wav.view(wav.size(0), -1)
|
23 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
24 |
+
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
|
25 |
+
round_down = wav[:, indices.long()]
|
26 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
27 |
+
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
|
28 |
+
return output
|
29 |
+
|
30 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
|
31 |
+
def calc_mos(audio_path):
|
32 |
+
wav, sr = torchaudio.load(audio_path)
|
33 |
+
osr = 16_000
|
34 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
35 |
+
csr = ChangeSampleRate(sr, osr)
|
36 |
+
out_wavs = csr(wav)
|
37 |
+
|
38 |
+
transcription = p(audio_path)["text"]
|
39 |
+
batch = {
|
40 |
+
'wav': out_wavs,
|
41 |
+
'domains': torch.tensor([0]),
|
42 |
+
'judge_id': torch.tensor([288])
|
43 |
+
}
|
44 |
+
with torch.no_grad():
|
45 |
+
output = model(batch)
|
46 |
+
return output.mean(dim=1).squeeze().detach().numpy()*2 + 3, transcription
|
47 |
+
|
48 |
+
|
49 |
+
description ="""
|
50 |
+
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
|
51 |
+
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
52 |
+
|
53 |
+
Paper is available [here](https://arxiv.org/abs/2204.02152)
|
54 |
+
"""
|
55 |
+
|
56 |
+
# inputs=gr.inputs.Audio(type='filepath'),
|
57 |
+
iface = gr.Interface(
|
58 |
+
fn=calc_mos,
|
59 |
+
inputs = gr.Audio(source="microphone", type="filepath"),
|
60 |
+
outputs=["text","textbox"],
|
61 |
+
title="UTMOS Demo",
|
62 |
+
description=description,
|
63 |
+
allow_flagging=True,
|
64 |
+
|
65 |
+
).launch()
|
app_record_streaming.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from random import sample
|
3 |
+
import gradio as gr
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import lightning_module
|
8 |
+
import pdb
|
9 |
+
# ASR part
|
10 |
+
from transformers import pipeline
|
11 |
+
p = pipeline("automatic-speech-recognition")
|
12 |
+
|
13 |
+
class ChangeSampleRate(nn.Module):
|
14 |
+
def __init__(self, input_rate: int, output_rate: int):
|
15 |
+
super().__init__()
|
16 |
+
self.output_rate = output_rate
|
17 |
+
self.input_rate = input_rate
|
18 |
+
|
19 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
20 |
+
# Only accepts 1-channel waveform input
|
21 |
+
wav = wav.view(wav.size(0), -1)
|
22 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
23 |
+
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
|
24 |
+
round_down = wav[:, indices.long()]
|
25 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
26 |
+
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
|
27 |
+
return output
|
28 |
+
|
29 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
|
30 |
+
def calc_mos(audio_path):
|
31 |
+
wav, sr = torchaudio.load(audio_path)
|
32 |
+
osr = 16_000
|
33 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
34 |
+
csr = ChangeSampleRate(sr, osr)
|
35 |
+
out_wavs = csr(wav)
|
36 |
+
transcription = p(audio_path)["text"]
|
37 |
+
batch = {
|
38 |
+
'wav': out_wavs,
|
39 |
+
'domains': torch.tensor([0]),
|
40 |
+
'judge_id': torch.tensor([288])
|
41 |
+
}
|
42 |
+
with torch.no_grad():
|
43 |
+
output = model(batch)
|
44 |
+
return output.mean(dim=1).squeeze().detach().numpy()*2 + 3, transcription
|
45 |
+
|
46 |
+
|
47 |
+
description ="""
|
48 |
+
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
|
49 |
+
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
50 |
+
|
51 |
+
Paper is available [here](https://arxiv.org/abs/2204.02152)
|
52 |
+
"""
|
53 |
+
|
54 |
+
# inputs=gr.inputs.Audio(type='filepath'),
|
55 |
+
iface = gr.Interface(
|
56 |
+
fn=calc_mos,
|
57 |
+
inputs = gr.Audio(source="microphone", type="filepath", streaming=True),
|
58 |
+
outputs=["text","textbox"],
|
59 |
+
title="UTMOS Demo",
|
60 |
+
description=description,
|
61 |
+
allow_flagging=False,
|
62 |
+
|
63 |
+
).launch()
|
lightning_module.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import hydra
|
7 |
+
from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection
|
8 |
+
|
9 |
+
|
10 |
+
class BaselineLightningModule(pl.LightningModule):
|
11 |
+
def __init__(self, cfg):
|
12 |
+
super().__init__()
|
13 |
+
self.cfg = cfg
|
14 |
+
self.construct_model()
|
15 |
+
self.save_hyperparameters()
|
16 |
+
|
17 |
+
def construct_model(self):
|
18 |
+
self.feature_extractors = nn.ModuleList([
|
19 |
+
load_ssl_model(cp_path='wav2vec_small.pt'),
|
20 |
+
DomainEmbedding(3,128),
|
21 |
+
])
|
22 |
+
output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
|
23 |
+
output_layers = [
|
24 |
+
LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim)
|
25 |
+
]
|
26 |
+
output_dim = output_layers[-1].get_output_dim()
|
27 |
+
output_layers.append(
|
28 |
+
Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim)
|
29 |
+
|
30 |
+
)
|
31 |
+
|
32 |
+
self.output_layers = nn.ModuleList(output_layers)
|
33 |
+
|
34 |
+
def forward(self, inputs):
|
35 |
+
outputs = {}
|
36 |
+
for feature_extractor in self.feature_extractors:
|
37 |
+
outputs.update(feature_extractor(inputs))
|
38 |
+
x = outputs
|
39 |
+
for output_layer in self.output_layers:
|
40 |
+
x = output_layer(x,inputs)
|
41 |
+
return x
|
model.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import fairseq
|
4 |
+
import os
|
5 |
+
import hydra
|
6 |
+
|
7 |
+
def load_ssl_model(cp_path):
|
8 |
+
ssl_model_type = cp_path.split("/")[-1]
|
9 |
+
wavlm = "WavLM" in ssl_model_type
|
10 |
+
if wavlm:
|
11 |
+
checkpoint = torch.load(cp_path)
|
12 |
+
cfg = WavLMConfig(checkpoint['cfg'])
|
13 |
+
ssl_model = WavLM(cfg)
|
14 |
+
ssl_model.load_state_dict(checkpoint['model'])
|
15 |
+
if 'Large' in ssl_model_type:
|
16 |
+
SSL_OUT_DIM = 1024
|
17 |
+
else:
|
18 |
+
SSL_OUT_DIM = 768
|
19 |
+
else:
|
20 |
+
if ssl_model_type == "wav2vec_small.pt":
|
21 |
+
SSL_OUT_DIM = 768
|
22 |
+
elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
|
23 |
+
SSL_OUT_DIM = 1024
|
24 |
+
else:
|
25 |
+
print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
|
26 |
+
exit()
|
27 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
28 |
+
[cp_path]
|
29 |
+
)
|
30 |
+
ssl_model = model[0]
|
31 |
+
ssl_model.remove_pretraining_modules()
|
32 |
+
return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
|
33 |
+
|
34 |
+
class SSL_model(nn.Module):
|
35 |
+
def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
|
36 |
+
super(SSL_model,self).__init__()
|
37 |
+
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
|
38 |
+
self.WavLM = wavlm
|
39 |
+
|
40 |
+
def forward(self,batch):
|
41 |
+
wav = batch['wav']
|
42 |
+
wav = wav.squeeze(1) # [batches, audio_len]
|
43 |
+
if self.WavLM:
|
44 |
+
x = self.ssl_model.extract_features(wav)[0]
|
45 |
+
else:
|
46 |
+
res = self.ssl_model(wav, mask=False, features_only=True)
|
47 |
+
x = res["x"]
|
48 |
+
return {"ssl-feature":x}
|
49 |
+
def get_output_dim(self):
|
50 |
+
return self.ssl_out_dim
|
51 |
+
|
52 |
+
|
53 |
+
class PhonemeEncoder(nn.Module):
|
54 |
+
'''
|
55 |
+
PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
|
56 |
+
Args:
|
57 |
+
vocab_size: the size of the vocabulary
|
58 |
+
hidden_dim: the size of the hidden state of the LSTM
|
59 |
+
emb_dim: the size of the embedding layer
|
60 |
+
out_dim: the size of the output of the linear layer
|
61 |
+
n_lstm_layers: the number of LSTM layers
|
62 |
+
'''
|
63 |
+
def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
|
64 |
+
super().__init__()
|
65 |
+
self.with_reference = with_reference
|
66 |
+
self.embedding = nn.Embedding(vocab_size, emb_dim)
|
67 |
+
self.encoder = nn.LSTM(emb_dim, hidden_dim,
|
68 |
+
num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
|
69 |
+
self.linear = nn.Sequential(
|
70 |
+
nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
|
71 |
+
nn.ReLU()
|
72 |
+
)
|
73 |
+
self.out_dim = out_dim
|
74 |
+
|
75 |
+
def forward(self,batch):
|
76 |
+
seq = batch['phonemes']
|
77 |
+
lens = batch['phoneme_lens']
|
78 |
+
reference_seq = batch['reference']
|
79 |
+
reference_lens = batch['reference_lens']
|
80 |
+
emb = self.embedding(seq)
|
81 |
+
emb = torch.nn.utils.rnn.pack_padded_sequence(
|
82 |
+
emb, lens, batch_first=True, enforce_sorted=False)
|
83 |
+
_, (ht, _) = self.encoder(emb)
|
84 |
+
feature = ht[-1] + ht[0]
|
85 |
+
if self.with_reference:
|
86 |
+
if reference_seq==None or reference_lens ==None:
|
87 |
+
raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
|
88 |
+
reference_emb = self.embedding(reference_seq)
|
89 |
+
reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
|
90 |
+
reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
|
91 |
+
_, (ht_ref, _) = self.encoder(emb)
|
92 |
+
reference_feature = ht_ref[-1] + ht_ref[0]
|
93 |
+
feature = self.linear(torch.cat([feature,reference_feature],1))
|
94 |
+
else:
|
95 |
+
feature = self.linear(feature)
|
96 |
+
return {"phoneme-feature": feature}
|
97 |
+
def get_output_dim(self):
|
98 |
+
return self.out_dim
|
99 |
+
|
100 |
+
class DomainEmbedding(nn.Module):
|
101 |
+
def __init__(self,n_domains,domain_dim) -> None:
|
102 |
+
super().__init__()
|
103 |
+
self.embedding = nn.Embedding(n_domains,domain_dim)
|
104 |
+
self.output_dim = domain_dim
|
105 |
+
def forward(self, batch):
|
106 |
+
return {"domain-feature": self.embedding(batch['domains'])}
|
107 |
+
def get_output_dim(self):
|
108 |
+
return self.output_dim
|
109 |
+
|
110 |
+
|
111 |
+
class LDConditioner(nn.Module):
|
112 |
+
'''
|
113 |
+
Conditions ssl output by listener embedding
|
114 |
+
'''
|
115 |
+
def __init__(self,input_dim, judge_dim, num_judges=None):
|
116 |
+
super().__init__()
|
117 |
+
self.input_dim = input_dim
|
118 |
+
self.judge_dim = judge_dim
|
119 |
+
self.num_judges = num_judges
|
120 |
+
assert num_judges !=None
|
121 |
+
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
|
122 |
+
# concat [self.output_layer, phoneme features]
|
123 |
+
|
124 |
+
self.decoder_rnn = nn.LSTM(
|
125 |
+
input_size = self.input_dim + self.judge_dim,
|
126 |
+
hidden_size = 512,
|
127 |
+
num_layers = 1,
|
128 |
+
batch_first = True,
|
129 |
+
bidirectional = True
|
130 |
+
) # linear?
|
131 |
+
self.out_dim = self.decoder_rnn.hidden_size*2
|
132 |
+
|
133 |
+
def get_output_dim(self):
|
134 |
+
return self.out_dim
|
135 |
+
|
136 |
+
|
137 |
+
def forward(self, x, batch):
|
138 |
+
judge_ids = batch['judge_id']
|
139 |
+
if 'phoneme-feature' in x.keys():
|
140 |
+
concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
|
141 |
+
else:
|
142 |
+
concatenated_feature = x['ssl-feature']
|
143 |
+
if 'domain-feature' in x.keys():
|
144 |
+
concatenated_feature = torch.cat(
|
145 |
+
(
|
146 |
+
concatenated_feature,
|
147 |
+
x['domain-feature']
|
148 |
+
.unsqueeze(1)
|
149 |
+
.expand(-1, concatenated_feature.size(1), -1),
|
150 |
+
),
|
151 |
+
dim=2,
|
152 |
+
)
|
153 |
+
if judge_ids != None:
|
154 |
+
concatenated_feature = torch.cat(
|
155 |
+
(
|
156 |
+
concatenated_feature,
|
157 |
+
self.judge_embedding(judge_ids)
|
158 |
+
.unsqueeze(1)
|
159 |
+
.expand(-1, concatenated_feature.size(1), -1),
|
160 |
+
),
|
161 |
+
dim=2,
|
162 |
+
)
|
163 |
+
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
|
164 |
+
return decoder_output
|
165 |
+
|
166 |
+
class Projection(nn.Module):
|
167 |
+
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
|
168 |
+
super(Projection, self).__init__()
|
169 |
+
self.range_clipping = range_clipping
|
170 |
+
output_dim = 1
|
171 |
+
if range_clipping:
|
172 |
+
self.proj = nn.Tanh()
|
173 |
+
|
174 |
+
self.net = nn.Sequential(
|
175 |
+
nn.Linear(input_dim, hidden_dim),
|
176 |
+
activation,
|
177 |
+
nn.Dropout(0.3),
|
178 |
+
nn.Linear(hidden_dim, output_dim),
|
179 |
+
)
|
180 |
+
self.output_dim = output_dim
|
181 |
+
|
182 |
+
def forward(self, x, batch):
|
183 |
+
output = self.net(x)
|
184 |
+
|
185 |
+
# range clipping
|
186 |
+
if self.range_clipping:
|
187 |
+
return self.proj(output) * 2.0 + 3
|
188 |
+
else:
|
189 |
+
return output
|
190 |
+
def get_output_dim(self):
|
191 |
+
return self.output_dim
|