MiXaiLL76 commited on
Commit
73271d4
1 Parent(s): 7b3e285

Implementing concurrent.futures

Browse files
Files changed (1) hide show
  1. tools/extract_embedding.py +58 -83
tools/extract_embedding.py CHANGED
@@ -13,71 +13,40 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
 
 
 
 
16
  import torch
17
  import torchaudio
18
- from tqdm import tqdm
19
- import onnxruntime
20
  import torchaudio.compliance.kaldi as kaldi
21
- from queue import Queue, Empty
22
- from threading import Thread
23
-
24
-
25
- class ExtractEmbedding:
26
- def __init__(self, model_path: str, queue: Queue, out_queue: Queue):
27
- self.model_path = model_path
28
- self.queue = queue
29
- self.out_queue = out_queue
30
- self.is_run = True
31
-
32
- def run(self):
33
- self.consumer_thread = Thread(target=self.consumer)
34
- self.consumer_thread.start()
35
-
36
- def stop(self):
37
- self.is_run = False
38
- self.consumer_thread.join()
39
-
40
- def consumer(self):
41
- option = onnxruntime.SessionOptions()
42
- option.graph_optimization_level = (
43
- onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
44
- )
45
- option.intra_op_num_threads = 1
46
- providers = ["CPUExecutionProvider"]
47
- ort_session = onnxruntime.InferenceSession(
48
- self.model_path, sess_options=option, providers=providers
49
- )
50
 
51
- while self.is_run:
52
- try:
53
- utt, wav_file = self.queue.get(timeout=1)
54
 
55
- audio, sample_rate = torchaudio.load(wav_file)
56
- if sample_rate != 16000:
57
- audio = torchaudio.transforms.Resample(
58
- orig_freq=sample_rate, new_freq=16000
59
- )(audio)
60
- feat = kaldi.fbank(
61
- audio, num_mel_bins=80, dither=0, sample_frequency=16000
62
- )
63
- feat = feat - feat.mean(dim=0, keepdim=True)
64
- embedding = (
65
- ort_session.run(
66
- None,
67
- {
68
- ort_session.get_inputs()[0]
69
- .name: feat.unsqueeze(dim=0)
70
- .cpu()
71
- .numpy()
72
- },
73
- )[0]
74
- .flatten()
75
- .tolist()
76
- )
77
- self.out_queue.put((utt, embedding))
78
- except Empty:
79
- self.is_run = False
80
- break
81
 
82
 
83
  def main(args):
@@ -91,32 +60,38 @@ def main(args):
91
  l = l.replace("\n", "").split()
92
  utt2spk[l[0]] = l[1]
93
 
94
- input_queue = Queue()
95
- output_queue = Queue()
96
- consumers = [
97
- ExtractEmbedding(args.onnx_path, input_queue, output_queue)
98
- for _ in range(args.num_thread)
 
 
 
 
 
 
 
 
 
 
99
  ]
 
 
 
 
 
 
 
 
100
 
101
  utt2embedding, spk2embedding = {}, {}
102
- for utt in tqdm(utt2wav.keys(), desc="Load data"):
103
- input_queue.put((utt, utt2wav[utt]))
104
-
105
- for c in consumers:
106
- c.run()
107
-
108
- with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
109
- while any([c.is_run for c in consumers]):
110
- try:
111
- utt, embedding = output_queue.get(timeout=1)
112
- utt2embedding[utt] = embedding
113
- spk = utt2spk[utt]
114
- if spk not in spk2embedding:
115
- spk2embedding[spk] = []
116
- spk2embedding[spk].append(embedding)
117
- pbar.update(1)
118
- except Empty:
119
- continue
120
 
121
  for k, v in spk2embedding.items():
122
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
16
+ import os
17
+ from concurrent.futures import ThreadPoolExecutor
18
+
19
+ import onnxruntime
20
  import torch
21
  import torchaudio
 
 
22
  import torchaudio.compliance.kaldi as kaldi
23
+ from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
25
 
26
+ def extract_embedding(input_list):
27
+ utt, wav_file, ort_session = input_list
28
+
29
+ audio, sample_rate = torchaudio.load(wav_file)
30
+ if sample_rate != 16000:
31
+ audio = torchaudio.transforms.Resample(
32
+ orig_freq=sample_rate, new_freq=16000
33
+ )(audio)
34
+ feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
35
+ feat = feat - feat.mean(dim=0, keepdim=True)
36
+ embedding = (
37
+ ort_session.run(
38
+ None,
39
+ {
40
+ ort_session.get_inputs()[0]
41
+ .name: feat.unsqueeze(dim=0)
42
+ .cpu()
43
+ .numpy()
44
+ },
45
+ )[0]
46
+ .flatten()
47
+ .tolist()
48
+ )
49
+ return (utt, embedding)
 
 
50
 
51
 
52
  def main(args):
 
60
  l = l.replace("\n", "").split()
61
  utt2spk[l[0]] = l[1]
62
 
63
+ assert os.path.exists(args.onnx_path), "onnx_path not exists"
64
+
65
+ option = onnxruntime.SessionOptions()
66
+ option.graph_optimization_level = (
67
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
68
+ )
69
+ option.intra_op_num_threads = 1
70
+ providers = ["CPUExecutionProvider"]
71
+ ort_session = onnxruntime.InferenceSession(
72
+ args.onnx_path, sess_options=option, providers=providers
73
+ )
74
+
75
+ inputs = [
76
+ (utt, utt2wav[utt], ort_session)
77
+ for utt in tqdm(utt2wav.keys(), desc="Load data")
78
  ]
79
+ with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
80
+ results = list(
81
+ tqdm(
82
+ executor.map(extract_embedding, inputs),
83
+ total=len(inputs),
84
+ desc="Process data: ",
85
+ )
86
+ )
87
 
88
  utt2embedding, spk2embedding = {}, {}
89
+ for utt, embedding in results:
90
+ utt2embedding[utt] = embedding
91
+ spk = utt2spk[utt]
92
+ if spk not in spk2embedding:
93
+ spk2embedding[spk] = []
94
+ spk2embedding[spk].append(embedding)
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  for k, v in spk2embedding.items():
97
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()