CosyVoice commited on
Commit
f1e374a
1 Parent(s): 8b097f7

add trt script TODO

Browse files
cosyvoice/bin/{export.py → export_jit.py} RENAMED
@@ -44,7 +44,7 @@ def main():
44
  torch._C._jit_set_profiling_mode(False)
45
  torch._C._jit_set_profiling_executor(False)
46
 
47
- cosyvoice = CosyVoice(args.model_dir, load_script=False)
48
 
49
  # 1. export llm text_encoder
50
  llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
 
44
  torch._C._jit_set_profiling_mode(False)
45
  torch._C._jit_set_profiling_executor(False)
46
 
47
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
48
 
49
  # 1. export llm text_encoder
50
  llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
cosyvoice/bin/export_trt.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
2
+ # tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
3
+ try:
4
+ import tensorrt
5
+ except ImportError:
6
+ print('step1, 下载\n step2. 解压,安装whl,')
7
+ # 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
8
+ # 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
cosyvoice/cli/cosyvoice.py CHANGED
@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
21
 
22
  class CosyVoice:
23
 
24
- def __init__(self, model_dir, load_script=True):
25
  instruct = True if '-Instruct' in model_dir else False
26
  self.model_dir = model_dir
27
  if not os.path.exists(model_dir):
@@ -39,9 +39,12 @@ class CosyVoice:
39
  self.model.load('{}/llm.pt'.format(model_dir),
40
  '{}/flow.pt'.format(model_dir),
41
  '{}/hift.pt'.format(model_dir))
42
- if load_script:
43
- self.model.load_script('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
  '{}/llm.llm.fp16.zip'.format(model_dir))
 
 
 
45
  del configs
46
 
47
  def list_avaliable_spks(self):
 
21
 
22
  class CosyVoice:
23
 
24
+ def __init__(self, model_dir, load_jit=True, load_trt=True):
25
  instruct = True if '-Instruct' in model_dir else False
26
  self.model_dir = model_dir
27
  if not os.path.exists(model_dir):
 
39
  self.model.load('{}/llm.pt'.format(model_dir),
40
  '{}/flow.pt'.format(model_dir),
41
  '{}/hift.pt'.format(model_dir))
42
+ if load_jit:
43
+ self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
  '{}/llm.llm.fp16.zip'.format(model_dir))
45
+ if load_trt:
46
+ # TODO
47
+ self.model.load_trt()
48
  del configs
49
 
50
  def list_avaliable_spks(self):
cosyvoice/cli/model.py CHANGED
@@ -53,12 +53,17 @@ class CosyVoiceModel:
53
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
54
  self.hift.to(self.device).eval()
55
 
56
- def load_script(self, llm_text_encoder_model, llm_llm_model):
57
  llm_text_encoder = torch.jit.load(llm_text_encoder_model)
58
  self.llm.text_encoder = llm_text_encoder
59
  llm_llm = torch.jit.load(llm_llm_model)
60
  self.llm.llm = llm_llm
61
 
 
 
 
 
 
62
  def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
63
  with self.llm_context:
64
  for i in self.llm.inference(text=text.to(self.device),
 
53
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
54
  self.hift.to(self.device).eval()
55
 
56
+ def load_jit(self, llm_text_encoder_model, llm_llm_model):
57
  llm_text_encoder = torch.jit.load(llm_text_encoder_model)
58
  self.llm.text_encoder = llm_text_encoder
59
  llm_llm = torch.jit.load(llm_llm_model)
60
  self.llm.llm = llm_llm
61
 
62
+ def load_trt(self):
63
+ # TODO 你需要的TRT推理的准备
64
+ self.flow.decoder.estimator = xxx
65
+ self.flow.decoder.session = xxx
66
+
67
  def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
68
  with self.llm_context:
69
  for i in self.llm.inference(text=text.to(self.device),
cosyvoice/flow/flow_matching.py CHANGED
@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
77
  sol = []
78
 
79
  for step in range(1, len(t_span)):
80
- dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
81
  # Classifier-Free Guidance inference introduced in VoiceBox
82
  if self.inference_cfg_rate > 0:
83
- cfg_dphi_dt = self.estimator(
84
  x, mask,
85
  torch.zeros_like(mu), t,
86
  torch.zeros_like(spks) if spks is not None else None,
@@ -96,6 +96,14 @@ class ConditionalCFM(BASECFM):
96
 
97
  return sol[-1]
98
 
 
 
 
 
 
 
 
 
99
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
100
  """Computes diffusion loss
101
 
 
77
  sol = []
78
 
79
  for step in range(1, len(t_span)):
80
+ dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
81
  # Classifier-Free Guidance inference introduced in VoiceBox
82
  if self.inference_cfg_rate > 0:
83
+ cfg_dphi_dt = self.forward_estimator(
84
  x, mask,
85
  torch.zeros_like(mu), t,
86
  torch.zeros_like(spks) if spks is not None else None,
 
96
 
97
  return sol[-1]
98
 
99
+ # TODO
100
+ def forward_estimator(self):
101
+ if isinstance(self.estimator, trt):
102
+ assert self.training is False, 'tensorrt cannot be used in training'
103
+ return xxx
104
+ else:
105
+ return self.estimator.forward
106
+
107
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
108
  """Computes diffusion loss
109