Spaces:
Running
on
T4
Running
on
T4
Mark Duppenthaler
commited on
Commit
•
2d522b6
1
Parent(s):
1727d3b
Update with temp work
Browse files- app.py +6 -1
- internal_demo_simuleval_transcoder.py +272 -0
- requirements.txt +16 -2
- simuleval_transcoder.py +178 -0
app.py
CHANGED
@@ -10,6 +10,8 @@ from seamless_communication.models.inference.translator import Translator
|
|
10 |
|
11 |
|
12 |
from m4t_app import *
|
|
|
|
|
13 |
|
14 |
from pydub import AudioSegment
|
15 |
import time
|
@@ -19,6 +21,7 @@ from time import sleep
|
|
19 |
|
20 |
USE_M4T = True
|
21 |
|
|
|
22 |
|
23 |
def translate_audio_file_segment(audio_file):
|
24 |
print("translate_m4t state")
|
@@ -90,7 +93,9 @@ def blocks():
|
|
90 |
)
|
91 |
|
92 |
most_recent_input_audio_segment = gr.Audio(
|
93 |
-
label="Recent Input Audio Segment segments",
|
|
|
|
|
94 |
)
|
95 |
# TODO: Should add combined input audio segments...
|
96 |
|
|
|
10 |
|
11 |
|
12 |
from m4t_app import *
|
13 |
+
from simuleval_transcoder import *
|
14 |
+
# from simuleval_transcoder import *
|
15 |
|
16 |
from pydub import AudioSegment
|
17 |
import time
|
|
|
21 |
|
22 |
USE_M4T = True
|
23 |
|
24 |
+
Transcoder = SimulevalTranscoder()
|
25 |
|
26 |
def translate_audio_file_segment(audio_file):
|
27 |
print("translate_m4t state")
|
|
|
93 |
)
|
94 |
|
95 |
most_recent_input_audio_segment = gr.Audio(
|
96 |
+
label="Recent Input Audio Segment segments",
|
97 |
+
format="bytes",
|
98 |
+
streaming=True
|
99 |
)
|
100 |
# TODO: Should add combined input audio segments...
|
101 |
|
internal_demo_simuleval_transcoder.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from simuleval.utils.agent import build_system_from_dir
|
2 |
+
from typing import Any, Tuple
|
3 |
+
import numpy as np
|
4 |
+
import soundfile
|
5 |
+
from fairseq.data.audio.audio_utils import convert_waveform
|
6 |
+
import io
|
7 |
+
import asyncio
|
8 |
+
from simuleval.data.segments import SpeechSegment, EmptySegment
|
9 |
+
import threading
|
10 |
+
import math
|
11 |
+
import logging
|
12 |
+
import sys
|
13 |
+
from pathlib import Path
|
14 |
+
import time
|
15 |
+
from g2p_en import G2p
|
16 |
+
import torch
|
17 |
+
import traceback
|
18 |
+
import time
|
19 |
+
import random
|
20 |
+
|
21 |
+
from .speech_and_text_output import SpeechAndTextOutput
|
22 |
+
|
23 |
+
MODEL_SAMPLE_RATE = 16_000
|
24 |
+
|
25 |
+
logger = logging.getLogger()
|
26 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
27 |
+
|
28 |
+
|
29 |
+
class SimulevalTranscoder:
|
30 |
+
def __init__(self, agent, sample_rate, debug, buffer_limit):
|
31 |
+
self.agent = agent
|
32 |
+
self.input_queue = asyncio.Queue()
|
33 |
+
self.output_queue = asyncio.Queue()
|
34 |
+
self.states = self.agent.build_states()
|
35 |
+
if debug:
|
36 |
+
self.states[0].debug = True
|
37 |
+
self.incoming_sample_rate = sample_rate
|
38 |
+
self.close = False
|
39 |
+
self.g2p = G2p()
|
40 |
+
|
41 |
+
# buffer all outgoing translations within this amount of time
|
42 |
+
self.output_buffer_idle_ms = 5000
|
43 |
+
self.output_buffer_size_limit = (
|
44 |
+
buffer_limit # phonemes for text, seconds for speech
|
45 |
+
)
|
46 |
+
self.output_buffer_cur_size = 0
|
47 |
+
self.output_buffer = []
|
48 |
+
self.speech_output_sample_rate = None
|
49 |
+
|
50 |
+
self.last_output_ts = time.time() * 1000
|
51 |
+
self.timeout_ms = (
|
52 |
+
30000 # close the transcoder thread after this amount of silence
|
53 |
+
)
|
54 |
+
self.first_input_ts = None
|
55 |
+
self.first_output_ts = None
|
56 |
+
self.output_data_type = None # speech or text
|
57 |
+
self.debug = debug
|
58 |
+
self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
|
59 |
+
if self.debug:
|
60 |
+
debug_folder = Path(__file__).resolve().parent.parent / "debug"
|
61 |
+
self.test_incoming_wav = soundfile.SoundFile(
|
62 |
+
debug_folder / f"{self.debug_ts}_test_incoming.wav",
|
63 |
+
mode="w+",
|
64 |
+
format="WAV",
|
65 |
+
subtype="PCM_16",
|
66 |
+
samplerate=self.incoming_sample_rate,
|
67 |
+
channels=1,
|
68 |
+
)
|
69 |
+
self.states[0].test_input_segments_wav = soundfile.SoundFile(
|
70 |
+
debug_folder / f"{self.debug_ts}_test_input_segments.wav",
|
71 |
+
mode="w+",
|
72 |
+
format="WAV",
|
73 |
+
samplerate=MODEL_SAMPLE_RATE,
|
74 |
+
channels=1,
|
75 |
+
)
|
76 |
+
|
77 |
+
def debug_log(self, *args):
|
78 |
+
if self.debug:
|
79 |
+
logger.info(*args)
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def build_agent(cls, model_path):
|
83 |
+
logger.info(f"Building simuleval agent: {model_path}")
|
84 |
+
agent = build_system_from_dir(
|
85 |
+
Path(__file__).resolve().parent.parent / f"models/{model_path}",
|
86 |
+
config_name="vad_main.yaml",
|
87 |
+
)
|
88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
89 |
+
agent.to(device, fp16=True)
|
90 |
+
logger.info(
|
91 |
+
f"Successfully built simuleval agent {model_path} on device {device}"
|
92 |
+
)
|
93 |
+
|
94 |
+
return agent
|
95 |
+
|
96 |
+
def process_incoming_bytes(self, incoming_bytes):
|
97 |
+
segment, _sr = self._preprocess_wav(incoming_bytes)
|
98 |
+
# # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
|
99 |
+
self.input_queue.put_nowait(segment)
|
100 |
+
|
101 |
+
def get_input_segment(self):
|
102 |
+
if self.input_queue.empty():
|
103 |
+
return None
|
104 |
+
chunk = self.input_queue.get_nowait()
|
105 |
+
self.input_queue.task_done()
|
106 |
+
return chunk
|
107 |
+
|
108 |
+
def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
|
109 |
+
segment, sample_rate = soundfile.read(
|
110 |
+
io.BytesIO(data),
|
111 |
+
dtype="float32",
|
112 |
+
always_2d=True,
|
113 |
+
frames=-1,
|
114 |
+
start=0,
|
115 |
+
format="RAW",
|
116 |
+
subtype="PCM_16",
|
117 |
+
samplerate=self.incoming_sample_rate,
|
118 |
+
channels=1,
|
119 |
+
)
|
120 |
+
if self.debug:
|
121 |
+
self.test_incoming_wav.seek(0, soundfile.SEEK_END)
|
122 |
+
self.test_incoming_wav.write(segment)
|
123 |
+
|
124 |
+
segment = segment.T
|
125 |
+
segment, new_sample_rate = convert_waveform(
|
126 |
+
segment,
|
127 |
+
sample_rate,
|
128 |
+
normalize_volume=False,
|
129 |
+
to_mono=True,
|
130 |
+
to_sample_rate=MODEL_SAMPLE_RATE,
|
131 |
+
)
|
132 |
+
|
133 |
+
assert MODEL_SAMPLE_RATE == new_sample_rate
|
134 |
+
segment = segment.squeeze(axis=0)
|
135 |
+
return segment, new_sample_rate
|
136 |
+
|
137 |
+
def process_pipeline_impl(self, input_segment):
|
138 |
+
try:
|
139 |
+
output_segment = self.agent.pushpop(input_segment, self.states)
|
140 |
+
if (
|
141 |
+
self.states[0].first_input_ts is not None
|
142 |
+
and self.first_input_ts is None
|
143 |
+
):
|
144 |
+
# TODO: this is hacky
|
145 |
+
self.first_input_ts = self.states[0].first_input_ts
|
146 |
+
|
147 |
+
if not output_segment.is_empty:
|
148 |
+
self.output_queue.put_nowait(output_segment)
|
149 |
+
|
150 |
+
if output_segment.finished:
|
151 |
+
self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
|
152 |
+
|
153 |
+
for state in self.states:
|
154 |
+
state.reset()
|
155 |
+
|
156 |
+
if self.debug:
|
157 |
+
# when we rebuild states, this value is reset to whatever
|
158 |
+
# is in the system dir config, which defaults debug=False.
|
159 |
+
self.states[0].debug = True
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Got exception while processing pipeline: {e}")
|
162 |
+
traceback.print_exc()
|
163 |
+
return input_segment
|
164 |
+
|
165 |
+
def process_pipeline_loop(self):
|
166 |
+
if self.close:
|
167 |
+
return # closes the thread
|
168 |
+
|
169 |
+
self.debug_log("processing_pipeline")
|
170 |
+
while not self.close:
|
171 |
+
input_segment = self.get_input_segment()
|
172 |
+
if input_segment is None:
|
173 |
+
if self.states[0].is_fresh_state: # TODO: this is hacky
|
174 |
+
time.sleep(0.3)
|
175 |
+
else:
|
176 |
+
time.sleep(0.03)
|
177 |
+
continue
|
178 |
+
self.process_pipeline_impl(input_segment)
|
179 |
+
self.debug_log("finished processing_pipeline")
|
180 |
+
|
181 |
+
def process_pipeline_once(self):
|
182 |
+
if self.close:
|
183 |
+
return
|
184 |
+
|
185 |
+
self.debug_log("processing pipeline once")
|
186 |
+
input_segment = self.get_input_segment()
|
187 |
+
if input_segment is None:
|
188 |
+
return
|
189 |
+
self.process_pipeline_impl(input_segment)
|
190 |
+
self.debug_log("finished processing_pipeline_once")
|
191 |
+
|
192 |
+
def get_output_segment(self):
|
193 |
+
if self.output_queue.empty():
|
194 |
+
return None
|
195 |
+
|
196 |
+
output_chunk = self.output_queue.get_nowait()
|
197 |
+
self.output_queue.task_done()
|
198 |
+
return output_chunk
|
199 |
+
|
200 |
+
def start(self):
|
201 |
+
self.debug_log("starting transcoder in a thread")
|
202 |
+
threading.Thread(target=self.process_pipeline_loop).start()
|
203 |
+
|
204 |
+
def first_translation_time(self):
|
205 |
+
return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
|
206 |
+
|
207 |
+
def get_buffered_output(self) -> SpeechAndTextOutput:
|
208 |
+
now = time.time() * 1000
|
209 |
+
self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
|
210 |
+
while not self.output_queue.empty():
|
211 |
+
tmp_out = self.get_output_segment()
|
212 |
+
if tmp_out and len(tmp_out.content) > 0:
|
213 |
+
if not self.output_data_type:
|
214 |
+
self.output_data_type = tmp_out.data_type
|
215 |
+
if len(self.output_buffer) == 0:
|
216 |
+
self.last_output_ts = now
|
217 |
+
self._populate_output_buffer(tmp_out)
|
218 |
+
self._increment_output_buffer_size(tmp_out)
|
219 |
+
|
220 |
+
if tmp_out.finished:
|
221 |
+
res = self._gather_output_buffer_data(final=True)
|
222 |
+
self.output_buffer = []
|
223 |
+
self.increment_output_buffer_size = 0
|
224 |
+
self.last_output_ts = now
|
225 |
+
self.first_output_ts = now
|
226 |
+
return res
|
227 |
+
|
228 |
+
if len(self.output_buffer) > 0 and (
|
229 |
+
now - self.last_output_ts >= self.output_buffer_idle_ms
|
230 |
+
or self.output_buffer_cur_size >= self.output_buffer_size_limit
|
231 |
+
):
|
232 |
+
self.last_output_ts = now
|
233 |
+
res = self._gather_output_buffer_data(final=False)
|
234 |
+
self.output_buffer = []
|
235 |
+
self.output_buffer_phoneme_count = 0
|
236 |
+
self.first_output_ts = now
|
237 |
+
return res
|
238 |
+
else:
|
239 |
+
return None
|
240 |
+
|
241 |
+
def _gather_output_buffer_data(self, final):
|
242 |
+
if self.output_data_type == "text":
|
243 |
+
return SpeechAndTextOutput(text=" ".join(self.output_buffer), final=final)
|
244 |
+
elif self.output_data_type == "speech":
|
245 |
+
return SpeechAndTextOutput(
|
246 |
+
speech_samples=self.output_buffer,
|
247 |
+
speech_sample_rate=MODEL_SAMPLE_RATE,
|
248 |
+
final=final,
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
raise ValueError(
|
252 |
+
f"Invalid output buffer data type: {self.output_data_type}"
|
253 |
+
)
|
254 |
+
|
255 |
+
def _increment_output_buffer_size(self, segment):
|
256 |
+
if segment.data_type == "text":
|
257 |
+
self.output_buffer_cur_size += self._compute_phoneme_count(segment.content)
|
258 |
+
elif segment.data_type == "speech":
|
259 |
+
self.output_buffer_cur_size += (
|
260 |
+
len(segment.content) / MODEL_SAMPLE_RATE
|
261 |
+
) # seconds
|
262 |
+
|
263 |
+
def _populate_output_buffer(self, segment):
|
264 |
+
if segment.data_type == "text":
|
265 |
+
self.output_buffer.append(segment.content)
|
266 |
+
elif segment.data_type == "speech":
|
267 |
+
self.output_buffer += segment.content
|
268 |
+
else:
|
269 |
+
raise ValueError(f"Invalid segment data type: {segment.data_type}")
|
270 |
+
|
271 |
+
def _compute_phoneme_count(self, string: str) -> int:
|
272 |
+
return len([x for x in self.g2p(string) if x != " "])
|
requirements.txt
CHANGED
@@ -1,9 +1,23 @@
|
|
1 |
# fairseq2==0.1.0
|
|
|
|
|
2 |
git+https://github.com/mduppes/fairseq2.git@93420c86ba01349ee8f90d7adda439b666b50557
|
3 |
-
git+https://github.com/facebookresearch/seamless_communication
|
|
|
|
|
|
|
4 |
gradio==3.41.0
|
5 |
huggingface_hub==0.16.4
|
6 |
torch==2.0.1
|
7 |
torchaudio==2.0.2
|
8 |
transformers==4.32.1
|
9 |
-
pydub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# fairseq2==0.1.0
|
2 |
+
|
3 |
+
# Temp to skip
|
4 |
git+https://github.com/mduppes/fairseq2.git@93420c86ba01349ee8f90d7adda439b666b50557
|
5 |
+
# git+https://github.com/facebookresearch/seamless_communication
|
6 |
+
./seamless_communication
|
7 |
+
# comment this out to test fairseq1 first
|
8 |
+
# git+https://github.com/facebookresearch/SimulEval.git
|
9 |
gradio==3.41.0
|
10 |
huggingface_hub==0.16.4
|
11 |
torch==2.0.1
|
12 |
torchaudio==2.0.2
|
13 |
transformers==4.32.1
|
14 |
+
pydub
|
15 |
+
|
16 |
+
|
17 |
+
# Can't import fairseq1 together.. causes conflict:
|
18 |
+
#The conflict is caused by:
|
19 |
+
# The user requested simuleval 1.1.0 (from git+ssh://****@github.com/facebookresearch/SimulEval.git@tree_pipeline)
|
20 |
+
# seamless-communication 1.0.0 depends on simuleval 1.0.3.dev36+gd84fa60 (from git+https://github.com/mduppes/SimulEval.git@main)
|
21 |
+
# From fairseq1 pipeline
|
22 |
+
# git+ssh://git@github.com/fairinternal/fairseq-py.git@emma_incremental_decoder
|
23 |
+
# git+ssh://git@github.com/facebookresearch/SimulEval.git@tree_pipeline
|
simuleval_transcoder.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from fairseq2.assets.card import AssetCard
|
8 |
+
from fairseq2.data import Collater
|
9 |
+
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
|
10 |
+
from fairseq2.data.text.text_tokenizer import TextTokenizer
|
11 |
+
from fairseq2.data.typing import StringLike
|
12 |
+
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
|
13 |
+
from fairseq2.memory import MemoryBlock
|
14 |
+
from fairseq2.typing import DataType, Device
|
15 |
+
from torch import Tensor
|
16 |
+
from enum import Enum, auto
|
17 |
+
from seamless_communication.models.inference.ngram_repeat_block_processor import (
|
18 |
+
NGramRepeatBlockProcessor,
|
19 |
+
)
|
20 |
+
|
21 |
+
from seamless_communication.models.unity import (
|
22 |
+
UnitTokenizer,
|
23 |
+
UnitYGenerator,
|
24 |
+
UnitYModel,
|
25 |
+
load_unity_model,
|
26 |
+
load_unity_text_tokenizer,
|
27 |
+
load_unity_unit_tokenizer,
|
28 |
+
)
|
29 |
+
from seamless_communication.models.unity.generator import SequenceToUnitOutput
|
30 |
+
from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
from seamless_communication.models.streaming.agents import (
|
35 |
+
SileroVADAgent,
|
36 |
+
TestTimeWaitKS2TVAD,
|
37 |
+
TestTimeWaitKUnityV1M4T
|
38 |
+
)
|
39 |
+
|
40 |
+
### From test_pipeline
|
41 |
+
import math
|
42 |
+
import soundfile
|
43 |
+
from argparse import Namespace, ArgumentParser
|
44 |
+
from simuleval.data.segments import SpeechSegment, EmptySegment
|
45 |
+
from simuleval.utils import build_system_from_dir
|
46 |
+
from pathlib import Path
|
47 |
+
import numpy as np
|
48 |
+
|
49 |
+
class AudioFrontEnd:
|
50 |
+
def __init__(self, wav_file, segment_size) -> None:
|
51 |
+
self.samples, self.sample_rate = soundfile.read(wav_file)
|
52 |
+
# print(len(self.samples), self.samples[:100])
|
53 |
+
self.samples = self.samples.tolist()
|
54 |
+
self.segment_size = segment_size
|
55 |
+
self.step = 0
|
56 |
+
def send_segment(self):
|
57 |
+
"""
|
58 |
+
This is the front-end logic in simuleval instance.py
|
59 |
+
"""
|
60 |
+
num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate)
|
61 |
+
print("self.segment_size", self.segment_size)
|
62 |
+
print('num_samples is', num_samples)
|
63 |
+
print('self.sample_rate is', self.sample_rate)
|
64 |
+
if self.step < len(self.samples):
|
65 |
+
if self.step + num_samples >= len(self.samples):
|
66 |
+
samples = self.samples[self.step :]
|
67 |
+
is_finished = True
|
68 |
+
else:
|
69 |
+
samples = self.samples[self.step : self.step + num_samples]
|
70 |
+
is_finished = False
|
71 |
+
self.step = min(self.step + num_samples, len(self.samples))
|
72 |
+
# print("len(samples) is", len(samples))
|
73 |
+
# import pdb
|
74 |
+
# pdb.set_trace()
|
75 |
+
segment = SpeechSegment(
|
76 |
+
index=self.step / self.sample_rate * 1000,
|
77 |
+
content=samples,
|
78 |
+
sample_rate=self.sample_rate,
|
79 |
+
finished=is_finished,
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
# Finish reading this audio
|
83 |
+
segment = EmptySegment(
|
84 |
+
index=self.step / self.sample_rate * 1000,
|
85 |
+
finished=True,
|
86 |
+
)
|
87 |
+
return segment
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def load_model_for_inference(
|
92 |
+
load_model_fn: Callable[..., nn.Module],
|
93 |
+
model_name_or_card: Union[str, AssetCard],
|
94 |
+
device: Device,
|
95 |
+
dtype: DataType,
|
96 |
+
) -> nn.Module:
|
97 |
+
model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
|
98 |
+
model.eval()
|
99 |
+
return model
|
100 |
+
|
101 |
+
class SimulevalTranscoder:
|
102 |
+
# def __init__(self, agent, sample_rate, debug, buffer_limit):
|
103 |
+
def __init__(self):
|
104 |
+
print("MDUPPES in here", SileroVADAgent, TestTimeWaitKS2TVAD)
|
105 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
106 |
+
|
107 |
+
device = "cpu"
|
108 |
+
print("DEVICE", device)
|
109 |
+
model_name_or_card="seamlessM4T_medium"
|
110 |
+
vocoder_name_or_card="vocoder_36langs"
|
111 |
+
# dtype=torch.float16,
|
112 |
+
# For CPU Mode need to use 32, float16 causes errors downstream
|
113 |
+
dtype=dtype=torch.float32
|
114 |
+
|
115 |
+
model: UnitYModel = load_model_for_inference(
|
116 |
+
load_unity_model, model_name_or_card, device, dtype
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
print(model, type(model))
|
121 |
+
parser = ArgumentParser()
|
122 |
+
source_segment_size = 320 # milliseconds
|
123 |
+
audio_frontend = AudioFrontEnd(
|
124 |
+
wav_file="/checkpoint/mduppes/samples/marta.wav",
|
125 |
+
segment_size=source_segment_size,
|
126 |
+
)
|
127 |
+
|
128 |
+
# mostly taken from S2S first agent: OnlineFeatureExtractorAgent defaults
|
129 |
+
SHIFT_SIZE = 10
|
130 |
+
WINDOW_SIZE = 25
|
131 |
+
SAMPLE_RATE = 16000
|
132 |
+
FEATURE_DIM = 80
|
133 |
+
|
134 |
+
# args and convert to namespace so it can be accesed via .
|
135 |
+
args = {
|
136 |
+
"shift_size": SHIFT_SIZE,
|
137 |
+
"window_size": WINDOW_SIZE,
|
138 |
+
"sample_rate": audio_frontend.sample_rate,
|
139 |
+
"feature_dim": 160, # from Wav2Vec2Frontend
|
140 |
+
"denormalize": False, # not sure..
|
141 |
+
"global_stats": None, # default file path containing cmvn stats..
|
142 |
+
}
|
143 |
+
print(args)
|
144 |
+
args = Namespace(**args)
|
145 |
+
|
146 |
+
pipeline = TestTimeWaitKUnityV1M4T(model, args)
|
147 |
+
system_states = pipeline.build_states()
|
148 |
+
print('system states')
|
149 |
+
print(system_states)
|
150 |
+
input_segment = np.empty(0, dtype=np.int16)
|
151 |
+
segments = []
|
152 |
+
while True:
|
153 |
+
speech_segment = audio_frontend.send_segment()
|
154 |
+
input_segment = np.concatenate((input_segment, np.array(speech_segment.content)))
|
155 |
+
# Translation happens here
|
156 |
+
output_segment = pipeline.pushpop(speech_segment, system_states)
|
157 |
+
print('pushpop result')
|
158 |
+
print(output_segment)
|
159 |
+
if output_segment.finished:
|
160 |
+
segments.append(input_segment)
|
161 |
+
input_segment = np.empty(0, dtype=np.int16)
|
162 |
+
print("Resetting states")
|
163 |
+
for state in system_states:
|
164 |
+
state.reset()
|
165 |
+
if speech_segment.finished:
|
166 |
+
break
|
167 |
+
# The VAD-segmented samples from the full input audio
|
168 |
+
for i, seg in enumerate(segments):
|
169 |
+
with soundfile.SoundFile(
|
170 |
+
Path("/checkpoint/mduppes/samples") / f"marta_{i}.wav",
|
171 |
+
mode="w+",
|
172 |
+
format="WAV",
|
173 |
+
samplerate=16000,
|
174 |
+
channels=1,
|
175 |
+
) as f:
|
176 |
+
f.seek(0, soundfile.SEEK_END)
|
177 |
+
f.write(seg)
|
178 |
+
|