voice-queries / app.py
Julien Simon
Replace deprecated append() with concat()
ea11adf
raw
history blame
4.58 kB
import gradio as gr
import nltk
import numpy as np
import pandas as pd
from librosa import load, resample
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
# Constants
filename = "df10k_SP500_2020.csv.zip"
model_name = "sentence-transformers/msmarco-distilbert-base-v4"
max_sequence_length = 512
embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz"
asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en"
# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f"Number of documents: {len(df)}")
nltk.download("punkt")
corpus = []
sentence_count = []
for _, row in df.iterrows():
# We're interested in the 'mdna' column: 'Management discussion and analysis'
sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english")
sentence_count.append(len(sentences))
for _, s in enumerate(sentences):
corpus.append(s)
print(f"Number of sentences: {len(corpus)}")
# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)["arr_0"]
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr = pipeline(
"automatic-speech-recognition", model=asr_model, feature_extractor=asr_model
)
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
output = pd.DataFrame(
columns=["Ticker", "Form type", "Filing date", "Text", "Score"]
)
for hit in hits:
corpus_id = hit["corpus_id"]
# Find source document based on sentence index
count = 0
for idx, c in enumerate(sentence_count):
count += c
if corpus_id > count - 1:
continue
else:
doc = df.iloc[idx]
new_row = {
"Ticker": doc["ticker"],
"Form type": doc["form_type"],
"Filing date": doc["filing_date"],
"Text": corpus[corpus_id][:80],
"Score": "{:.2f}".format(hit["score"]),
}
output = pd.concat([output, pd.DataFrame([new_row])], ignore_index=True)
break
return output
def process(input_selection, query, filepath, hits):
if input_selection == "speech":
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, orig_sr=sampling_rate, target_sr=16000)
text = asr(speech)["text"]
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.Radio(
["text", "speech"], type="value", value="speech", label="Input selection"
)
text_query = gr.Textbox(
lines=1,
label="Text input",
value="The company is under investigation by tax authorities for potential fraud.",
)
mic = gr.Audio(
source="microphone", type="filepath", label="Speech input", optional=True
)
slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits")
# Gradio outputs
speech_query = gr.Textbox(type="text", label="Query string")
results = gr.Dataframe(
type="pandas",
headers=["Ticker", "Form type", "Filing date", "Text", "Score"],
label="Query results",
)
iface = gr.Interface(
theme="huggingface",
description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80",
fn=process,
inputs=[buttons, text_query, mic, slider],
outputs=[speech_query, results],
examples=[
[
"speech",
"Nos ventes internationales ont significativement augmenté.",
"sales_16k_fr.wav",
3,
],
[
"speech",
"Le prix de l'énergie pourrait avoir un impact négatif dans le futur.",
"energy_16k_fr.wav",
3,
],
[
"speech",
"El precio de la energía podría tener un impacto negativo en el futuro.",
"energy_24k_es.wav",
3,
],
[
"speech",
"Mehrere Steuerbehörden untersuchen unser Unternehmen.",
"tax_24k_de.wav",
3,
],
],
)
iface.launch()