Spaces:
Runtime error
Runtime error
File size: 2,273 Bytes
cfbfc47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
from transformers import pipeline
from flair.data import Sentence
from flair.models import SequenceTagger
import pickle
class Models:
def pickle_it(self, obj, file_name):
with open(f'{file_name}.pickle', 'wb') as f:
pickle.dump(obj, f)
def unpickle_it(self, file_name):
with open(f'{file_name}.pickle', 'rb') as f:
return pickle.load(f)
def load_trained_models(self, pickle=False):
#NER (dates)
tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
self.ner_dates = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple")
#Zero Shot Classification
self.zero_shot_classifier = pipeline("zero-shot-classification", model='facebook/bart-large-mnli')
# self.zero_shot_classifier = pipeline("zero-shot-classification", model='valhalla/distilbart-mnli-12-6')
# Ner
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
self.ner = pipeline('ner', model=model, tokenizer=tokenizer, grouped_entities=True)
# Pos Tagging
self.tagger = SequenceTagger.load("flair/pos-english-fast")
if pickle:
self.pickle_models()
return self.ner, self.ner_dates, self.zero_shot_classifier, self.tagger
def pickle_models(self):
self.pickle_it(self.ner, "ner")
self.pickle_it(self.zero_shot_classifier, "zero_shot_classifier_6")
self.pickle_it(self.ner_dates, "ner_dates")
self.pickle_it(self.tagger, "pos_tagger_fast")
def load_pickled_models(self):
ner_dates = self.unpickle_it('ner_dates')
ner = self.unpickle_it('ner')
zero_shot_classifier = self.unpickle_it('zero_shot_classifier_6')
tagger = self.unpickle_it("pos_tagger_fast")
return ner_dates, ner, zero_shot_classifier, tagger
def get_flair_sentence(self, sent):
return Sentence(sent) |