Spaces:
Sleeping
Sleeping
""" | |
放置公用模型 | |
""" | |
import gc | |
import logging | |
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, MegatronBertModel | |
from contants import config | |
from utils.download import download_file | |
from bert_vits2.text.chinese_bert import get_bert_feature as zh_bert | |
from bert_vits2.text.english_bert_mock import get_bert_feature as en_bert | |
from bert_vits2.text.japanese_bert import get_bert_feature as ja_bert | |
from bert_vits2.text.japanese_bert_v111 import get_bert_feature as ja_bert_v111 | |
from bert_vits2.text.japanese_bert_v200 import get_bert_feature as ja_bert_v200 | |
from bert_vits2.text.english_bert_mock_v200 import get_bert_feature as en_bert_v200 | |
from bert_vits2.text.chinese_bert_extra import get_bert_feature as zh_bert_extra | |
from bert_vits2.text.japanese_bert_extra import get_bert_feature as ja_bert_extra | |
class ModelHandler: | |
def __init__(self, device=config.system.device): | |
self.DOWNLOAD_PATHS = { | |
"CHINESE_ROBERTA_WWM_EXT_LARGE": [ | |
"https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin", | |
], | |
"BERT_BASE_JAPANESE_V3": [ | |
"https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin", | |
], | |
"BERT_LARGE_JAPANESE_V2": [ | |
"https://huggingface.co/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin", | |
], | |
"DEBERTA_V2_LARGE_JAPANESE": [ | |
"https://huggingface.co/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin", | |
], | |
"DEBERTA_V3_LARGE": [ | |
"https://huggingface.co/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin", | |
], | |
"SPM": [ | |
"https://huggingface.co/microsoft/deberta-v3-large/resolve/main/spm.model", | |
"https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/spm.model", | |
], | |
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": [ | |
"https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin", | |
], | |
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": [ | |
"https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", | |
], | |
"CLAP_HTSAT_FUSED": [ | |
"https://huggingface.co/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", | |
"https://hf-mirror.com/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", | |
], | |
"Erlangshen_MegatronBert_1.3B_Chinese": [ | |
"https://huggingface.co/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin", | |
], | |
"G2PWModel": [ | |
# "https://storage.googleapis.com/esun-ai/g2pW/G2PWModel-v2-onnx.zip", | |
"https://huggingface.co/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx", | |
"https://hf-mirror.com/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx", | |
], | |
"CHINESE_HUBERT_BASE": [ | |
"https://huggingface.co/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin", | |
"https://hf-mirror.com/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin", | |
] | |
} | |
self.SHA256 = { | |
"CHINESE_ROBERTA_WWM_EXT_LARGE": "4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd", | |
"BERT_BASE_JAPANESE_V3": "e172862e0674054d65e0ba40d67df2a4687982f589db44aa27091c386e5450a4", | |
"BERT_LARGE_JAPANESE_V2": "50212d714f79af45d3e47205faa356d0e5030e1c9a37138eadda544180f9e7c9", | |
"DEBERTA_V2_LARGE_JAPANESE": "a6c15feac0dea77ab8835c70e1befa4cf4c2137862c6fb2443b1553f70840047", | |
"DEBERTA_V3_LARGE": "dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5", | |
"SPM": "c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd", | |
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": "bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201", | |
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": "176d9d1ce29a8bddbab44068b9c1c194c51624c7f1812905e01355da58b18816", | |
"CLAP_HTSAT_FUSED": "1ed5d0215d887551ddd0a49ce7311b21429ebdf1e6a129d4e68f743357225253", | |
"Erlangshen_MegatronBert_1.3B_Chinese": "3456bb8f2c7157985688a4cb5cecdb9e229cb1dcf785b01545c611462ffe3579", | |
# "G2PWModel": "bb40c8c7b5baa755b2acd317c6bc5a65e4af7b80c40a569247fbd76989299999", | |
"G2PWModel": "", | |
"CHINESE_HUBERT_BASE": "2fefccd26c2794a583b80f6f7210c721873cb7ebae2c1cde3baf9b27855e24d8", | |
} | |
self.model_path = { | |
"CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.chinese_roberta_wwm_ext_large), | |
"BERT_BASE_JAPANESE_V3": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.bert_base_japanese_v3), | |
"BERT_LARGE_JAPANESE_V2": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.bert_large_japanese_v2), | |
"DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.deberta_v2_large_japanese), | |
"DEBERTA_V3_LARGE": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.deberta_v3_large), | |
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.deberta_v2_large_japanese_char_wwm), | |
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.wav2vec2_large_robust_12_ft_emotion_msp_dim), | |
"CLAP_HTSAT_FUSED": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.clap_htsat_fused), | |
"Erlangshen_MegatronBert_1.3B_Chinese": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.erlangshen_MegatronBert_1_3B_Chinese), | |
"G2PWModel": os.path.join(config.abs_path, config.system.data_path, config.model_config.g2pw_model), | |
"CHINESE_HUBERT_BASE": os.path.join(config.abs_path, config.system.data_path, | |
config.model_config.chinese_hubert_base), | |
} | |
self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111, | |
"ja_v200": ja_bert_v200, "en_v200": en_bert_v200, "zh_extra": zh_bert_extra, | |
"ja_extra": ja_bert_extra} | |
self.bert_models = {} # Value: (tokenizer, model, reference_count) | |
self.emotion = None | |
self.clap = None | |
self.pinyinPlus = None | |
self.device = device | |
self.ssl_model = None | |
if config.bert_vits2_config.torch_data_type.lower() in ["float16", "fp16"]: | |
self.torch_dtype = torch.float16 | |
else: | |
self.torch_dtype = None | |
def emotion_model(self): | |
return self.emotion["model"] | |
def emotion_processor(self): | |
return self.emotion["processor"] | |
def clap_model(self): | |
return self.clap["model"] | |
def clap_processor(self): | |
return self.clap["processor"] | |
def _download_model(self, model_name, target_path=None): | |
urls = self.DOWNLOAD_PATHS[model_name] | |
if target_path is None: | |
target_path = os.path.join(self.model_path[model_name], "pytorch_model.bin") | |
expected_sha256 = self.SHA256[model_name] | |
success, message = download_file(urls, target_path, expected_sha256=expected_sha256) | |
if not success: | |
logging.error(f"Failed to download {model_name}: {message}") | |
else: | |
logging.info(f"{message}") | |
def load_bert(self, bert_model_name, max_retries=3): | |
if bert_model_name not in self.bert_models: | |
retries = 0 | |
model_path = "" | |
while retries < max_retries: | |
model_path = self.model_path[bert_model_name] | |
logging.info(f"Loading BERT model: {model_path}") | |
try: | |
if bert_model_name == "Erlangshen_MegatronBert_1.3B_Chinese": | |
tokenizer = BertTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype) | |
model = MegatronBertModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( | |
self.device) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype) | |
model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( | |
self.device) | |
self.bert_models[bert_model_name] = (tokenizer, model, 1) # 初始化引用计数为1 | |
logging.info(f"Success loading: {model_path}") | |
break | |
except Exception as e: | |
logging.error(f"Failed loading {model_path}. {e}") | |
logging.info(f"Trying to download.") | |
if bert_model_name == "DEBERTA_V3_LARGE" and not os.path.exists( | |
os.path.join(model_path, "spm.model")): | |
self._download_model("SPM", os.path.join(model_path, "spm.model")) | |
self._download_model(bert_model_name) | |
retries += 1 | |
if retries == max_retries: | |
logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
else: | |
tokenizer, model, count = self.bert_models[bert_model_name] | |
self.bert_models[bert_model_name] = (tokenizer, model, count + 1) | |
def load_emotion(self, max_retries=3): | |
"""Bert-VITS2 v2.1 EmotionModel""" | |
if self.emotion is None: | |
from transformers import Wav2Vec2Processor | |
from bert_vits2.get_emo import EmotionModel | |
retries = 0 | |
model_path = self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"] | |
while retries < max_retries: | |
logging.info(f"Loading WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM: {model_path}") | |
try: | |
self.emotion = {} | |
self.emotion["model"] = EmotionModel.from_pretrained(model_path).to(self.device) | |
self.emotion["processor"] = Wav2Vec2Processor.from_pretrained(model_path) | |
self.emotion["reference_count"] = 1 | |
logging.info(f"Success loading: {model_path}") | |
break | |
except Exception as e: | |
logging.error(f"Failed loading {model_path}. {e}") | |
self._download_model("WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM") | |
retries += 1 | |
if retries == max_retries: | |
logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
else: | |
self.emotion["reference_count"] += 1 | |
def release_emotion(self): | |
if self.emotion is not None: | |
self.emotion["reference_count"] -= 1 | |
if self.emotion["reference_count"] <= 0: | |
del self.emotion | |
self.emotion = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
logging.info(f"Emotion model has been released.") | |
def load_clap(self, max_retries=3): | |
"""Bert-VITS2 v2.2 ClapModel""" | |
if self.clap is None: | |
from transformers import ClapModel, ClapProcessor | |
retries = 0 | |
model_path = self.model_path["CLAP_HTSAT_FUSED"] | |
while retries < max_retries: | |
logging.info(f"Loading CLAP_HTSAT_FUSED: {model_path}") | |
try: | |
self.clap = {} | |
self.clap["model"] = ClapModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( | |
self.device) | |
self.clap["processor"] = ClapProcessor.from_pretrained(model_path, torch_dtype=self.torch_dtype) | |
self.clap["reference_count"] = 1 | |
logging.info(f"Success loading: {model_path}") | |
break | |
except Exception as e: | |
logging.error(f"Failed loading {model_path}. {e}") | |
self._download_model("CLAP_HTSAT_FUSED") | |
retries += 1 | |
if retries == max_retries: | |
logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
else: | |
self.clap["reference_count"] += 1 | |
def release_clap(self): | |
if self.clap is not None: | |
self.clap["reference_count"] -= 1 | |
if self.clap["reference_count"] <= 0: | |
del self.clap | |
self.clap = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
logging.info(f"Clap model has been released.") | |
def get_bert_model(self, bert_model_name): | |
if bert_model_name not in self.bert_models: | |
self.load_bert(bert_model_name) | |
tokenizer, model, _ = self.bert_models[bert_model_name] | |
return tokenizer, model | |
def get_bert_feature(self, norm_text, word2ph, language, bert_model_name, style_text=None, style_weight=0.7): | |
tokenizer, model = self.get_bert_model(bert_model_name) | |
bert_feature = self.lang_bert_func_map[language](norm_text, word2ph, tokenizer, model, self.device, | |
style_text=style_text, style_weight=style_weight) | |
return bert_feature | |
def get_pinyinPlus(self): | |
if self.pinyinPlus is None: | |
from bert_vits2.g2pW.pypinyin_G2pW_bv2 import G2PWPinyin | |
logging.info(f"Loading G2PWModel: {self.model_path['G2PWModel']}") | |
self.pinyinPlus = G2PWPinyin( | |
model_dir=self.model_path["G2PWModel"], | |
model_source=self.model_path["Erlangshen_MegatronBert_1.3B_Chinese"], | |
v_to_u=False, | |
neutral_tone_with_five=True, | |
) | |
logging.info("Success loading G2PWModel") | |
return self.pinyinPlus | |
def release_bert(self, bert_model_name): | |
if bert_model_name in self.bert_models: | |
_, _, count = self.bert_models[bert_model_name] | |
count -= 1 | |
if count == 0: | |
# 当引用计数为0时,删除模型并释放其资源 | |
del self.bert_models[bert_model_name] | |
gc.collect() | |
torch.cuda.empty_cache() | |
logging.info(f"BERT model {bert_model_name} has been released.") | |
else: | |
tokenizer, model = self.bert_models[bert_model_name][:2] | |
self.bert_models[bert_model_name] = (tokenizer, model, count) | |
def load_ssl(self, max_retries=3): | |
"""GPT-SoVITS""" | |
if self.ssl_model is None: | |
retries = 0 | |
model_path = self.model_path["CHINESE_HUBERT_BASE"] | |
while retries < max_retries: | |
logging.info(f"Loading CHINESE_HUBERT_BASE: {model_path}") | |
try: | |
from gpt_sovits.feature_extractor.cnhubert import CNHubert | |
self.ssl_model = {} | |
model_path = self.model_path.get("CHINESE_HUBERT_BASE") | |
self.ssl_model["model"] = CNHubert(model_path) | |
self.ssl_model["model"].eval() | |
if config.gpt_sovits_config.is_half: | |
self.ssl_model["model"] = self.ssl_model["model"].half() | |
self.ssl_model["model"] = self.ssl_model["model"].to(self.device) | |
self.ssl_model["reference_count"] = 1 | |
logging.info(f"Success loading: {model_path}") | |
break | |
except Exception as e: | |
logging.error(f"Failed loading {model_path}. {e}") | |
self._download_model("CHINESE_HUBERT_BASE") | |
retries += 1 | |
if retries == max_retries: | |
logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
else: | |
self.ssl_model["reference_count"] += 1 | |
def get_ssl_model(self): | |
if self.ssl_model is None: | |
self.load_ssl() | |
return self.ssl_model.get("model") | |
def release_ssl_model(self): | |
if self.ssl_model is not None: | |
self.ssl_model["reference_count"] -= 1 | |
if self.ssl_model["reference_count"] <= 0: | |
del self.ssl_model | |
self.ssl_model = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
logging.info(f"SSL model has been released.") | |
def is_model_loaded(self, bert_model_name): | |
return bert_model_name in self.bert_models | |
def reference_count(self, bert_model_name): | |
return self.bert_models[bert_model_name][2] if bert_model_name in self.bert_models else 0 | |