vits-simple-api-gsv / manager /model_handler.py
Artrajz's picture
init
960cd20
raw
history blame
18.7 kB
"""
放置公用模型
"""
import gc
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, MegatronBertModel
from contants import config
from utils.download import download_file
from bert_vits2.text.chinese_bert import get_bert_feature as zh_bert
from bert_vits2.text.english_bert_mock import get_bert_feature as en_bert
from bert_vits2.text.japanese_bert import get_bert_feature as ja_bert
from bert_vits2.text.japanese_bert_v111 import get_bert_feature as ja_bert_v111
from bert_vits2.text.japanese_bert_v200 import get_bert_feature as ja_bert_v200
from bert_vits2.text.english_bert_mock_v200 import get_bert_feature as en_bert_v200
from bert_vits2.text.chinese_bert_extra import get_bert_feature as zh_bert_extra
from bert_vits2.text.japanese_bert_extra import get_bert_feature as ja_bert_extra
class ModelHandler:
def __init__(self, device=config.system.device):
self.DOWNLOAD_PATHS = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": [
"https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
],
"BERT_BASE_JAPANESE_V3": [
"https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin",
],
"BERT_LARGE_JAPANESE_V2": [
"https://huggingface.co/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin",
],
"DEBERTA_V2_LARGE_JAPANESE": [
"https://huggingface.co/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin",
],
"DEBERTA_V3_LARGE": [
"https://huggingface.co/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin",
],
"SPM": [
"https://huggingface.co/microsoft/deberta-v3-large/resolve/main/spm.model",
"https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/spm.model",
],
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": [
"https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin",
],
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": [
"https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin",
],
"CLAP_HTSAT_FUSED": [
"https://huggingface.co/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true",
"https://hf-mirror.com/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true",
],
"Erlangshen_MegatronBert_1.3B_Chinese": [
"https://huggingface.co/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin",
],
"G2PWModel": [
# "https://storage.googleapis.com/esun-ai/g2pW/G2PWModel-v2-onnx.zip",
"https://huggingface.co/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx",
"https://hf-mirror.com/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx",
],
"CHINESE_HUBERT_BASE": [
"https://huggingface.co/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin",
]
}
self.SHA256 = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": "4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd",
"BERT_BASE_JAPANESE_V3": "e172862e0674054d65e0ba40d67df2a4687982f589db44aa27091c386e5450a4",
"BERT_LARGE_JAPANESE_V2": "50212d714f79af45d3e47205faa356d0e5030e1c9a37138eadda544180f9e7c9",
"DEBERTA_V2_LARGE_JAPANESE": "a6c15feac0dea77ab8835c70e1befa4cf4c2137862c6fb2443b1553f70840047",
"DEBERTA_V3_LARGE": "dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5",
"SPM": "c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd",
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": "bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201",
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": "176d9d1ce29a8bddbab44068b9c1c194c51624c7f1812905e01355da58b18816",
"CLAP_HTSAT_FUSED": "1ed5d0215d887551ddd0a49ce7311b21429ebdf1e6a129d4e68f743357225253",
"Erlangshen_MegatronBert_1.3B_Chinese": "3456bb8f2c7157985688a4cb5cecdb9e229cb1dcf785b01545c611462ffe3579",
# "G2PWModel": "bb40c8c7b5baa755b2acd317c6bc5a65e4af7b80c40a569247fbd76989299999",
"G2PWModel": "",
"CHINESE_HUBERT_BASE": "2fefccd26c2794a583b80f6f7210c721873cb7ebae2c1cde3baf9b27855e24d8",
}
self.model_path = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.chinese_roberta_wwm_ext_large),
"BERT_BASE_JAPANESE_V3": os.path.join(config.abs_path, config.system.data_path,
config.model_config.bert_base_japanese_v3),
"BERT_LARGE_JAPANESE_V2": os.path.join(config.abs_path, config.system.data_path,
config.model_config.bert_large_japanese_v2),
"DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v2_large_japanese),
"DEBERTA_V3_LARGE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v3_large),
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v2_large_japanese_char_wwm),
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": os.path.join(config.abs_path, config.system.data_path,
config.model_config.wav2vec2_large_robust_12_ft_emotion_msp_dim),
"CLAP_HTSAT_FUSED": os.path.join(config.abs_path, config.system.data_path,
config.model_config.clap_htsat_fused),
"Erlangshen_MegatronBert_1.3B_Chinese": os.path.join(config.abs_path, config.system.data_path,
config.model_config.erlangshen_MegatronBert_1_3B_Chinese),
"G2PWModel": os.path.join(config.abs_path, config.system.data_path, config.model_config.g2pw_model),
"CHINESE_HUBERT_BASE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.chinese_hubert_base),
}
self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111,
"ja_v200": ja_bert_v200, "en_v200": en_bert_v200, "zh_extra": zh_bert_extra,
"ja_extra": ja_bert_extra}
self.bert_models = {} # Value: (tokenizer, model, reference_count)
self.emotion = None
self.clap = None
self.pinyinPlus = None
self.device = device
self.ssl_model = None
if config.bert_vits2_config.torch_data_type.lower() in ["float16", "fp16"]:
self.torch_dtype = torch.float16
else:
self.torch_dtype = None
@property
def emotion_model(self):
return self.emotion["model"]
@property
def emotion_processor(self):
return self.emotion["processor"]
@property
def clap_model(self):
return self.clap["model"]
@property
def clap_processor(self):
return self.clap["processor"]
def _download_model(self, model_name, target_path=None):
urls = self.DOWNLOAD_PATHS[model_name]
if target_path is None:
target_path = os.path.join(self.model_path[model_name], "pytorch_model.bin")
expected_sha256 = self.SHA256[model_name]
success, message = download_file(urls, target_path, expected_sha256=expected_sha256)
if not success:
logging.error(f"Failed to download {model_name}: {message}")
else:
logging.info(f"{message}")
def load_bert(self, bert_model_name, max_retries=3):
if bert_model_name not in self.bert_models:
retries = 0
model_path = ""
while retries < max_retries:
model_path = self.model_path[bert_model_name]
logging.info(f"Loading BERT model: {model_path}")
try:
if bert_model_name == "Erlangshen_MegatronBert_1.3B_Chinese":
tokenizer = BertTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype)
model = MegatronBertModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype)
model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
self.bert_models[bert_model_name] = (tokenizer, model, 1) # 初始化引用计数为1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
logging.info(f"Trying to download.")
if bert_model_name == "DEBERTA_V3_LARGE" and not os.path.exists(
os.path.join(model_path, "spm.model")):
self._download_model("SPM", os.path.join(model_path, "spm.model"))
self._download_model(bert_model_name)
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
tokenizer, model, count = self.bert_models[bert_model_name]
self.bert_models[bert_model_name] = (tokenizer, model, count + 1)
def load_emotion(self, max_retries=3):
"""Bert-VITS2 v2.1 EmotionModel"""
if self.emotion is None:
from transformers import Wav2Vec2Processor
from bert_vits2.get_emo import EmotionModel
retries = 0
model_path = self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"]
while retries < max_retries:
logging.info(f"Loading WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM: {model_path}")
try:
self.emotion = {}
self.emotion["model"] = EmotionModel.from_pretrained(model_path).to(self.device)
self.emotion["processor"] = Wav2Vec2Processor.from_pretrained(model_path)
self.emotion["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM")
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.emotion["reference_count"] += 1
def release_emotion(self):
if self.emotion is not None:
self.emotion["reference_count"] -= 1
if self.emotion["reference_count"] <= 0:
del self.emotion
self.emotion = None
gc.collect()
torch.cuda.empty_cache()
logging.info(f"Emotion model has been released.")
def load_clap(self, max_retries=3):
"""Bert-VITS2 v2.2 ClapModel"""
if self.clap is None:
from transformers import ClapModel, ClapProcessor
retries = 0
model_path = self.model_path["CLAP_HTSAT_FUSED"]
while retries < max_retries:
logging.info(f"Loading CLAP_HTSAT_FUSED: {model_path}")
try:
self.clap = {}
self.clap["model"] = ClapModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
self.clap["processor"] = ClapProcessor.from_pretrained(model_path, torch_dtype=self.torch_dtype)
self.clap["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("CLAP_HTSAT_FUSED")
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.clap["reference_count"] += 1
def release_clap(self):
if self.clap is not None:
self.clap["reference_count"] -= 1
if self.clap["reference_count"] <= 0:
del self.clap
self.clap = None
gc.collect()
torch.cuda.empty_cache()
logging.info(f"Clap model has been released.")
def get_bert_model(self, bert_model_name):
if bert_model_name not in self.bert_models:
self.load_bert(bert_model_name)
tokenizer, model, _ = self.bert_models[bert_model_name]
return tokenizer, model
def get_bert_feature(self, norm_text, word2ph, language, bert_model_name, style_text=None, style_weight=0.7):
tokenizer, model = self.get_bert_model(bert_model_name)
bert_feature = self.lang_bert_func_map[language](norm_text, word2ph, tokenizer, model, self.device,
style_text=style_text, style_weight=style_weight)
return bert_feature
def get_pinyinPlus(self):
if self.pinyinPlus is None:
from bert_vits2.g2pW.pypinyin_G2pW_bv2 import G2PWPinyin
logging.info(f"Loading G2PWModel: {self.model_path['G2PWModel']}")
self.pinyinPlus = G2PWPinyin(
model_dir=self.model_path["G2PWModel"],
model_source=self.model_path["Erlangshen_MegatronBert_1.3B_Chinese"],
v_to_u=False,
neutral_tone_with_five=True,
)
logging.info("Success loading G2PWModel")
return self.pinyinPlus
def release_bert(self, bert_model_name):
if bert_model_name in self.bert_models:
_, _, count = self.bert_models[bert_model_name]
count -= 1
if count == 0:
# 当引用计数为0时,删除模型并释放其资源
del self.bert_models[bert_model_name]
gc.collect()
torch.cuda.empty_cache()
logging.info(f"BERT model {bert_model_name} has been released.")
else:
tokenizer, model = self.bert_models[bert_model_name][:2]
self.bert_models[bert_model_name] = (tokenizer, model, count)
def load_ssl(self, max_retries=3):
"""GPT-SoVITS"""
if self.ssl_model is None:
retries = 0
model_path = self.model_path["CHINESE_HUBERT_BASE"]
while retries < max_retries:
logging.info(f"Loading CHINESE_HUBERT_BASE: {model_path}")
try:
from gpt_sovits.feature_extractor.cnhubert import CNHubert
self.ssl_model = {}
model_path = self.model_path.get("CHINESE_HUBERT_BASE")
self.ssl_model["model"] = CNHubert(model_path)
self.ssl_model["model"].eval()
if config.gpt_sovits_config.is_half:
self.ssl_model["model"] = self.ssl_model["model"].half()
self.ssl_model["model"] = self.ssl_model["model"].to(self.device)
self.ssl_model["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("CHINESE_HUBERT_BASE")
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.ssl_model["reference_count"] += 1
def get_ssl_model(self):
if self.ssl_model is None:
self.load_ssl()
return self.ssl_model.get("model")
def release_ssl_model(self):
if self.ssl_model is not None:
self.ssl_model["reference_count"] -= 1
if self.ssl_model["reference_count"] <= 0:
del self.ssl_model
self.ssl_model = None
gc.collect()
torch.cuda.empty_cache()
logging.info(f"SSL model has been released.")
def is_model_loaded(self, bert_model_name):
return bert_model_name in self.bert_models
def reference_count(self, bert_model_name):
return self.bert_models[bert_model_name][2] if bert_model_name in self.bert_models else 0