Liyan06 commited on
Commit
2d158d3
·
1 Parent(s): 93e9112

add entity highlight

Browse files
Files changed (1) hide show
  1. handler.py +22 -3
handler.py CHANGED
@@ -3,6 +3,16 @@ from web_retrieval import *
3
  from nltk.tokenize import sent_tokenize
4
  import evaluate
5
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
8
  '''
@@ -19,7 +29,13 @@ def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
19
  ranked_docs, scores = zip(*ranked_doc_score)
20
 
21
  return ranked_docs, scores
22
-
 
 
 
 
 
 
23
 
24
  class EndpointHandler():
25
  def __init__(self, path="./"):
@@ -30,6 +46,7 @@ class EndpointHandler():
30
  def __call__(self, data):
31
 
32
  claim = data['inputs']['claims'][0]
 
33
 
34
  # Using user-provided document to do fact-checking
35
  if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
@@ -48,7 +65,8 @@ class EndpointHandler():
48
  outputs = {
49
  'ranked_docs': ranked_docs,
50
  'scores': scores,
51
- 'span_to_highlight': span_to_highlight
 
52
  }
53
 
54
  else:
@@ -69,7 +87,8 @@ class EndpointHandler():
69
  'ranked_docs': ranked_docs,
70
  'scores': scores,
71
  'ranked_urls': ranked_urls,
72
- 'span_to_highlight': span_to_highlight
 
73
  }
74
 
75
  return outputs
 
3
  from nltk.tokenize import sent_tokenize
4
  import evaluate
5
 
6
+ import spacy
7
+ from spacy.cli import download
8
+
9
+ try:
10
+ nlp = spacy.load("en_core_web_lg")
11
+ except:
12
+ # If loading fails, download the model
13
+ download("en_core_web_lg")
14
+ nlp = spacy.load("en_core_web_lg")
15
+
16
 
17
  def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
18
  '''
 
29
  ranked_docs, scores = zip(*ranked_doc_score)
30
 
31
  return ranked_docs, scores
32
+
33
+
34
+ def extract_entities(text):
35
+ text = nlp(text)
36
+ ents = list({ent.text for ent in text.ents})
37
+ return ents
38
+
39
 
40
  class EndpointHandler():
41
  def __init__(self, path="./"):
 
46
  def __call__(self, data):
47
 
48
  claim = data['inputs']['claims'][0]
49
+ ents = extract_entities(claim)
50
 
51
  # Using user-provided document to do fact-checking
52
  if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
 
65
  outputs = {
66
  'ranked_docs': ranked_docs,
67
  'scores': scores,
68
+ 'span_to_highlight': span_to_highlight,
69
+ 'entities': ents
70
  }
71
 
72
  else:
 
87
  'ranked_docs': ranked_docs,
88
  'scores': scores,
89
  'ranked_urls': ranked_urls,
90
+ 'span_to_highlight': span_to_highlight,
91
+ 'entities': ents
92
  }
93
 
94
  return outputs