Spaces:
Runtime error
Runtime error
File size: 6,331 Bytes
cfa1e90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import logging
import os
import re
from functools import lru_cache
from urllib.parse import unquote
import streamlit as st
from codetiming import Timer
from transformers import pipeline
from preprocess import ArabertPreprocessor
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import GPT2TokenizerFast, BertTokenizer
import tokenizers
logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger.info("Loading models...")
reader_time = Timer("loading", text="Time: {:.2f}", logger=logging.info)
reader_time.start()
#####
@st.cache(ttl=24*3600, hash_funcs={AutoModelForSeq2SeqLM: lambda _: None})
def load_seq2seqLM_model(model_path): #This function is not used
return AutoModelForSeq2SeqLM.from_pretrained(model_path)
@st.cache(ttl=24*3600, hash_funcs={AutoModelForCausalLM: lambda _: None})
def load_casualLM_model(model_path):
return AutoModelForCausalLM.from_pretrained(model_path)
@st.cache(ttl=24*3600, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_autotokenizer_model(tokenizer_path):
return AutoTokenizer.from_pretrained(tokenizer_path)
@st.cache(ttl=24*3600, hash_funcs={BertTokenizer: lambda _: None})
def load_berttokenizer_model(tokenizer_path):
return BertTokenizer.from_pretrained(tokenizer_path)
@st.cache(ttl=24*3600, hash_funcs={GPT2TokenizerFast: lambda _: None})
def load_gpt2tokenizer_model(tokenizer_path):
return GPT2TokenizerFast.from_pretrained(tokenizer_path)
@st.cache(ttl=24*3600, allow_output_mutation=True, hash_funcs={pipeline: lambda _: None, tokenizers.Tokenizer: lambda _: None})
def load_generation_pipeline(model_path):
if model_path == "malmarjeh/mbert2mbert-arabic-text-summarization":
tokenizer = load_berttokenizer_model(model_path)
else:
tokenizer = load_autotokenizer_model(model_path)
#model = load_seq2seqLM_model(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
return pipeline("text2text-generation",model=model,tokenizer=tokenizer)
@st.cache(ttl=24*3600, hash_funcs={ArabertPreprocessor: lambda _: None})
def load_preprocessor():
return ArabertPreprocessor(model_name="")
tokenizer = load_autotokenizer_model("malmarjeh/bert2bert")
generation_pipeline = load_generation_pipeline("malmarjeh/bert2bert")
logger.info("BERT2BERT is loaded")
tokenizer_mbert = load_berttokenizer_model("malmarjeh/mbert2mbert-arabic-text-summarization")
generation_pipeline_mbert = load_generation_pipeline("malmarjeh/mbert2mbert-arabic-text-summarization")
logger.info("mBERT2mBERT is loaded")
tokenizer_t5 = load_autotokenizer_model("malmarjeh/t5-arabic-text-summarization")
generation_pipeline_t5 = load_generation_pipeline("malmarjeh/t5-arabic-text-summarization")
logger.info("T5 is loaded")
tokenizer_transformer = load_autotokenizer_model("malmarjeh/transformer")
generation_pipeline_transformer = load_generation_pipeline("malmarjeh/transformer")
logger.info("Transformer is loaded")
tokenizer_gpt2 = load_gpt2tokenizer_model("aubmindlab/aragpt2-base")
model_gpt2 = load_casualLM_model("malmarjeh/gpt2")
logger.info("GPT-2 is loaded")
reader_time.stop()
preprocessor = load_preprocessor()
logger.info("Finished loading the models...")
logger.info(f"Time spent loading: {reader_time.last}")
@lru_cache(maxsize=200)
def get_results(text, model_selected, num_beams, length_penalty):
logger.info("\n=================================================================")
logger.info(f"Text: {text}")
logger.info(f"model_selected: {model_selected}")
logger.info(f"length_penalty: {length_penalty}")
reader_time = Timer("summarize", text="Time: {:.2f}", logger=logging.info)
reader_time.start()
if model_selected == 'GPT-2':
number_of_tokens_limit = 80
else:
number_of_tokens_limit = 150
text = preprocessor.preprocess(text)
logger.info(f"input length: {len(text.split())}")
text = ' '.join(text.split()[:number_of_tokens_limit])
if model_selected == 'Transformer':
result = generation_pipeline_transformer(text,
pad_token_id=tokenizer_transformer.eos_token_id,
num_beams=num_beams,
repetition_penalty=3.0,
max_length=200,
length_penalty=length_penalty,
no_repeat_ngram_size = 3)[0]['generated_text']
logger.info('Transformer')
elif model_selected == 'GPT-2':
text_processed = '\n النص: ' + text + ' \n الملخص: \n '
tokenizer_gpt2.add_special_tokens({'pad_token': '<pad>'})
text_tokens = tokenizer_gpt2.batch_encode_plus([text_processed], return_tensors='pt', padding='max_length', max_length=100)
output_ = model_gpt2.generate(input_ids=text_tokens['input_ids'],repetition_penalty=3.0, num_beams=num_beams, max_length=140, pad_token_id=2, eos_token_id=0, bos_token_id=10611)
result = tokenizer_gpt2.decode(output_[0][100:], skip_special_tokens=True).strip()
logger.info('GPT-2')
elif model_selected == 'mBERT2mBERT':
result = generation_pipeline_mbert(text,
pad_token_id=tokenizer_mbert.eos_token_id,
num_beams=num_beams,
repetition_penalty=3.0,
max_length=200,
length_penalty=length_penalty,
no_repeat_ngram_size = 3)[0]['generated_text']
logger.info('mBERT')
elif model_selected == 'T5':
result = generation_pipeline_t5(text,
pad_token_id=tokenizer_t5.eos_token_id,
num_beams=num_beams,
repetition_penalty=3.0,
max_length=200,
length_penalty=length_penalty,
no_repeat_ngram_size = 3)[0]['generated_text']
logger.info('t5')
elif model_selected == 'BERT2BERT':
result = generation_pipeline(text,
pad_token_id=tokenizer.eos_token_id,
num_beams=num_beams,
repetition_penalty=3.0,
max_length=200,
length_penalty=length_penalty,
no_repeat_ngram_size = 3)[0]['generated_text']
logger.info('bert2bert')
else:
result = "الرجاء اختيار نموذج"
reader_time.stop()
logger.info(f"Time spent summarizing: {reader_time.last}")
return result
if __name__ == "__main__":
results_dict = ""
|