vits-simple-api-gsv / manager /model_handler.py
Artrajz's picture
init
960cd20
"""
放置公用模型
"""
import gc
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, MegatronBertModel
from contants import config
from utils.download import download_file
from bert_vits2.text.chinese_bert import get_bert_feature as zh_bert
from bert_vits2.text.english_bert_mock import get_bert_feature as en_bert
from bert_vits2.text.japanese_bert import get_bert_feature as ja_bert
from bert_vits2.text.japanese_bert_v111 import get_bert_feature as ja_bert_v111
from bert_vits2.text.japanese_bert_v200 import get_bert_feature as ja_bert_v200
from bert_vits2.text.english_bert_mock_v200 import get_bert_feature as en_bert_v200
from bert_vits2.text.chinese_bert_extra import get_bert_feature as zh_bert_extra
from bert_vits2.text.japanese_bert_extra import get_bert_feature as ja_bert_extra
class ModelHandler:
def __init__(self, device=config.system.device):
self.DOWNLOAD_PATHS = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": [
"https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
],
"BERT_BASE_JAPANESE_V3": [
"https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin",
],
"BERT_LARGE_JAPANESE_V2": [
"https://huggingface.co/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin",
],
"DEBERTA_V2_LARGE_JAPANESE": [
"https://huggingface.co/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin",
],
"DEBERTA_V3_LARGE": [
"https://huggingface.co/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin",
],
"SPM": [
"https://huggingface.co/microsoft/deberta-v3-large/resolve/main/spm.model",
"https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/spm.model",
],
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": [
"https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin",
],
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": [
"https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin",
],
"CLAP_HTSAT_FUSED": [
"https://huggingface.co/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true",
"https://hf-mirror.com/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true",
],
"Erlangshen_MegatronBert_1.3B_Chinese": [
"https://huggingface.co/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin",
],
"G2PWModel": [
# "https://storage.googleapis.com/esun-ai/g2pW/G2PWModel-v2-onnx.zip",
"https://huggingface.co/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx",
"https://hf-mirror.com/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx",
],
"CHINESE_HUBERT_BASE": [
"https://huggingface.co/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin",
]
}
self.SHA256 = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": "4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd",
"BERT_BASE_JAPANESE_V3": "e172862e0674054d65e0ba40d67df2a4687982f589db44aa27091c386e5450a4",
"BERT_LARGE_JAPANESE_V2": "50212d714f79af45d3e47205faa356d0e5030e1c9a37138eadda544180f9e7c9",
"DEBERTA_V2_LARGE_JAPANESE": "a6c15feac0dea77ab8835c70e1befa4cf4c2137862c6fb2443b1553f70840047",
"DEBERTA_V3_LARGE": "dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5",
"SPM": "c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd",
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": "bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201",
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": "176d9d1ce29a8bddbab44068b9c1c194c51624c7f1812905e01355da58b18816",
"CLAP_HTSAT_FUSED": "1ed5d0215d887551ddd0a49ce7311b21429ebdf1e6a129d4e68f743357225253",
"Erlangshen_MegatronBert_1.3B_Chinese": "3456bb8f2c7157985688a4cb5cecdb9e229cb1dcf785b01545c611462ffe3579",
# "G2PWModel": "bb40c8c7b5baa755b2acd317c6bc5a65e4af7b80c40a569247fbd76989299999",
"G2PWModel": "",
"CHINESE_HUBERT_BASE": "2fefccd26c2794a583b80f6f7210c721873cb7ebae2c1cde3baf9b27855e24d8",
}
self.model_path = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.chinese_roberta_wwm_ext_large),
"BERT_BASE_JAPANESE_V3": os.path.join(config.abs_path, config.system.data_path,
config.model_config.bert_base_japanese_v3),
"BERT_LARGE_JAPANESE_V2": os.path.join(config.abs_path, config.system.data_path,
config.model_config.bert_large_japanese_v2),
"DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v2_large_japanese),
"DEBERTA_V3_LARGE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v3_large),
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v2_large_japanese_char_wwm),
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": os.path.join(config.abs_path, config.system.data_path,
config.model_config.wav2vec2_large_robust_12_ft_emotion_msp_dim),
"CLAP_HTSAT_FUSED": os.path.join(config.abs_path, config.system.data_path,
config.model_config.clap_htsat_fused),
"Erlangshen_MegatronBert_1.3B_Chinese": os.path.join(config.abs_path, config.system.data_path,
config.model_config.erlangshen_MegatronBert_1_3B_Chinese),
"G2PWModel": os.path.join(config.abs_path, config.system.data_path, config.model_config.g2pw_model),
"CHINESE_HUBERT_BASE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.chinese_hubert_base),
}
self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111,
"ja_v200": ja_bert_v200, "en_v200": en_bert_v200, "zh_extra": zh_bert_extra,
"ja_extra": ja_bert_extra}
self.bert_models = {} # Value: (tokenizer, model, reference_count)
self.emotion = None
self.clap = None
self.pinyinPlus = None
self.device = device
self.ssl_model = None
if config.bert_vits2_config.torch_data_type.lower() in ["float16", "fp16"]:
self.torch_dtype = torch.float16
else:
self.torch_dtype = None
@property
def emotion_model(self):
return self.emotion["model"]
@property
def emotion_processor(self):
return self.emotion["processor"]
@property
def clap_model(self):
return self.clap["model"]
@property
def clap_processor(self):
return self.clap["processor"]
def _download_model(self, model_name, target_path=None):
urls = self.DOWNLOAD_PATHS[model_name]
if target_path is None:
target_path = os.path.join(self.model_path[model_name], "pytorch_model.bin")
expected_sha256 = self.SHA256[model_name]
success, message = download_file(urls, target_path, expected_sha256=expected_sha256)
if not success:
logging.error(f"Failed to download {model_name}: {message}")
else:
logging.info(f"{message}")
def load_bert(self, bert_model_name, max_retries=3):
if bert_model_name not in self.bert_models:
retries = 0
model_path = ""
while retries < max_retries:
model_path = self.model_path[bert_model_name]
logging.info(f"Loading BERT model: {model_path}")
try:
if bert_model_name == "Erlangshen_MegatronBert_1.3B_Chinese":
tokenizer = BertTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype)
model = MegatronBertModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype)
model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
self.bert_models[bert_model_name] = (tokenizer, model, 1) # 初始化引用计数为1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
logging.info(f"Trying to download.")
if bert_model_name == "DEBERTA_V3_LARGE" and not os.path.exists(
os.path.join(model_path, "spm.model")):
self._download_model("SPM", os.path.join(model_path, "spm.model"))
self._download_model(bert_model_name)
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
tokenizer, model, count = self.bert_models[bert_model_name]
self.bert_models[bert_model_name] = (tokenizer, model, count + 1)
def load_emotion(self, max_retries=3):
"""Bert-VITS2 v2.1 EmotionModel"""
if self.emotion is None:
from transformers import Wav2Vec2Processor
from bert_vits2.get_emo import EmotionModel
retries = 0
model_path = self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"]
while retries < max_retries:
logging.info(f"Loading WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM: {model_path}")
try:
self.emotion = {}
self.emotion["model"] = EmotionModel.from_pretrained(model_path).to(self.device)
self.emotion["processor"] = Wav2Vec2Processor.from_pretrained(model_path)
self.emotion["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM")
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.emotion["reference_count"] += 1
def release_emotion(self):
if self.emotion is not None:
self.emotion["reference_count"] -= 1
if self.emotion["reference_count"] <= 0:
del self.emotion
self.emotion = None
gc.collect()
torch.cuda.empty_cache()
logging.info(f"Emotion model has been released.")
def load_clap(self, max_retries=3):
"""Bert-VITS2 v2.2 ClapModel"""
if self.clap is None:
from transformers import ClapModel, ClapProcessor
retries = 0
model_path = self.model_path["CLAP_HTSAT_FUSED"]
while retries < max_retries:
logging.info(f"Loading CLAP_HTSAT_FUSED: {model_path}")
try:
self.clap = {}
self.clap["model"] = ClapModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
self.clap["processor"] = ClapProcessor.from_pretrained(model_path, torch_dtype=self.torch_dtype)
self.clap["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("CLAP_HTSAT_FUSED")
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.clap["reference_count"] += 1
def release_clap(self):
if self.clap is not None:
self.clap["reference_count"] -= 1
if self.clap["reference_count"] <= 0:
del self.clap
self.clap = None
gc.collect()
torch.cuda.empty_cache()
logging.info(f"Clap model has been released.")
def get_bert_model(self, bert_model_name):
if bert_model_name not in self.bert_models:
self.load_bert(bert_model_name)
tokenizer, model, _ = self.bert_models[bert_model_name]
return tokenizer, model
def get_bert_feature(self, norm_text, word2ph, language, bert_model_name, style_text=None, style_weight=0.7):
tokenizer, model = self.get_bert_model(bert_model_name)
bert_feature = self.lang_bert_func_map[language](norm_text, word2ph, tokenizer, model, self.device,
style_text=style_text, style_weight=style_weight)
return bert_feature
def get_pinyinPlus(self):
if self.pinyinPlus is None:
from bert_vits2.g2pW.pypinyin_G2pW_bv2 import G2PWPinyin
logging.info(f"Loading G2PWModel: {self.model_path['G2PWModel']}")
self.pinyinPlus = G2PWPinyin(
model_dir=self.model_path["G2PWModel"],
model_source=self.model_path["Erlangshen_MegatronBert_1.3B_Chinese"],
v_to_u=False,
neutral_tone_with_five=True,
)
logging.info("Success loading G2PWModel")
return self.pinyinPlus
def release_bert(self, bert_model_name):
if bert_model_name in self.bert_models:
_, _, count = self.bert_models[bert_model_name]
count -= 1
if count == 0:
# 当引用计数为0时,删除模型并释放其资源
del self.bert_models[bert_model_name]
gc.collect()
torch.cuda.empty_cache()
logging.info(f"BERT model {bert_model_name} has been released.")
else:
tokenizer, model = self.bert_models[bert_model_name][:2]
self.bert_models[bert_model_name] = (tokenizer, model, count)
def load_ssl(self, max_retries=3):
"""GPT-SoVITS"""
if self.ssl_model is None:
retries = 0
model_path = self.model_path["CHINESE_HUBERT_BASE"]
while retries < max_retries:
logging.info(f"Loading CHINESE_HUBERT_BASE: {model_path}")
try:
from gpt_sovits.feature_extractor.cnhubert import CNHubert
self.ssl_model = {}
model_path = self.model_path.get("CHINESE_HUBERT_BASE")
self.ssl_model["model"] = CNHubert(model_path)
self.ssl_model["model"].eval()
if config.gpt_sovits_config.is_half:
self.ssl_model["model"] = self.ssl_model["model"].half()
self.ssl_model["model"] = self.ssl_model["model"].to(self.device)
self.ssl_model["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("CHINESE_HUBERT_BASE")
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.ssl_model["reference_count"] += 1
def get_ssl_model(self):
if self.ssl_model is None:
self.load_ssl()
return self.ssl_model.get("model")
def release_ssl_model(self):
if self.ssl_model is not None:
self.ssl_model["reference_count"] -= 1
if self.ssl_model["reference_count"] <= 0:
del self.ssl_model
self.ssl_model = None
gc.collect()
torch.cuda.empty_cache()
logging.info(f"SSL model has been released.")
def is_model_loaded(self, bert_model_name):
return bert_model_name in self.bert_models
def reference_count(self, bert_model_name):
return self.bert_models[bert_model_name][2] if bert_model_name in self.bert_models else 0