Liyan06
commited on
Commit
·
2d158d3
1
Parent(s):
93e9112
add entity highlight
Browse files- handler.py +22 -3
handler.py
CHANGED
@@ -3,6 +3,16 @@ from web_retrieval import *
|
|
3 |
from nltk.tokenize import sent_tokenize
|
4 |
import evaluate
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
|
8 |
'''
|
@@ -19,7 +29,13 @@ def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
|
|
19 |
ranked_docs, scores = zip(*ranked_doc_score)
|
20 |
|
21 |
return ranked_docs, scores
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
class EndpointHandler():
|
25 |
def __init__(self, path="./"):
|
@@ -30,6 +46,7 @@ class EndpointHandler():
|
|
30 |
def __call__(self, data):
|
31 |
|
32 |
claim = data['inputs']['claims'][0]
|
|
|
33 |
|
34 |
# Using user-provided document to do fact-checking
|
35 |
if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
|
@@ -48,7 +65,8 @@ class EndpointHandler():
|
|
48 |
outputs = {
|
49 |
'ranked_docs': ranked_docs,
|
50 |
'scores': scores,
|
51 |
-
'span_to_highlight': span_to_highlight
|
|
|
52 |
}
|
53 |
|
54 |
else:
|
@@ -69,7 +87,8 @@ class EndpointHandler():
|
|
69 |
'ranked_docs': ranked_docs,
|
70 |
'scores': scores,
|
71 |
'ranked_urls': ranked_urls,
|
72 |
-
'span_to_highlight': span_to_highlight
|
|
|
73 |
}
|
74 |
|
75 |
return outputs
|
|
|
3 |
from nltk.tokenize import sent_tokenize
|
4 |
import evaluate
|
5 |
|
6 |
+
import spacy
|
7 |
+
from spacy.cli import download
|
8 |
+
|
9 |
+
try:
|
10 |
+
nlp = spacy.load("en_core_web_lg")
|
11 |
+
except:
|
12 |
+
# If loading fails, download the model
|
13 |
+
download("en_core_web_lg")
|
14 |
+
nlp = spacy.load("en_core_web_lg")
|
15 |
+
|
16 |
|
17 |
def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
|
18 |
'''
|
|
|
29 |
ranked_docs, scores = zip(*ranked_doc_score)
|
30 |
|
31 |
return ranked_docs, scores
|
32 |
+
|
33 |
+
|
34 |
+
def extract_entities(text):
|
35 |
+
text = nlp(text)
|
36 |
+
ents = list({ent.text for ent in text.ents})
|
37 |
+
return ents
|
38 |
+
|
39 |
|
40 |
class EndpointHandler():
|
41 |
def __init__(self, path="./"):
|
|
|
46 |
def __call__(self, data):
|
47 |
|
48 |
claim = data['inputs']['claims'][0]
|
49 |
+
ents = extract_entities(claim)
|
50 |
|
51 |
# Using user-provided document to do fact-checking
|
52 |
if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
|
|
|
65 |
outputs = {
|
66 |
'ranked_docs': ranked_docs,
|
67 |
'scores': scores,
|
68 |
+
'span_to_highlight': span_to_highlight,
|
69 |
+
'entities': ents
|
70 |
}
|
71 |
|
72 |
else:
|
|
|
87 |
'ranked_docs': ranked_docs,
|
88 |
'scores': scores,
|
89 |
'ranked_urls': ranked_urls,
|
90 |
+
'span_to_highlight': span_to_highlight,
|
91 |
+
'entities': ents
|
92 |
}
|
93 |
|
94 |
return outputs
|