|
import gradio as gr |
|
from transformers import pipeline, AutoTokenizer |
|
from sentence_transformers import SentenceTransformer, util |
|
import nltk |
|
from nltk.tokenize import sent_tokenize |
|
|
|
|
|
nltk.download('punkt') |
|
|
|
|
|
translation_models = { |
|
'Vietnamese': "Helsinki-NLP/opus-mt-en-vi", |
|
'Japanese': "Helsinki-NLP/opus-mt-en-jap", |
|
'Thai': "Helsinki-NLP/opus-mt-en-tha", |
|
'Spanish': "Helsinki-NLP/opus-mt-en-es" |
|
} |
|
|
|
|
|
summarization_models = { |
|
'Scientific': "sshleifer/distilbart-cnn-12-6", |
|
'Literature': "t5-small" |
|
} |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-small") |
|
|
|
|
|
summarizers = {model: pipeline("summarization", model=model) for model in summarization_models.values()} |
|
|
|
|
|
def get_translator(language): |
|
model_name = translation_models.get(language) |
|
if model_name: |
|
return pipeline("translation", model=model_name) |
|
return None |
|
|
|
|
|
def split_text(text, max_tokens=1024): |
|
inputs = tokenizer(text, return_tensors='pt', truncation=False) |
|
input_ids = inputs['input_ids'][0] |
|
total_tokens = len(input_ids) |
|
|
|
chunks = [] |
|
start = 0 |
|
while start < total_tokens: |
|
end = min(start + max_tokens, total_tokens) |
|
chunk_ids = input_ids[start:end] |
|
chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True) |
|
chunks.append(chunk_text) |
|
start = end |
|
|
|
return chunks |
|
|
|
|
|
def summarize_text(text, model): |
|
chunks = split_text(text) |
|
summaries = [] |
|
for chunk in chunks: |
|
try: |
|
summary = summarizers[model](chunk, max_length=150, min_length=40, do_sample=False)[0]['summary_text'] |
|
summaries.append(summary) |
|
except Exception as e: |
|
print(f"Error summarizing chunk: {chunk}\nError: {e}") |
|
return " ".join(summaries) |
|
|
|
|
|
def translate_text(text, language): |
|
translator = get_translator(language) |
|
if translator: |
|
try: |
|
translated_text = translator(text)[0]['translation_text'] |
|
return translated_text |
|
except Exception as e: |
|
print(f"Error translating text: {text}\nError: {e}") |
|
return text |
|
return text |
|
|
|
def process_text(input_text, model, language): |
|
print(f"Input text: {input_text[:500]}...") |
|
summary = summarize_text(input_text, model) |
|
print(f"Summary: {summary[:500]}...") |
|
bullet_points = generate_bullet_points(summary) |
|
print(f"Bullet Points: {bullet_points}") |
|
translated_text = translate_text(bullet_points, language) |
|
print(f"Translated Text: {translated_text}") |
|
return bullet_points, translated_text |
|
|
|
def generate_bullet_points(summary): |
|
print("Summary Text:", summary) |
|
|
|
|
|
sentences = sent_tokenize(summary) |
|
key_sentences = sentences[:3] |
|
|
|
bullet_points = "\n".join(f"- {sentence}" for sentence in key_sentences) |
|
print("Bullet Points:", bullet_points) |
|
|
|
return bullet_points |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_text, |
|
inputs=[ |
|
gr.Textbox(label="Input Text", placeholder="Paste your text here...", lines=10), |
|
gr.Radio(choices=["Scientific", "Literature"], label="Summarization Model"), |
|
gr.Dropdown(choices=["Vietnamese", "Japanese", "Thai", "Spanish"], label="Translate to", value="Vietnamese") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Bullet Points", lines=10), |
|
gr.Textbox(label="Translated Bullet Points", lines=10) |
|
], |
|
title="Text to Bullet Points and Translation", |
|
description="Paste any text, choose the summarization model, and optionally translate the bullet points into Vietnamese, Japanese, Thai, or Spanish." |
|
) |
|
|
|
iface.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|