Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import BertModel, BertConfig, BertTokenizer | |
try: | |
from contants import config | |
except: | |
from contants import config | |
class CharEmbedding(nn.Module): | |
def __init__(self, model_dir): | |
super().__init__() | |
self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |
self.bert_config = BertConfig.from_pretrained(model_dir) | |
self.hidden_size = self.bert_config.hidden_size | |
self.bert = BertModel(self.bert_config) | |
self.proj = nn.Linear(self.hidden_size, 256) | |
self.linear = nn.Linear(256, 3) | |
def text2Token(self, text): | |
token = self.tokenizer.tokenize(text) | |
txtid = self.tokenizer.convert_tokens_to_ids(token) | |
return txtid | |
def forward(self, inputs_ids, inputs_masks, tokens_type_ids): | |
out_seq = self.bert(input_ids=inputs_ids, | |
attention_mask=inputs_masks, | |
token_type_ids=tokens_type_ids)[0] | |
out_seq = self.proj(out_seq) | |
return out_seq | |
class TTSProsody(object): | |
def __init__(self, path, device): | |
self.device = device | |
self.char_model = CharEmbedding(path) | |
self.char_model.load_state_dict( | |
torch.load( | |
os.path.join(config.abs_path, config.system.data_path, config.model_config.vits_chinese_bert, | |
"prosody_model.pt"), | |
map_location=config.system.device | |
), | |
strict=False | |
) | |
self.char_model.eval() | |
self.char_model.to(self.device) | |
def get_char_embeds(self, text): | |
input_ids = self.char_model.text2Token(text) | |
input_masks = [1] * len(input_ids) | |
type_ids = [0] * len(input_ids) | |
input_ids = torch.LongTensor([input_ids]).to(self.device) | |
input_masks = torch.LongTensor([input_masks]).to(self.device) | |
type_ids = torch.LongTensor([type_ids]).to(self.device) | |
with torch.no_grad(): | |
char_embeds = self.char_model( | |
input_ids, input_masks, type_ids).squeeze(0).cpu() | |
return char_embeds | |
def expand_for_phone(self, char_embeds, length): # length of phones for char | |
assert char_embeds.size(0) == len(length) | |
expand_vecs = list() | |
for vec, leng in zip(char_embeds, length): | |
vec = vec.expand(leng, -1) | |
expand_vecs.append(vec) | |
expand_embeds = torch.cat(expand_vecs, 0) | |
assert expand_embeds.size(0) == sum(length) | |
return expand_embeds.numpy() | |
if __name__ == "__main__": | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
prosody = TTSProsody('./bert/', device) | |
while True: | |
text = input("请输入文本:") | |
prosody.get_char_embeds(text) | |