Spaces:
Runtime error
Runtime error
import nltk | |
import ssl | |
import re | |
try: | |
_create_unverified_https_context = ssl._create_unverified_context | |
except AttributeError: | |
pass | |
else: | |
ssl._create_default_https_context = _create_unverified_https_context | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
from transformers import BartTokenizer, PegasusTokenizer | |
from transformers import BartForConditionalGeneration, PegasusForConditionalGeneration | |
from tqdm.notebook import tqdm | |
class Abstractive_Summarization_Model: | |
def __init__(self): | |
self.text = None | |
self.IS_CNNDM = True # whether to use CNNDM dataset or XSum dataset | |
self.LOWER = False | |
self.max_length = 1024 #if self.IS_CNNDM else 512 | |
self.model, self.tokenizer = self.load_model() | |
def load_model(self): | |
# Load our model checkpoints | |
print('[INFO]: Loading model ...') | |
if self.IS_CNNDM: | |
model = BartForConditionalGeneration.from_pretrained('Yale-LILY/brio-cnndm-uncased') | |
tokenizer = BartTokenizer.from_pretrained('Yale-LILY/brio-cnndm-uncased') | |
else: | |
model = PegasusForConditionalGeneration.from_pretrained('Yale-LILY/brio-xsum-cased') | |
tokenizer = PegasusTokenizer.from_pretrained('Yale-LILY/brio-xsum-cased') | |
print('[INFO]: Model Successfully Loaded :)') | |
return model, tokenizer | |
def summarize(self, text): | |
# generation example | |
if self.LOWER: | |
article = text.lower() | |
else: | |
article = text | |
inputs = self.tokenizer([article], max_length=self.max_length, return_tensors="pt", truncation=True) | |
# Generate Summary | |
summary_ids = self.model.generate(inputs["input_ids"]) | |
return self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |