CosyVoice commited on
Commit
ff8e635
1 Parent(s): 2665b06

use thread pool in tools

Browse files
tools/extract_embedding.py CHANGED
@@ -13,74 +13,39 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
16
- import os
17
- from concurrent.futures import ThreadPoolExecutor
18
-
19
  import onnxruntime
20
  import torch
21
  import torchaudio
22
  import torchaudio.compliance.kaldi as kaldi
23
  from tqdm import tqdm
24
- from itertools import repeat
25
 
26
 
27
- def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
28
- audio, sample_rate = torchaudio.load(wav_file)
29
  if sample_rate != 16000:
30
- audio = torchaudio.transforms.Resample(
31
- orig_freq=sample_rate, new_freq=16000
32
- )(audio)
33
- feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
 
34
  feat = feat - feat.mean(dim=0, keepdim=True)
35
  embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
36
- return (utt, embedding)
37
 
38
 
39
  def main(args):
40
- utt2wav, utt2spk = {}, {}
41
- with open("{}/wav.scp".format(args.dir)) as f:
42
- for l in f:
43
- l = l.replace("\n", "").split()
44
- utt2wav[l[0]] = l[1]
45
- with open("{}/utt2spk".format(args.dir)) as f:
46
- for l in f:
47
- l = l.replace("\n", "").split()
48
- utt2spk[l[0]] = l[1]
49
-
50
- assert os.path.exists(args.onnx_path), "onnx_path not exists"
51
-
52
- option = onnxruntime.SessionOptions()
53
- option.graph_optimization_level = (
54
- onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
55
- )
56
- option.intra_op_num_threads = 1
57
- providers = ["CPUExecutionProvider"]
58
- ort_session = onnxruntime.InferenceSession(
59
- args.onnx_path, sess_options=option, providers=providers
60
- )
61
-
62
- all_utt = utt2wav.keys()
63
-
64
- with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
65
- results = list(
66
- tqdm(
67
- executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
68
- total=len(utt2wav),
69
- desc="Process data: "
70
- )
71
- )
72
-
73
  utt2embedding, spk2embedding = {}, {}
74
- for utt, embedding in results:
 
75
  utt2embedding[utt] = embedding
76
  spk = utt2spk[utt]
77
  if spk not in spk2embedding:
78
  spk2embedding[spk] = []
79
  spk2embedding[spk].append(embedding)
80
-
81
  for k, v in spk2embedding.items():
82
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
83
-
84
  torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
85
  torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
86
 
@@ -91,4 +56,22 @@ if __name__ == "__main__":
91
  parser.add_argument("--onnx_path", type=str)
92
  parser.add_argument("--num_thread", type=int, default=8)
93
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  main(args)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
17
  import onnxruntime
18
  import torch
19
  import torchaudio
20
  import torchaudio.compliance.kaldi as kaldi
21
  from tqdm import tqdm
 
22
 
23
 
24
+ def single_job(utt):
25
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
26
  if sample_rate != 16000:
27
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
28
+ feat = kaldi.fbank(audio,
29
+ num_mel_bins=80,
30
+ dither=0,
31
+ sample_frequency=16000)
32
  feat = feat - feat.mean(dim=0, keepdim=True)
33
  embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
34
+ return utt, embedding
35
 
36
 
37
  def main(args):
38
+ all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  utt2embedding, spk2embedding = {}, {}
40
+ for future in tqdm(as_completed(all_task)):
41
+ utt, embedding = future.result()
42
  utt2embedding[utt] = embedding
43
  spk = utt2spk[utt]
44
  if spk not in spk2embedding:
45
  spk2embedding[spk] = []
46
  spk2embedding[spk].append(embedding)
 
47
  for k, v in spk2embedding.items():
48
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
 
49
  torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
50
  torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
51
 
 
56
  parser.add_argument("--onnx_path", type=str)
57
  parser.add_argument("--num_thread", type=int, default=8)
58
  args = parser.parse_args()
59
+
60
+ utt2wav, utt2spk = {}, {}
61
+ with open('{}/wav.scp'.format(args.dir)) as f:
62
+ for l in f:
63
+ l = l.replace('\n', '').split()
64
+ utt2wav[l[0]] = l[1]
65
+ with open('{}/utt2spk'.format(args.dir)) as f:
66
+ for l in f:
67
+ l = l.replace('\n', '').split()
68
+ utt2spk[l[0]] = l[1]
69
+
70
+ option = onnxruntime.SessionOptions()
71
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
72
+ option.intra_op_num_threads = 1
73
+ providers = ["CPUExecutionProvider"]
74
+ ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
75
+ executor = ThreadPoolExecutor(max_workers=args.num_thread)
76
+
77
  main(args)
tools/extract_speech_token.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
 
16
  import logging
17
  import torch
18
  from tqdm import tqdm
@@ -22,7 +23,36 @@ import torchaudio
22
  import whisper
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def main(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  utt2wav = {}
27
  with open('{}/wav.scp'.format(args.dir)) as f:
28
  for l in f:
@@ -34,28 +64,6 @@ def main(args):
34
  option.intra_op_num_threads = 1
35
  providers = ["CUDAExecutionProvider"]
36
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
 
37
 
38
- utt2speech_token = {}
39
- for utt in tqdm(utt2wav.keys()):
40
- audio, sample_rate = torchaudio.load(utt2wav[utt])
41
- if sample_rate != 16000:
42
- audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
43
- if audio.shape[1] / 16000 > 30:
44
- logging.warning('do not support extract speech token for audio longer than 30s')
45
- speech_token = []
46
- else:
47
- feat = whisper.log_mel_spectrogram(audio, n_mels=128)
48
- speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
49
- ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
50
- utt2speech_token[utt] = speech_token
51
- torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
52
-
53
-
54
- if __name__ == "__main__":
55
- parser = argparse.ArgumentParser()
56
- parser.add_argument('--dir',
57
- type=str)
58
- parser.add_argument('--onnx_path',
59
- type=str)
60
- args = parser.parse_args()
61
  main(args)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import logging
18
  import torch
19
  from tqdm import tqdm
 
23
  import whisper
24
 
25
 
26
+ def single_job(utt):
27
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
28
+ if sample_rate != 16000:
29
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
30
+ if audio.shape[1] / 16000 > 30:
31
+ logging.warning('do not support extract speech token for audio longer than 30s')
32
+ speech_token = []
33
+ else:
34
+ feat = whisper.log_mel_spectrogram(audio, n_mels=128)
35
+ speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
36
+ ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
37
+ return utt, speech_token
38
+
39
+
40
  def main(args):
41
+ all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
42
+ utt2speech_token = {}
43
+ for future in tqdm(as_completed(all_task)):
44
+ utt, speech_token = future.result()
45
+ utt2speech_token[utt] = speech_token
46
+ torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
47
+
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--dir", type=str)
52
+ parser.add_argument("--onnx_path", type=str)
53
+ parser.add_argument("--num_thread", type=int, default=8)
54
+ args = parser.parse_args()
55
+
56
  utt2wav = {}
57
  with open('{}/wav.scp'.format(args.dir)) as f:
58
  for l in f:
 
64
  option.intra_op_num_threads = 1
65
  providers = ["CUDAExecutionProvider"]
66
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
67
+ executor = ThreadPoolExecutor(max_workers=args.num_thread)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  main(args)