from typing import Any, List, Tuple, Union, Optional import numpy as np import soundfile import io import asyncio from simuleval.agents.pipeline import TreeAgentPipeline from simuleval.agents.states import AgentStates from simuleval.data.segments import Segment, EmptySegment, SpeechSegment import threading import math import logging import sys from pathlib import Path import time from g2p_en import G2p import torch import traceback import time import random import colorlog # Sanity check that pipeline is loadable from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import ( # TestTimeWaitKUnityS2TM4T, TestTimeWaitKUnityS2TM4TVAD ) from simuleval.utils.agent import build_system_args MODEL_SAMPLE_RATE = 16_000 logger = logging.getLogger(__name__) logger.propagate = False handler = colorlog.StreamHandler(stream=sys.stdout) formatter = colorlog.ColoredFormatter( "%(log_color)s[%(asctime)s][%(levelname)s][%(module)s]:%(reset)s %(message)s", reset=True, log_colors={ "DEBUG": "cyan", "INFO": "green", "WARNING": "yellow", "ERROR": "red", "CRITICAL": "red,bg_white", }, ) handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.DEBUG) # TODO: Integrate this better so target lang and others can be changed. Also currently dependent on devserver internals def build_agent(): config = { 'dataloader': 'fairseq2_s2t', 'data_file': '/large_experiments/seamless/ust/abinesh/data/s2st50_manifests/50-10/simuleval/dev_mtedx_filt_50-10_debug.tsv', 'model_name': 'seamlessM4T_v2_large', 'device': 'cuda:0', 'source_segment_size': 320, 'waitk_lagging': 7, 'fixed_pre_decision_ratio': 2, 'init_target_tokens': ' __eng__', 'max_len_a': 0, 'max_len_b': 200, 'agent_class': 'seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t.TestTimeWaitKUnityS2TM4TVAD', 'task': 's2st', 'tgt_lang': 'eng', 'latency_metrics': 'StartOffset EndOffset AL', 'output': 'TestTimeWaitKUnityS2TM4TVAD-wait7-debug' } agent , _ = build_system_args(config) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # agent.to(device, fp16=True) logger.info( f"Successfully built simuleval agent" ) return agent class SpeechAndTextOutput: def __init__( self, text: str = None, speech_samples: list = None, speech_sample_rate: float = None, final: bool = False, ): self.text = text self.speech_samples = speech_samples self.speech_sample_rate = speech_sample_rate self.final = final class OutputSegments: def __init__(self, segments: Union[List[Segment], Segment]): if isinstance(segments, Segment): segments = [segments] self.segments: List[Segment] = [s for s in segments] @property def is_empty(self): return all(segment.is_empty for segment in self.segments) @property def finished(self): return all(segment.finished for segment in self.segments) def compute_length(self, g2p): lengths = [] for segment in self.segments: if segment.data_type == "text": lengths.append(len([x for x in g2p(segment.content) if x != " "])) elif segment.data_type == "speech": lengths.append(len(segment.content) / MODEL_SAMPLE_RATE) elif isinstance(segment, EmptySegment): continue else: logger.warning( f"Unexpected data_type: {segment.data_type} not in 'speech', 'text'" ) return max(lengths) @classmethod def join_output_buffer( cls, buffer: List[List[Segment]], output: SpeechAndTextOutput ): num_segments = len(buffer[0]) for i in range(num_segments): segment_list = [ buffer[j][i] for j in range(len(buffer)) if buffer[j][i].data_type is not None ] if len(segment_list) == 0: continue if len(set(segment.data_type for segment in segment_list)) != 1: logger.warning( f"Data type mismatch at {i}: {set(segment.data_type for segment in segment_list)}" ) continue data_type = segment_list[0].data_type if data_type == "text": if output.text is not None: logger.warning("Multiple text outputs, overwriting!") output.text = " ".join([segment.content for segment in segment_list]) elif data_type == "speech": if output.speech_samples is not None: logger.warning("Multiple speech outputs, overwriting!") speech_out = [] for segment in segment_list: speech_out += segment.content output.speech_samples = speech_out output.speech_sample_rate = MODEL_SAMPLE_RATE elif isinstance(segment_list[0], EmptySegment): continue else: logger.warning( f"Invalid output buffer data type: {data_type}, expected 'speech' or 'text" ) return output def __repr__(self) -> str: repr_str = str(self.segments) return f"{self.__class__.__name__}(\n\t{repr_str}\n)" def convert_waveform( waveform: Union[np.ndarray, torch.Tensor], sample_rate: int, normalize_volume: bool = False, to_mono: bool = False, to_sample_rate: Optional[int] = None, ) -> Tuple[Union[np.ndarray, torch.Tensor], int]: """convert a waveform: - to a target sample rate - from multi-channel to mono channel - volume normalization Args: waveform (numpy.ndarray or torch.Tensor): 2D original waveform (channels x length) sample_rate (int): original sample rate normalize_volume (bool): perform volume normalization to_mono (bool): convert to mono channel if having multiple channels to_sample_rate (Optional[int]): target sample rate Returns: waveform (numpy.ndarray): converted 2D waveform (channels x length) sample_rate (float): target sample rate """ try: import torchaudio.sox_effects as ta_sox except ImportError: raise ImportError("Please install torchaudio: pip install torchaudio") effects = [] if normalize_volume: effects.append(["gain", "-n"]) if to_sample_rate is not None and to_sample_rate != sample_rate: effects.append(["rate", f"{to_sample_rate}"]) if to_mono and waveform.shape[0] > 1: effects.append(["channels", "1"]) if len(effects) > 0: is_np_input = isinstance(waveform, np.ndarray) _waveform = torch.from_numpy(waveform) if is_np_input else waveform converted, converted_sample_rate = ta_sox.apply_effects_tensor( _waveform, sample_rate, effects ) if is_np_input: converted = converted.numpy() return converted, converted_sample_rate return waveform, sample_rate class SimulevalTranscoder: def __init__(self, sample_rate, debug, buffer_limit): self.agent = build_agent() self.input_queue = asyncio.Queue() self.output_queue = asyncio.Queue() self.states = self.agent.build_states() if debug: self.get_states_root().debug = True self.incoming_sample_rate = sample_rate self.close = False self.g2p = G2p() # buffer all outgoing translations within this amount of time self.output_buffer_idle_ms = 5000 self.output_buffer_size_limit = ( buffer_limit # phonemes for text, seconds for speech ) self.output_buffer_cur_size = 0 self.output_buffer: List[List[Segment]] = [] self.speech_output_sample_rate = None self.last_output_ts = time.time() * 1000 self.timeout_ms = ( 30000 # close the transcoder thread after this amount of silence ) self.first_input_ts = None self.first_output_ts = None self.debug = debug self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}" if self.debug: debug_folder = Path(__file__).resolve().parent.parent / "debug" self.test_incoming_wav = soundfile.SoundFile( debug_folder / f"{self.debug_ts}_test_incoming.wav", mode="w+", format="WAV", subtype="PCM_16", samplerate=self.incoming_sample_rate, channels=1, ) self.get_states_root().test_input_segments_wav = soundfile.SoundFile( debug_folder / f"{self.debug_ts}_test_input_segments.wav", mode="w+", format="WAV", samplerate=MODEL_SAMPLE_RATE, channels=1, ) def get_states_root(self) -> AgentStates: if isinstance(self.agent, TreeAgentPipeline): # self.states is a dict return self.states[self.agent.source_module] else: # self.states is a list return self.states[0] def reset_states(self): if isinstance(self.agent, TreeAgentPipeline): states_iter = self.states.values() else: states_iter = self.states for state in states_iter: state.reset() def debug_log(self, *args): if self.debug: logger.info(*args) def process_incoming_bytes(self, incoming_bytes, target_language, sample_rate): # TODO: currently just taking sample rate here, refactor sample rate # bytes is 16bit signed int self.incoming_sample_rate = sample_rate segment, sr = self._preprocess_wav(incoming_bytes) segment = SpeechSegment( content=segment, sample_rate=sr, tgt_lang=target_language ) # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16) self.input_queue.put_nowait(segment) def get_input_segment(self): if self.input_queue.empty(): return None chunk = self.input_queue.get_nowait() self.input_queue.task_done() return chunk def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]: segment, sample_rate = soundfile.read( io.BytesIO(data), dtype="float32", always_2d=True, frames=-1, start=0, format="RAW", subtype="PCM_16", samplerate=self.incoming_sample_rate, channels=1, ) if self.debug: self.test_incoming_wav.seek(0, soundfile.SEEK_END) self.test_incoming_wav.write(segment) segment = segment.T segment, new_sample_rate = convert_waveform( segment, sample_rate, normalize_volume=False, to_mono=True, to_sample_rate=MODEL_SAMPLE_RATE, ) assert MODEL_SAMPLE_RATE == new_sample_rate segment = segment.squeeze(axis=0) return segment, new_sample_rate def process_pipeline_impl(self, input_segment): try: with torch.no_grad(): output_segment = OutputSegments( self.agent.pushpop(input_segment, self.states) ) if ( self.get_states_root().first_input_ts is not None and self.first_input_ts is None ): # TODO: this is hacky self.first_input_ts = self.get_states_root().first_input_ts if not output_segment.is_empty: self.output_queue.put_nowait(output_segment) if output_segment.finished: self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.") self.reset_states() if self.debug: # when we rebuild states, this value is reset to whatever # is in the system dir config, which defaults debug=False. self.get_states_root().debug = True except Exception as e: logger.error(f"Got exception while processing pipeline: {e}") traceback.print_exc() return input_segment def process_pipeline_loop(self): if self.close: return # closes the thread self.debug_log("processing_pipeline") while not self.close: input_segment = self.get_input_segment() if input_segment is None: # if self.get_states_root().is_fresh_state: # TODO: this is hacky # time.sleep(0.3) # else: time.sleep(0.03) continue self.process_pipeline_impl(input_segment) self.debug_log("finished processing_pipeline") def process_pipeline_once(self): if self.close: return self.debug_log("processing pipeline once") input_segment = self.get_input_segment() if input_segment is None: return self.process_pipeline_impl(input_segment) self.debug_log("finished processing_pipeline_once") def get_output_segment(self): if self.output_queue.empty(): return None output_chunk = self.output_queue.get_nowait() self.output_queue.task_done() return output_chunk def start(self): self.debug_log("starting transcoder in a thread") threading.Thread(target=self.process_pipeline_loop).start() def first_translation_time(self): return round((self.first_output_ts - self.first_input_ts) / 1000, 2) def get_buffered_output(self) -> SpeechAndTextOutput: now = time.time() * 1000 self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}") while not self.output_queue.empty(): tmp_out = self.get_output_segment() if tmp_out and tmp_out.compute_length(self.g2p) > 0: if len(self.output_buffer) == 0: self.last_output_ts = now self._populate_output_buffer(tmp_out) self._increment_output_buffer_size(tmp_out) if tmp_out.finished: self.debug_log("tmp_out.finished") res = self._gather_output_buffer_data(final=True) self.debug_log(f"gathered output data: {res}") self.output_buffer = [] self.increment_output_buffer_size = 0 self.last_output_ts = now self.first_output_ts = now return res else: self.debug_log("tmp_out.compute_length is not > 0") if len(self.output_buffer) > 0 and ( now - self.last_output_ts >= self.output_buffer_idle_ms or self.output_buffer_cur_size >= self.output_buffer_size_limit ): self.debug_log( "[get_buffered_output] output_buffer is not empty. getting res to return." ) self.last_output_ts = now res = self._gather_output_buffer_data(final=False) self.debug_log(f"gathered output data: {res}") self.output_buffer = [] self.output_buffer_phoneme_count = 0 self.first_output_ts = now return res else: self.debug_log("[get_buffered_output] output_buffer is empty...") return None def _gather_output_buffer_data(self, final): output = SpeechAndTextOutput() output.final = final output = OutputSegments.join_output_buffer(self.output_buffer, output) return output def _increment_output_buffer_size(self, segment: OutputSegments): self.output_buffer_cur_size += segment.compute_length(self.g2p) def _populate_output_buffer(self, segment: OutputSegments): self.output_buffer.append(segment.segments) def _compute_phoneme_count(self, string: str) -> int: return len([x for x in self.g2p(string) if x != " "])