File size: 6,331 Bytes
cfa1e90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import logging
import os
import re
from functools import lru_cache
from urllib.parse import unquote

import streamlit as st
from codetiming import Timer
from transformers import pipeline
from preprocess import ArabertPreprocessor
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import GPT2TokenizerFast, BertTokenizer
import tokenizers

logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

logger.info("Loading models...")
reader_time = Timer("loading", text="Time: {:.2f}", logger=logging.info)
reader_time.start()
#####

@st.cache(ttl=24*3600, hash_funcs={AutoModelForSeq2SeqLM: lambda _: None})
def load_seq2seqLM_model(model_path): #This function is not used
    return AutoModelForSeq2SeqLM.from_pretrained(model_path)
@st.cache(ttl=24*3600, hash_funcs={AutoModelForCausalLM: lambda _: None})
def load_casualLM_model(model_path):
    return AutoModelForCausalLM.from_pretrained(model_path)

@st.cache(ttl=24*3600, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_autotokenizer_model(tokenizer_path):
    return AutoTokenizer.from_pretrained(tokenizer_path)
@st.cache(ttl=24*3600, hash_funcs={BertTokenizer: lambda _: None})
def load_berttokenizer_model(tokenizer_path):
    return BertTokenizer.from_pretrained(tokenizer_path)
@st.cache(ttl=24*3600, hash_funcs={GPT2TokenizerFast: lambda _: None})
def load_gpt2tokenizer_model(tokenizer_path):
    return GPT2TokenizerFast.from_pretrained(tokenizer_path)

@st.cache(ttl=24*3600, allow_output_mutation=True, hash_funcs={pipeline: lambda _: None, tokenizers.Tokenizer: lambda _: None})
def load_generation_pipeline(model_path):
    if model_path == "malmarjeh/mbert2mbert-arabic-text-summarization":
        tokenizer = load_berttokenizer_model(model_path)
    else:
        tokenizer = load_autotokenizer_model(model_path)
    #model = load_seq2seqLM_model(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    return pipeline("text2text-generation",model=model,tokenizer=tokenizer)

@st.cache(ttl=24*3600, hash_funcs={ArabertPreprocessor: lambda _: None})
def load_preprocessor():
    return ArabertPreprocessor(model_name="")

tokenizer = load_autotokenizer_model("malmarjeh/bert2bert")
generation_pipeline = load_generation_pipeline("malmarjeh/bert2bert")
logger.info("BERT2BERT is loaded")

tokenizer_mbert = load_berttokenizer_model("malmarjeh/mbert2mbert-arabic-text-summarization")
generation_pipeline_mbert = load_generation_pipeline("malmarjeh/mbert2mbert-arabic-text-summarization")
logger.info("mBERT2mBERT is loaded")

tokenizer_t5 = load_autotokenizer_model("malmarjeh/t5-arabic-text-summarization")
generation_pipeline_t5 = load_generation_pipeline("malmarjeh/t5-arabic-text-summarization")
logger.info("T5 is loaded")

tokenizer_transformer = load_autotokenizer_model("malmarjeh/transformer")
generation_pipeline_transformer = load_generation_pipeline("malmarjeh/transformer")
logger.info("Transformer is loaded")

tokenizer_gpt2 = load_gpt2tokenizer_model("aubmindlab/aragpt2-base")
model_gpt2 = load_casualLM_model("malmarjeh/gpt2")
logger.info("GPT-2 is loaded")

reader_time.stop()

preprocessor = load_preprocessor()

logger.info("Finished loading the models...")
logger.info(f"Time spent loading: {reader_time.last}")

@lru_cache(maxsize=200)
def get_results(text, model_selected, num_beams, length_penalty):
    logger.info("\n=================================================================")
    logger.info(f"Text: {text}")
    logger.info(f"model_selected: {model_selected}")
    logger.info(f"length_penalty: {length_penalty}")
    reader_time = Timer("summarize", text="Time: {:.2f}", logger=logging.info)
    reader_time.start()
    if model_selected == 'GPT-2':
        number_of_tokens_limit = 80
    else:
        number_of_tokens_limit = 150
    text = preprocessor.preprocess(text)
    logger.info(f"input length: {len(text.split())}")
    text = ' '.join(text.split()[:number_of_tokens_limit])
    
    if model_selected == 'Transformer':
        result = generation_pipeline_transformer(text,
            pad_token_id=tokenizer_transformer.eos_token_id,
            num_beams=num_beams,
            repetition_penalty=3.0,
            max_length=200,
            length_penalty=length_penalty,
            no_repeat_ngram_size = 3)[0]['generated_text']
        logger.info('Transformer')
    elif model_selected == 'GPT-2':
        text_processed = '\n النص: ' + text + ' \n الملخص: \n '
        tokenizer_gpt2.add_special_tokens({'pad_token': '<pad>'})
        text_tokens = tokenizer_gpt2.batch_encode_plus([text_processed], return_tensors='pt', padding='max_length', max_length=100)
        output_ = model_gpt2.generate(input_ids=text_tokens['input_ids'],repetition_penalty=3.0, num_beams=num_beams, max_length=140, pad_token_id=2, eos_token_id=0, bos_token_id=10611)
        result = tokenizer_gpt2.decode(output_[0][100:], skip_special_tokens=True).strip()
        logger.info('GPT-2')
    elif model_selected == 'mBERT2mBERT':
        result = generation_pipeline_mbert(text,
            pad_token_id=tokenizer_mbert.eos_token_id,
            num_beams=num_beams,
            repetition_penalty=3.0,
            max_length=200,
            length_penalty=length_penalty,
            no_repeat_ngram_size = 3)[0]['generated_text']
        logger.info('mBERT')
    elif model_selected == 'T5':
        result = generation_pipeline_t5(text,
            pad_token_id=tokenizer_t5.eos_token_id,
            num_beams=num_beams,
            repetition_penalty=3.0,
            max_length=200,
            length_penalty=length_penalty,
            no_repeat_ngram_size = 3)[0]['generated_text']
        logger.info('t5')
    elif model_selected == 'BERT2BERT':
        result = generation_pipeline(text,
            pad_token_id=tokenizer.eos_token_id,
            num_beams=num_beams,
            repetition_penalty=3.0,
            max_length=200,
            length_penalty=length_penalty,
            no_repeat_ngram_size = 3)[0]['generated_text']
        logger.info('bert2bert')
    else:
        result = "الرجاء اختيار نموذج"

    reader_time.stop()
    logger.info(f"Time spent summarizing: {reader_time.last}")

    return result


if __name__ == "__main__":
    results_dict = ""