kapllan's picture
Upload app.py
671198b verified
raw
history blame contribute delete
No virus
6.38 kB
import fasttext
import gradio as gr
import joblib
import json as js
import omikuji
import os
import re
from collections import defaultdict
from huggingface_hub import snapshot_download
from typing import List, Tuple, Dict
from install_packages import download_model
download_model('https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin', 'lid.176.bin')
# Download the model files from Hugging Face
for repo_id in ['kapllan/omikuji-bonsai-parliament-de-spacy', 'kapllan/omikuji-bonsai-parliament-fr-spacy',
'kapllan/omikuji-bonsai-parliament-it-spacy']:
if not os.path.exists(repo_id):
os.makedirs(repo_id)
model_dir = snapshot_download(repo_id=repo_id, local_dir=repo_id)
lang_model = fasttext.load_model('lid.176.bin')
with open('./id2label.json', 'r') as f:
id2label = js.load(f)
with open('topics_hierarchy.json', 'r') as f:
topics_hierarchy = js.load(f)
def map_language(language: str) -> str:
language_mapping = {'de': 'German',
'it': 'Italian',
'fr': 'French'}
if language in language_mapping.keys():
return language_mapping[language]
else:
return language
def find_model(language: str):
vectorizer, model = None, None
if language in ['de', 'fr', 'it']:
path_to_vectorizer = f'./kapllan/omikuji-bonsai-parliament-{language}-spacy/vectorizer'
path_to_model = f'./kapllan/omikuji-bonsai-parliament-{language}-spacy/omikuji-model'
vectorizer = joblib.load(path_to_vectorizer)
model = omikuji.Model.load(path_to_model)
return vectorizer, model
def predict_lang(text: str) -> str:
text = re.sub(r'\n', '', text) # Remove linebreaks because fasttext cannot process that otherwise
predictions = lang_model.predict(text, k=1) # returns top 2 matching languages
language = predictions[0][0] # returns top 2 matching languages
language = re.sub(r'__label__', '', language) # returns top 2 matching languages
return language
def predict_topic(text: str) -> [List[str], str]:
results = []
language = predict_lang(text)
vectorizer, model = find_model(language)
language = map_language(language)
if vectorizer is not None:
texts = [text]
vector = vectorizer.transform(texts)
for row in vector:
if row.nnz == 0: # All zero vector, empty result
continue
feature_values = [(col, row[0, col]) for col in row.nonzero()[1]]
for subj_id, score in model.predict(feature_values, top_k=1000):
results.append((id2label[str(subj_id)], score))
return results, language
def get_row_color(type: str):
if 'main' in type.lower():
return 'background-color: darkgrey;'
if 'sub' in type.lower():
return 'background-color: lightgrey;'
def generate_html_table(topics: List[Tuple[str, str, float]]):
html = '<table style="width:100%; border: 1px solid black; border-collapse: collapse;">'
html += '<tr><th>Type</th><th>Topic</th><th>Score</th></tr>'
for type, topic, score in topics:
color = get_row_color(type)
topic = f"<strong>{topic}</strong>" if 'main' in type.lower() else topic
type = f"<strong>{type}</strong>" if 'main' in type.lower() else type
score = f"<strong>{score}</strong>" if 'main' in type.lower() else score
html += f'<tr style="{color}"><td>{type}</td><td>{topic}</td><td>{score}</td></tr>'
html += '</table>'
return html
def restructure_topics(topics: List[Tuple[str, float]]) -> List[Tuple[str, str, float]]:
topics = [(str(x[0]).lower(), x[1]) for x in topics]
topics_as_dict = {}
for predicted_topic, score in topics:
if str(predicted_topic).lower() in topics_hierarchy.keys():
topics_as_dict[str(predicted_topic).lower()] = []
for predicted_topic, score in topics:
for main_topic, sub_topics in topics_hierarchy.items():
if main_topic in topics_as_dict.keys() and predicted_topic != main_topic and predicted_topic in sub_topics:
topics_as_dict[main_topic].append(predicted_topic)
topics_restructured = []
for predicted_main_topic, predicted_sub_topics in topics_as_dict.items():
if len(predicted_sub_topics) > 0:
score = [t for t in topics if t[0] == predicted_main_topic][0][1]
topics_restructured.append(
('Main Topic', predicted_main_topic, score))
predicted_sub_topics_with_scores = []
for pst in predicted_sub_topics:
score = [t for t in topics if t[0] == pst][0][1]
entry = ('Sub Topic', pst, score)
if entry not in predicted_sub_topics_with_scores:
predicted_sub_topics_with_scores.append(entry)
for x in predicted_sub_topics_with_scores:
topics_restructured.append(x)
return topics_restructured
def topic_modeling(text: str, threshold: float) -> [List[str], str]:
# Prepare labels and scores for the plot
sorted_topics, language = predict_topic(text)
if len(sorted_topics) > 0 and language in ['German', 'French', 'Italian']:
sorted_topics = [t for t in sorted_topics if t[1] >= threshold]
else:
sorted_topics = []
sorted_topics = restructure_topics(sorted_topics)
sorted_topics = generate_html_table(sorted_topics)
return sorted_topics, language
with gr.Blocks() as iface:
gr.Markdown("# Topic Modeling")
gr.Markdown("Enter a document and get each topic along with its score.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=10, placeholder="Enter a document")
submit_button = gr.Button("Submit")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Score Threshold",
value=0.0)
language_text = gr.Textbox(lines=1, placeholder="Detected language will be shown here...",
interactive=False, label="Detected Language")
with gr.Column():
output_data = gr.HTML()
submit_button.click(topic_modeling, inputs=[input_text, threshold_slider],
outputs=[output_data, language_text])
# Launch the app
iface.launch(share=True)