wenmengzhou commited on
Commit
5297156
1 Parent(s): 589560d

update model according to hysts advice

Browse files
Files changed (1) hide show
  1. cosyvoice/cli/model.py +4 -6
cosyvoice/cli/model.py CHANGED
@@ -19,18 +19,17 @@ class CosyVoiceModel:
19
  llm: torch.nn.Module,
20
  flow: torch.nn.Module,
21
  hift: torch.nn.Module):
22
- #self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
- self.device = 'cpu'
24
  self.llm = llm
25
  self.flow = flow
26
  self.hift = hift
27
 
28
  def load(self, llm_model, flow_model, hift_model):
29
- self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
30
  self.llm.to(self.device).eval()
31
- self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
32
  self.flow.to(self.device).eval()
33
- self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
34
  self.hift.to(self.device).eval()
35
 
36
  def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
@@ -38,7 +37,6 @@ class CosyVoiceModel:
38
  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
40
  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
41
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
  tts_speech_token = self.llm.inference(text=text.to(self.device),
43
  text_len=text_len.to(self.device),
44
  prompt_text=prompt_text.to(self.device),
 
19
  llm: torch.nn.Module,
20
  flow: torch.nn.Module,
21
  hift: torch.nn.Module):
22
+ self.device = torch.device('cuda')
 
23
  self.llm = llm
24
  self.flow = flow
25
  self.hift = hift
26
 
27
  def load(self, llm_model, flow_model, hift_model):
28
+ self.llm.load_state_dict(torch.load(llm_model, map_location='cpu'))
29
  self.llm.to(self.device).eval()
30
+ self.flow.load_state_dict(torch.load(flow_model, map_location='cpu'))
31
  self.flow.to(self.device).eval()
32
+ self.hift.load_state_dict(torch.load(hift_model, map_location='cpu'))
33
  self.hift.to(self.device).eval()
34
 
35
  def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
 
37
  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
 
40
  tts_speech_token = self.llm.inference(text=text.to(self.device),
41
  text_len=text_len.to(self.device),
42
  prompt_text=prompt_text.to(self.device),