khrek commited on
Commit
822daa6
1 Parent(s): 547a2b6

Upload models.py

Browse files
Files changed (1) hide show
  1. models.py +45 -41
models.py CHANGED
@@ -1,49 +1,53 @@
1
  import torch
2
  import sentencepiece
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
- from langchain import PromptTemplate, LLMChain, HuggingFacePipeline
5
- import ast
 
 
 
 
6
  class Models():
7
  def __init__(self) -> None:
8
- self.template = """
9
- A virtual assistant answers questions from a user based on the provided text.
10
- USER: Text: {input_text}
11
- ASSISTANT: I’ve read this text.
12
- USER: What describes {entity_type} in the text?
13
- ASSISTANT:
14
- """
15
  self.load_trained_models()
16
 
17
  def load_trained_models(self):
18
- #is it best to keep in memory why not pickle?
19
- checkpoint = "Universal-NER/UniNER-7B-all"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- ner_model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float32, offload_folder="offload", offload_state_dict = True)
22
- tokenizer = AutoTokenizer.from_pretrained("Universal-NER/UniNER-7B-all", use_fast=False, padding="max_length")
23
- hf_pipeline = pipeline(
24
- "text-generation", #task
25
- model=ner_model,
26
- max_length=1000,
27
- tokenizer=tokenizer,
28
- trust_remote_code=True,
29
- do_sample=True,
30
- top_k=10,
31
- num_return_sequences=1
32
- )
33
-
34
- self.llm = HuggingFacePipeline(pipeline = hf_pipeline, model_kwargs = {'temperature':0})
35
- self.prompt = PromptTemplate(template=self.template, input_variables=["input_text","entity_type"])
36
- self.llm_chain = LLMChain(prompt=self.prompt, llm=self.llm)
37
-
38
- def extract_ner(self, context, entity_type):
39
- return ast.literal_eval(self.llm_chain.run({"input_text":context,"entity_type":entity_type}))
40
-
41
- def get_ner(self, clean_lines, entity):
42
- tokens = []
43
- try_num = 0
44
- while try_num < 5 and tokens == []:
45
- tokens = self.extract_ner(' '.join(clean_lines), entity)
46
- if len(tokens) == 0:
47
- raise ValueError("Couldnt extract {entity}")
48
- return tokens
49
-
 
1
  import torch
2
  import sentencepiece
3
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
4
+ import os
5
+ import spacy
6
+ import spacy_transformers
7
+ import zipfile
8
+ from collections import defaultdict
9
+
10
  class Models():
11
  def __init__(self) -> None:
 
 
 
 
 
 
 
12
  self.load_trained_models()
13
 
14
  def load_trained_models(self):
15
+ tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates",use_fast=False)
16
+ model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
17
+ self.ner = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple")
18
+ current_directory = os.path.dirname(os.path.realpath(__file__))
19
+ custom_ner_path = os.path.join(current_directory, 'spacy_model_v2/output/model-best')
20
+ if not os.path.exists(custom_ner_path):
21
+ with zipfile.ZipFile(r"models\prototype\spacy_model_v2.zip", 'r') as zip_ref:
22
+ # Extract all contents in the current working directory
23
+ zip_ref.extractall()
24
+ self.custom_ner = spacy.load(custom_ner_path)
25
+
26
+ def extract_ner(self, text):
27
+ entities = self.ner(text)
28
+ keys = ['DATE', 'ORG', 'LOC']
29
+ sort_dict = defaultdict(list)
30
+ for entity in entities:
31
+ if entity['score'] > 0.75:
32
+ sort_dict[entity['entity_group']].append(entity['word'])
33
+ filtered_dict = {key: value for key, value in sort_dict.items() if key in keys}
34
+ filtered_dict = defaultdict(list, filtered_dict)
35
+ return filtered_dict['DATE'], filtered_dict['ORG'], filtered_dict['LOC']
36
+ def get_ner(self, text, recover_text):
37
+ dates, companies, locations = self.extract_ner(text)
38
+ alternative_dates, alternative_companies, alternative_locations = self.extract_ner(recover_text)
39
 
40
+ if dates == [] :
41
+ dates = alternative_dates
42
+ if companies == []:
43
+ companies = alternative_companies
44
+ if locations == []:
45
+ locations = alternative_locations
46
+ return dates, companies, locations
47
+ def get_custom_ner(self, text):
48
+ doc = self.custom_ner(text)
49
+ entities = list(doc.ents)
50
+ sort_dict = defaultdict(list)
51
+ for entity in entities:
52
+ sort_dict[entity.label_].append(entity.text)
53
+ return sort_dict