Spaces:
Build error
Build error
import streamlit as st | |
from transformers import MarianTokenizer, MarianMTModel , BertTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from ar_correct import ar_correct | |
import mishkal.tashkeel | |
from arabert.preprocess import ArabertPreprocessor | |
# Initialize Mishkal vocalizer | |
vocalizer = mishkal.tashkeel.TashkeelClass() | |
# Initialize Marian tokenizer and model for translation | |
mname = "marefa-nlp/marefa-mt-en-ar" | |
tokenizer = MarianTokenizer.from_pretrained(mname) | |
model = MarianMTModel.from_pretrained(mname) | |
# Initialize BERT tokenizer and model for summarization | |
model_name = "malmarjeh/mbert2mbert-arabic-text-summarization" | |
preprocessor = ArabertPreprocessor(model_name="") | |
tokenizer_summarization = BertTokenizer.from_pretrained(model_name) | |
model_summarization = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
pipeline_summarization = pipeline("text2text-generation", model=model_summarization, tokenizer=tokenizer_summarization) | |
def main(): | |
st.title("U3reb Demo") | |
# Text Input | |
input_text = st.text_area("Enter Arabic Text:") | |
# Tokenization | |
st.subheader("Tokenization (Mishkal)") | |
if input_text: | |
text_mishkal = vocalizer.tashkeel(input_text) | |
st.write("Tokenized Text (with diacritics):", text_mishkal) | |
# Translation | |
st.subheader("Translation") | |
if input_text: | |
translated_tokens = model.generate(**tokenizer.prepare_seq2seq_batch([input_text], return_tensors="pt")) | |
translated_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated_tokens] | |
st.write("Translated Text:", translated_text) | |
# Arabic Text Correction | |
st.subheader("Arabic Text Correction (ar_correct)") | |
if input_text: | |
corrected_text = ar_correct(input_text) | |
st.write("Corrected Text:", corrected_text) | |
# Text Summarization | |
st.subheader("Text Summarization") | |
if input_text: | |
preprocessed_text = preprocessor.preprocess(input_text) | |
result = pipeline_summarization(preprocessed_text, | |
pad_token_id=tokenizer_summarization.eos_token_id, | |
num_beams=3, | |
repetition_penalty=3.0, | |
max_length=200, | |
length_penalty=1.0, | |
no_repeat_ngram_size=3)[0]['generated_text'] | |
st.write("Summarized Text:", result) | |
if __name__ == "__main__": | |
main() | |