64FC's picture
Change Vits tokenizer to Facebook's
6a07d2a
raw
history blame
No virus
2.82 kB
import torch
from transformers import pipeline, VitsModel, VitsTokenizer
import numpy as np
import gradio as gr
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load Whisper-small
#pipe = pipeline("automatic-speech-recognition",
# model="openai/whisper-small",
# device=device
#)
# Load Distil-Whisper-large
pipe = pipeline("automatic-speech-recognition",
model="distil-whisper/distil-large-v2",
device=device
)
# Load the model checkpoint and tokenizer
#model = VitsModel.from_pretrained("Matthijs/mms-tts-fra")
#tokenizer = VitsTokenizer.from_pretrained("Matthijs/mms-tts-fra")
model = VitsModel.from_pretrained("facebook/mms-tts-fra")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-fra")
# Define a function to translate an audio, in French here
def translate(audio):
outputs = pipe(audio, max_new_tokens=256,
generate_kwargs={"task": "transcribe", "language": "fr"})
return outputs["text"]
# Define function to generate the waveform output
def synthesise(text):
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model(input_ids)
return outputs.audio[0]
# Define the pipeline
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
synthesised_speech = (
synthesised_speech.numpy() * 32767).astype(np.int16)
return 16000, synthesised_speech
# Define the title etc
title = "Cascaded STST"
description = """
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in French. Demo uses OpenAI's [Whisper Small](https://huggingface.co/openai/whisper-small) model for speech translation, and Facebook's
[MMS TTS](https://huggingface.co/facebook/mms-tts) model, finetuned by [Matthijs](https://huggingface.co/Matthijs), for text-to-speech:
![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
"""
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(source="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
examples=[["./example.wav"]],
title=title,
description=description,
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch()