|
|
|
import os |
|
import logging |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
from vocos import Vocos |
|
from .model.dvae import DVAE |
|
from .model.gpt import GPT_warpper |
|
from .utils.gpu_utils import select_device |
|
from .utils.io_utils import get_latest_modified_file |
|
from .infer.api import refine_text, infer_code |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
logging.basicConfig(level = logging.INFO) |
|
|
|
|
|
class Chat: |
|
def __init__(self, ): |
|
self.pretrain_models = {} |
|
self.logger = logging.getLogger(__name__) |
|
|
|
def check_model(self, level = logging.INFO, use_decoder = False): |
|
not_finish = False |
|
check_list = ['vocos', 'gpt', 'tokenizer'] |
|
|
|
if use_decoder: |
|
check_list.append('decoder') |
|
else: |
|
check_list.append('dvae') |
|
|
|
for module in check_list: |
|
if module not in self.pretrain_models: |
|
self.logger.log(logging.WARNING, f'{module} not initialized.') |
|
not_finish = True |
|
|
|
if not not_finish: |
|
self.logger.log(level, f'All initialized.') |
|
|
|
return not not_finish |
|
|
|
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'): |
|
if source == 'huggingface': |
|
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) |
|
try: |
|
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots')) |
|
except: |
|
download_path = None |
|
if download_path is None or force_redownload: |
|
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS') |
|
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"]) |
|
else: |
|
self.logger.log(logging.INFO, f'Load from cache: {download_path}') |
|
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}) |
|
elif source == 'local': |
|
self.logger.log(logging.INFO, f'Load from local: {local_path}') |
|
self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()}) |
|
|
|
def _load( |
|
self, |
|
vocos_config_path: str = None, |
|
vocos_ckpt_path: str = None, |
|
dvae_config_path: str = None, |
|
dvae_ckpt_path: str = None, |
|
gpt_config_path: str = None, |
|
gpt_ckpt_path: str = None, |
|
decoder_config_path: str = None, |
|
decoder_ckpt_path: str = None, |
|
tokenizer_path: str = None, |
|
device: str = None |
|
): |
|
if not device: |
|
device = select_device(4096) |
|
self.logger.log(logging.INFO, f'use {device}') |
|
|
|
if vocos_config_path: |
|
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval() |
|
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None' |
|
vocos.load_state_dict(torch.load(vocos_ckpt_path)) |
|
self.pretrain_models['vocos'] = vocos |
|
self.logger.log(logging.INFO, 'vocos loaded.') |
|
|
|
if dvae_config_path: |
|
cfg = OmegaConf.load(dvae_config_path) |
|
dvae = DVAE(**cfg).to(device).eval() |
|
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' |
|
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu')) |
|
self.pretrain_models['dvae'] = dvae |
|
self.logger.log(logging.INFO, 'dvae loaded.') |
|
|
|
if gpt_config_path: |
|
cfg = OmegaConf.load(gpt_config_path) |
|
gpt = GPT_warpper(**cfg).to(device).eval() |
|
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' |
|
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) |
|
self.pretrain_models['gpt'] = gpt |
|
self.logger.log(logging.INFO, 'gpt loaded.') |
|
|
|
if decoder_config_path: |
|
cfg = OmegaConf.load(decoder_config_path) |
|
decoder = DVAE(**cfg).to(device).eval() |
|
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' |
|
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu')) |
|
self.pretrain_models['decoder'] = decoder |
|
self.logger.log(logging.INFO, 'decoder loaded.') |
|
|
|
if tokenizer_path: |
|
tokenizer = torch.load(tokenizer_path, map_location='cpu') |
|
tokenizer.padding_side = 'left' |
|
self.pretrain_models['tokenizer'] = tokenizer |
|
self.logger.log(logging.INFO, 'tokenizer loaded.') |
|
|
|
self.check_model() |
|
|
|
def infer( |
|
self, |
|
text, |
|
skip_refine_text=False, |
|
refine_text_only=False, |
|
params_refine_text={}, |
|
params_infer_code={}, |
|
use_decoder=False |
|
): |
|
|
|
assert self.check_model(use_decoder=use_decoder) |
|
|
|
if not skip_refine_text: |
|
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] |
|
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] |
|
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) |
|
if refine_text_only: |
|
return text |
|
|
|
text = [params_infer_code.get('prompt', '') + i for i in text] |
|
params_infer_code.pop('prompt', '') |
|
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder) |
|
|
|
if use_decoder: |
|
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']] |
|
else: |
|
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']] |
|
|
|
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec] |
|
|
|
return wav |
|
|
|
|
|
|
|
|