Spaces:
Running
Running
Upload models.py
Browse files
models.py
CHANGED
@@ -1,49 +1,53 @@
|
|
1 |
import torch
|
2 |
import sentencepiece
|
3 |
-
from transformers import
|
4 |
-
|
5 |
-
import
|
|
|
|
|
|
|
|
|
6 |
class Models():
|
7 |
def __init__(self) -> None:
|
8 |
-
self.template = """
|
9 |
-
A virtual assistant answers questions from a user based on the provided text.
|
10 |
-
USER: Text: {input_text}
|
11 |
-
ASSISTANT: I’ve read this text.
|
12 |
-
USER: What describes {entity_type} in the text?
|
13 |
-
ASSISTANT:
|
14 |
-
"""
|
15 |
self.load_trained_models()
|
16 |
|
17 |
def load_trained_models(self):
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
self.prompt = PromptTemplate(template=self.template, input_variables=["input_text","entity_type"])
|
36 |
-
self.llm_chain = LLMChain(prompt=self.prompt, llm=self.llm)
|
37 |
-
|
38 |
-
def extract_ner(self, context, entity_type):
|
39 |
-
return ast.literal_eval(self.llm_chain.run({"input_text":context,"entity_type":entity_type}))
|
40 |
-
|
41 |
-
def get_ner(self, clean_lines, entity):
|
42 |
-
tokens = []
|
43 |
-
try_num = 0
|
44 |
-
while try_num < 5 and tokens == []:
|
45 |
-
tokens = self.extract_ner(' '.join(clean_lines), entity)
|
46 |
-
if len(tokens) == 0:
|
47 |
-
raise ValueError("Couldnt extract {entity}")
|
48 |
-
return tokens
|
49 |
-
|
|
|
1 |
import torch
|
2 |
import sentencepiece
|
3 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
4 |
+
import os
|
5 |
+
import spacy
|
6 |
+
import spacy_transformers
|
7 |
+
import zipfile
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
class Models():
|
11 |
def __init__(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
self.load_trained_models()
|
13 |
|
14 |
def load_trained_models(self):
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates",use_fast=False)
|
16 |
+
model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
|
17 |
+
self.ner = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple")
|
18 |
+
current_directory = os.path.dirname(os.path.realpath(__file__))
|
19 |
+
custom_ner_path = os.path.join(current_directory, 'spacy_model_v2/output/model-best')
|
20 |
+
if not os.path.exists(custom_ner_path):
|
21 |
+
with zipfile.ZipFile(r"models\prototype\spacy_model_v2.zip", 'r') as zip_ref:
|
22 |
+
# Extract all contents in the current working directory
|
23 |
+
zip_ref.extractall()
|
24 |
+
self.custom_ner = spacy.load(custom_ner_path)
|
25 |
+
|
26 |
+
def extract_ner(self, text):
|
27 |
+
entities = self.ner(text)
|
28 |
+
keys = ['DATE', 'ORG', 'LOC']
|
29 |
+
sort_dict = defaultdict(list)
|
30 |
+
for entity in entities:
|
31 |
+
if entity['score'] > 0.75:
|
32 |
+
sort_dict[entity['entity_group']].append(entity['word'])
|
33 |
+
filtered_dict = {key: value for key, value in sort_dict.items() if key in keys}
|
34 |
+
filtered_dict = defaultdict(list, filtered_dict)
|
35 |
+
return filtered_dict['DATE'], filtered_dict['ORG'], filtered_dict['LOC']
|
36 |
+
def get_ner(self, text, recover_text):
|
37 |
+
dates, companies, locations = self.extract_ner(text)
|
38 |
+
alternative_dates, alternative_companies, alternative_locations = self.extract_ner(recover_text)
|
39 |
|
40 |
+
if dates == [] :
|
41 |
+
dates = alternative_dates
|
42 |
+
if companies == []:
|
43 |
+
companies = alternative_companies
|
44 |
+
if locations == []:
|
45 |
+
locations = alternative_locations
|
46 |
+
return dates, companies, locations
|
47 |
+
def get_custom_ner(self, text):
|
48 |
+
doc = self.custom_ner(text)
|
49 |
+
entities = list(doc.ents)
|
50 |
+
sort_dict = defaultdict(list)
|
51 |
+
for entity in entities:
|
52 |
+
sort_dict[entity.label_].append(entity.text)
|
53 |
+
return sort_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|