Spaces:
Build error
Build error
fabiochiusano
commited on
Commit
•
0ba9aa2
1
Parent(s):
b7aa506
first commit
Browse files- .gitignore +1 -0
- app.py +186 -0
- kb.py +99 -0
- networks/.DS_Store +0 -0
- networks/network_1_bryant.p +0 -0
- networks/network_1_google.p +0 -0
- networks/network_1_napoleon.p +0 -0
- networks/network_2_crypto.p +0 -0
- networks/network_2_depp.p +0 -0
- networks/network_2_rome.p +0 -0
- networks/network_3_amazon.p +0 -0
- networks/network_3_apple.p +0 -0
- networks/network_3_bryant.p +0 -0
- networks/network_3_google.p +0 -0
- networks/network_3_musk.p +0 -0
- requirements.txt +6 -0
- utils.py +198 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
myvenv
|
app.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import streamlit.components.v1 as components
|
3 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
4 |
+
import utils
|
5 |
+
from kb import KB
|
6 |
+
|
7 |
+
texts = {
|
8 |
+
"Napoleon": "Napoleon Bonaparte (born Napoleone di Buonaparte; 15 August 1769 – 5 May 1821), and later known by his regnal name Napoleon I, was a French military and political leader who rose to prominence during the French Revolution and led several successful campaigns during the Revolutionary Wars. He was the de facto leader of the French Republic as First Consul from 1799 to 1804. As Napoleon I, he was Emperor of the French from 1804 until 1814 and again in 1815. Napoleon's political and cultural legacy has endured, and he has been one of the most celebrated and controversial leaders in world history.",
|
9 |
+
"Kobe Bryant": "Kobe Bean Bryant (August 23, 1978 – January 26, 2020) was an American professional basketball player. A shooting guard, he spent his entire 20-year career with the Los Angeles Lakers in the National Basketball Association (NBA). Widely regarded as one of the greatest basketball players of all time, Bryant won five NBA championships, was an 18-time All-Star, a 15-time member of the All-NBA Team, a 12-time member of the All-Defensive Team, the 2008 NBA Most Valuable Player (MVP), and a two-time NBA Finals MVP. Bryant also led the NBA in scoring twice, and ranks fourth in league all-time regular season and postseason scoring. He was posthumously voted into the Naismith Memorial Basketball Hall of Fame in 2020 and named to the NBA 75th Anniversary Team in 2021.",
|
10 |
+
"Google": "Originally known as BackRub. Google is a search engine that started development in 1996 by Sergey Brin and Larry Page as a research project at Stanford University to find files on the Internet. Larry and Sergey later decided the name of their search engine needed to change and chose Google, which is inspired from the term googol. The company is headquartered in Mountain View, California."
|
11 |
+
}
|
12 |
+
|
13 |
+
urls = {
|
14 |
+
"Crypto": "https://www.investopedia.com/terms/c/cryptocurrency.asp",
|
15 |
+
"Jhonny Depp": "https://www.britannica.com/biography/Johnny-Depp",
|
16 |
+
"Rome": "https://www.timeout.com/rome/things-to-do/best-things-to-do-in-rome"
|
17 |
+
}
|
18 |
+
|
19 |
+
st.header("Extracting a Knowledge Base from text")
|
20 |
+
st_model_load = st.text('Loading NER model... It may take a while.')
|
21 |
+
|
22 |
+
@st.cache(allow_output_mutation=True)
|
23 |
+
def load_model():
|
24 |
+
print("Loading model...")
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
|
26 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
|
27 |
+
print("Model loaded!")
|
28 |
+
return tokenizer, model
|
29 |
+
|
30 |
+
tokenizer, model = load_model()
|
31 |
+
st.success('Model loaded!')
|
32 |
+
st_model_load.text("")
|
33 |
+
|
34 |
+
# sidebar
|
35 |
+
with st.sidebar:
|
36 |
+
st.header("What is a Knowledge Base")
|
37 |
+
st.markdown("A [**Knowledge Base (KB)**](https://en.wikipedia.org/wiki/Knowledge_base) is information stored in structured data, ready to be used for analysis or inference. Usually a KB is stored as a graph (i.e. a [**Knowledge Graph**](https://www.ibm.com/cloud/learn/knowledge-graph)), where nodes are **entities** and edges are **relations** between entities.")
|
38 |
+
st.markdown("_For example, the from the text \"Fabio lives in Italy\" we can extract the relation triplet <Fabio, lives in, Italy>, where \"Fabio\" and \"Italy\" are entities._")
|
39 |
+
st.header("How to build a Knowledge Graph")
|
40 |
+
st.markdown("To build a Knowledge Graph from text, we typically need to perform two steps:\n- Extract entities, a.k.a. **Named Entity Recognition (NER)**, i.e. the nodes.\n- Extract relations between the entities, a.k.a. **Relation Classification (RC)**, i.e. the edges.\nRecently, end-to-end approaches have been proposed to tackle both tasks simultaneously. This task is usually referred to as **Relation Extraction (RE)**. In this demo, an end-to-end model called [**REBEL**](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf) is used.")
|
41 |
+
st.header("How REBEL works")
|
42 |
+
st.markdown("REBEL is a **text2text** model obtained by fine-tuning [**BART**](https://huggingface.co/docs/transformers/model_doc/bart) for translating a raw input sentence containing entities and implicit relations into a set of triplets that explicitly refer to those relations. You can find [REBEL in the Hugging Face Hub](https://huggingface.co/Babelscape/rebel-large).")
|
43 |
+
st.header("Further steps")
|
44 |
+
st.markdown("Even though they are not visualized, the knowledge graph saves information about the provenience of each relation (e.g. from which articles it has been exrtacted and other metadata), along with Wikipedia data about each entity.")
|
45 |
+
st.markdown("Other libraries used:\n- [wikipedia](https://pypi.org/project/wikipedia/): For validating extracted entities checking if they have a corresponding Wikipedia page.\n- [newspaper](https://github.com/codelucas/newspaper): For parsing articles from URLs.\n- [pyvis](https://pyvis.readthedocs.io/en/latest/index.html): For graphs visualizations.\n- [GoogleNews](https://github.com/Iceloof/GoogleNews): For reading Google News latest articles about a topic.")
|
46 |
+
st.header("Considerations")
|
47 |
+
st.markdown("If you look closely at the extracted knowledge graphs, some extracted relations are false. Indeed, relation extraction models are still far from perfect and require further steps in the pipeline to build reliable knowledge graphs. Consider this demo as a starting step!")
|
48 |
+
|
49 |
+
# Choose from where to generate the KB
|
50 |
+
options = [
|
51 |
+
"Text",
|
52 |
+
"Article at URL",
|
53 |
+
"Multiple news articles"
|
54 |
+
]
|
55 |
+
if 'option' not in st.session_state:
|
56 |
+
st.session_state.option = options[0]
|
57 |
+
option = st.selectbox('Build a Knowledge Base from:', options, index=options.index(st.session_state.option))
|
58 |
+
|
59 |
+
text_option, text = None, None
|
60 |
+
url_option, url = None, None
|
61 |
+
news_option = None
|
62 |
+
|
63 |
+
if option == "Text":
|
64 |
+
text_options = [
|
65 |
+
"Napoleon",
|
66 |
+
"Kobe Bryant",
|
67 |
+
"Google",
|
68 |
+
"Free text"
|
69 |
+
]
|
70 |
+
if 'text_option' not in st.session_state or st.session_state.text_option is None:
|
71 |
+
st.session_state.text_option = text_options[0]
|
72 |
+
text_option = st.selectbox('Choose text option:', text_options, index=text_options.index(st.session_state.text_option))
|
73 |
+
|
74 |
+
disabled = False
|
75 |
+
if text_option != "Free text":
|
76 |
+
disabled = True
|
77 |
+
text = texts[text_option]
|
78 |
+
else:
|
79 |
+
if 'text' not in st.session_state:
|
80 |
+
st.session_state.text = ""
|
81 |
+
text = st.session_state.text
|
82 |
+
text = st.text_area('Text:', value=text, height=300, disabled=disabled)
|
83 |
+
elif option == "Article at URL":
|
84 |
+
url_options = [
|
85 |
+
"Crypto",
|
86 |
+
"Jhonny Depp",
|
87 |
+
"Rome",
|
88 |
+
"Free URL"
|
89 |
+
]
|
90 |
+
if 'url_option' not in st.session_state or st.session_state.url_option is None:
|
91 |
+
st.session_state.url_option = url_options[0]
|
92 |
+
url_option = st.selectbox('Choose URL option:', url_options, index=url_options.index(st.session_state.url_option))
|
93 |
+
|
94 |
+
disabled = False
|
95 |
+
if url_option != "Free URL":
|
96 |
+
disabled = True
|
97 |
+
url = urls[url_option]
|
98 |
+
else:
|
99 |
+
if 'url' not in st.session_state:
|
100 |
+
st.session_state.url = ""
|
101 |
+
url = st.session_state.url
|
102 |
+
url = st.text_input('URL:', value=url, disabled=disabled)
|
103 |
+
else:
|
104 |
+
news_options = [
|
105 |
+
"Google",
|
106 |
+
"Apple",
|
107 |
+
"Elon Musk",
|
108 |
+
"Kobe Bryant"
|
109 |
+
]
|
110 |
+
if 'news_option' not in st.session_state or st.session_state.news_option is None:
|
111 |
+
st.session_state.news_option = news_options[0]
|
112 |
+
news_option = st.selectbox('Use articles about:', news_options, index=news_options.index(st.session_state.news_option))
|
113 |
+
|
114 |
+
placeholder = st.empty()
|
115 |
+
|
116 |
+
def generate_kb():
|
117 |
+
st.session_state.option = option
|
118 |
+
st.session_state.text_option = text_option
|
119 |
+
st.session_state.text = text
|
120 |
+
st.session_state.url_option = url_option
|
121 |
+
st.session_state.url = url
|
122 |
+
st.session_state.news_option = news_option
|
123 |
+
|
124 |
+
# load correct kb
|
125 |
+
if option == "Text":
|
126 |
+
if text_option == "Napoleon":
|
127 |
+
kb = utils.load_kb("networks/network_1_napoleon.p")
|
128 |
+
elif text_option == "Kobe Bryant":
|
129 |
+
kb = utils.load_kb("networks/network_1_bryant.p")
|
130 |
+
elif text_option == "Google":
|
131 |
+
kb = utils.load_kb("networks/network_1_google.p")
|
132 |
+
else:
|
133 |
+
kb = utils.from_text_to_kb(text, model, tokenizer, "", verbose=True)
|
134 |
+
elif option == "Article at URL":
|
135 |
+
if url_option == "Crypto":
|
136 |
+
kb = utils.load_kb("networks/network_2_crypto.p")
|
137 |
+
elif url_option == "Jhonny Depp":
|
138 |
+
kb = utils.load_kb("networks/network_2_depp.p")
|
139 |
+
elif url_option == "Rome":
|
140 |
+
kb = utils.load_kb("networks/network_2_rome.p")
|
141 |
+
else:
|
142 |
+
kb = utils.from_url_to_kb(url, model, tokenizer)
|
143 |
+
else:
|
144 |
+
if news_option == "Google":
|
145 |
+
kb = utils.load_kb("networks/network_3_google.p")
|
146 |
+
elif news_option == "Apple":
|
147 |
+
kb = utils.load_kb("networks/network_3_apple.p")
|
148 |
+
elif news_option == "Elon Musk":
|
149 |
+
kb = utils.load_kb("networks/network_3_musk.p")
|
150 |
+
elif news_option == "Kobe Bryant":
|
151 |
+
kb = utils.load_kb("networks/network_3_bryant.p")
|
152 |
+
|
153 |
+
# save chart
|
154 |
+
utils.save_network_html(kb, filename="networks/network.html")
|
155 |
+
st.session_state.kb_chart = "networks/network.html"
|
156 |
+
st.session_state.kb_text = kb.get_textual_representation()
|
157 |
+
|
158 |
+
|
159 |
+
st.session_state.option = option
|
160 |
+
st.session_state.text_option = text_option
|
161 |
+
st.session_state.text = text
|
162 |
+
st.session_state.url_option = url_option
|
163 |
+
st.session_state.url = url
|
164 |
+
st.session_state.news_option = news_option
|
165 |
+
|
166 |
+
button_text = "Show KB"
|
167 |
+
if (option == "Text" and text_option == "Free text") or (option == "Article at URL" and url_option == "Free URL"):
|
168 |
+
button_text = "Generate KB"
|
169 |
+
|
170 |
+
# generate KB button
|
171 |
+
st.button(button_text, on_click=generate_kb)
|
172 |
+
|
173 |
+
# kb chart session state
|
174 |
+
if 'kb_chart' not in st.session_state:
|
175 |
+
st.session_state.kb_chart = None
|
176 |
+
if 'kb_text' not in st.session_state:
|
177 |
+
st.session_state.kb_text = None
|
178 |
+
|
179 |
+
# show graph
|
180 |
+
if st.session_state.kb_chart:
|
181 |
+
with st.container():
|
182 |
+
st.subheader("Generated KB")
|
183 |
+
st.markdown("*You can interact with the graph and zoom.*")
|
184 |
+
html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read()
|
185 |
+
components.html(html_source_code, width=700, height=700)
|
186 |
+
st.markdown(st.session_state.kb_text)
|
kb.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wikipedia
|
2 |
+
|
3 |
+
class KB():
|
4 |
+
def __init__(self):
|
5 |
+
self.entities = {} # { entity_title: {...} }
|
6 |
+
self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
|
7 |
+
# meta: { article_url: { spans: [...] } } ]
|
8 |
+
self.sources = {} # { article_url: {...} }
|
9 |
+
|
10 |
+
def merge_with_kb(self, kb2):
|
11 |
+
for r in kb2.relations:
|
12 |
+
article_url = list(r["meta"].keys())[0]
|
13 |
+
source_data = kb2.sources[article_url]
|
14 |
+
self.add_relation(r, source_data["article_title"],
|
15 |
+
source_data["article_publish_date"])
|
16 |
+
|
17 |
+
def are_relations_equal(self, r1, r2):
|
18 |
+
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
|
19 |
+
|
20 |
+
def exists_relation(self, r1):
|
21 |
+
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
|
22 |
+
|
23 |
+
def merge_relations(self, r2):
|
24 |
+
r1 = [r for r in self.relations
|
25 |
+
if self.are_relations_equal(r2, r)][0]
|
26 |
+
|
27 |
+
# if different article
|
28 |
+
article_url = list(r2["meta"].keys())[0]
|
29 |
+
if article_url not in r1["meta"]:
|
30 |
+
r1["meta"][article_url] = r2["meta"][article_url]
|
31 |
+
|
32 |
+
# if existing article
|
33 |
+
else:
|
34 |
+
spans_to_add = [span for span in r2["meta"][article_url]["spans"]
|
35 |
+
if span not in r1["meta"][article_url]["spans"]]
|
36 |
+
r1["meta"][article_url]["spans"] += spans_to_add
|
37 |
+
|
38 |
+
def get_wikipedia_data(self, candidate_entity):
|
39 |
+
try:
|
40 |
+
page = wikipedia.page(candidate_entity, auto_suggest=False)
|
41 |
+
entity_data = {
|
42 |
+
"title": page.title,
|
43 |
+
"url": page.url,
|
44 |
+
"summary": page.summary
|
45 |
+
}
|
46 |
+
return entity_data
|
47 |
+
except:
|
48 |
+
return None
|
49 |
+
|
50 |
+
def add_entity(self, e):
|
51 |
+
self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
|
52 |
+
|
53 |
+
def add_relation(self, r, article_title, article_publish_date):
|
54 |
+
# check on wikipedia
|
55 |
+
candidate_entities = [r["head"], r["tail"]]
|
56 |
+
entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
|
57 |
+
|
58 |
+
# if one entity does not exist, stop
|
59 |
+
if any(ent is None for ent in entities):
|
60 |
+
return
|
61 |
+
|
62 |
+
# manage new entities
|
63 |
+
for e in entities:
|
64 |
+
self.add_entity(e)
|
65 |
+
|
66 |
+
# rename relation entities with their wikipedia titles
|
67 |
+
r["head"] = entities[0]["title"]
|
68 |
+
r["tail"] = entities[1]["title"]
|
69 |
+
|
70 |
+
# add source if not in kb
|
71 |
+
article_url = list(r["meta"].keys())[0]
|
72 |
+
if article_url not in self.sources:
|
73 |
+
self.sources[article_url] = {
|
74 |
+
"article_title": article_title,
|
75 |
+
"article_publish_date": article_publish_date
|
76 |
+
}
|
77 |
+
|
78 |
+
# manage new relation
|
79 |
+
if not self.exists_relation(r):
|
80 |
+
self.relations.append(r)
|
81 |
+
else:
|
82 |
+
self.merge_relations(r)
|
83 |
+
|
84 |
+
def get_textual_representation(self):
|
85 |
+
res = ""
|
86 |
+
res += "### Entities\n"
|
87 |
+
for e in self.entities.items():
|
88 |
+
# shorten summary
|
89 |
+
e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
|
90 |
+
res += f"- {e_temp}\n"
|
91 |
+
res += "\n"
|
92 |
+
res += "### Relations\n"
|
93 |
+
for r in self.relations:
|
94 |
+
res += f"- {r}\n"
|
95 |
+
res += "\n"
|
96 |
+
res += "### Sources\n"
|
97 |
+
for s in self.sources.items():
|
98 |
+
res += f"- {s}\n"
|
99 |
+
return res
|
networks/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
networks/network_1_bryant.p
ADDED
Binary file (20.9 kB). View file
|
|
networks/network_1_google.p
ADDED
Binary file (11.2 kB). View file
|
|
networks/network_1_napoleon.p
ADDED
Binary file (11.9 kB). View file
|
|
networks/network_2_crypto.p
ADDED
Binary file (37.7 kB). View file
|
|
networks/network_2_depp.p
ADDED
Binary file (7.83 kB). View file
|
|
networks/network_2_rome.p
ADDED
Binary file (4.92 kB). View file
|
|
networks/network_3_amazon.p
ADDED
Binary file (153 kB). View file
|
|
networks/network_3_apple.p
ADDED
Binary file (227 kB). View file
|
|
networks/network_3_bryant.p
ADDED
Binary file (185 kB). View file
|
|
networks/network_3_google.p
ADDED
Binary file (190 kB). View file
|
|
networks/network_3_musk.p
ADDED
Binary file (113 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
pyvis
|
4 |
+
GoogleNews
|
5 |
+
newspaper3k
|
6 |
+
wikipedia
|
utils.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyvis.network import Network
|
2 |
+
from GoogleNews import GoogleNews
|
3 |
+
from newspaper import Article, ArticleException
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
from kb import KB
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
def extract_relations_from_model_output(text):
|
10 |
+
relations = []
|
11 |
+
relation, subject, relation, object_ = '', '', '', ''
|
12 |
+
text = text.strip()
|
13 |
+
current = 'x'
|
14 |
+
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
|
15 |
+
for token in text_replaced.split():
|
16 |
+
if token == "<triplet>":
|
17 |
+
current = 't'
|
18 |
+
if relation != '':
|
19 |
+
relations.append({
|
20 |
+
'head': subject.strip(),
|
21 |
+
'type': relation.strip(),
|
22 |
+
'tail': object_.strip()
|
23 |
+
})
|
24 |
+
relation = ''
|
25 |
+
subject = ''
|
26 |
+
elif token == "<subj>":
|
27 |
+
current = 's'
|
28 |
+
if relation != '':
|
29 |
+
relations.append({
|
30 |
+
'head': subject.strip(),
|
31 |
+
'type': relation.strip(),
|
32 |
+
'tail': object_.strip()
|
33 |
+
})
|
34 |
+
object_ = ''
|
35 |
+
elif token == "<obj>":
|
36 |
+
current = 'o'
|
37 |
+
relation = ''
|
38 |
+
else:
|
39 |
+
if current == 't':
|
40 |
+
subject += ' ' + token
|
41 |
+
elif current == 's':
|
42 |
+
object_ += ' ' + token
|
43 |
+
elif current == 'o':
|
44 |
+
relation += ' ' + token
|
45 |
+
if subject != '' and relation != '' and object_ != '':
|
46 |
+
relations.append({
|
47 |
+
'head': subject.strip(),
|
48 |
+
'type': relation.strip(),
|
49 |
+
'tail': object_.strip()
|
50 |
+
})
|
51 |
+
return relations
|
52 |
+
|
53 |
+
def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
|
54 |
+
article_publish_date=None, verbose=False):
|
55 |
+
# tokenize whole text
|
56 |
+
inputs = tokenizer([text], return_tensors="pt")
|
57 |
+
|
58 |
+
# compute span boundaries
|
59 |
+
num_tokens = len(inputs["input_ids"][0])
|
60 |
+
if verbose:
|
61 |
+
print(f"Input has {num_tokens} tokens")
|
62 |
+
num_spans = math.ceil(num_tokens / span_length)
|
63 |
+
if verbose:
|
64 |
+
print(f"Input has {num_spans} spans")
|
65 |
+
overlap = math.ceil((num_spans * span_length - num_tokens) /
|
66 |
+
max(num_spans - 1, 1))
|
67 |
+
spans_boundaries = []
|
68 |
+
start = 0
|
69 |
+
for i in range(num_spans):
|
70 |
+
spans_boundaries.append([start + span_length * i,
|
71 |
+
start + span_length * (i + 1)])
|
72 |
+
start -= overlap
|
73 |
+
if verbose:
|
74 |
+
print(f"Span boundaries are {spans_boundaries}")
|
75 |
+
|
76 |
+
# transform input with spans
|
77 |
+
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
|
78 |
+
for boundary in spans_boundaries]
|
79 |
+
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
|
80 |
+
for boundary in spans_boundaries]
|
81 |
+
inputs = {
|
82 |
+
"input_ids": torch.stack(tensor_ids),
|
83 |
+
"attention_mask": torch.stack(tensor_masks)
|
84 |
+
}
|
85 |
+
|
86 |
+
# generate relations
|
87 |
+
num_return_sequences = 3
|
88 |
+
gen_kwargs = {
|
89 |
+
"max_length": 256,
|
90 |
+
"length_penalty": 0,
|
91 |
+
"num_beams": 3,
|
92 |
+
"num_return_sequences": num_return_sequences
|
93 |
+
}
|
94 |
+
generated_tokens = model.generate(
|
95 |
+
**inputs,
|
96 |
+
**gen_kwargs,
|
97 |
+
)
|
98 |
+
|
99 |
+
# decode relations
|
100 |
+
decoded_preds = tokenizer.batch_decode(generated_tokens,
|
101 |
+
skip_special_tokens=False)
|
102 |
+
|
103 |
+
# create kb
|
104 |
+
kb = KB()
|
105 |
+
i = 0
|
106 |
+
for sentence_pred in decoded_preds:
|
107 |
+
current_span_index = i // num_return_sequences
|
108 |
+
relations = extract_relations_from_model_output(sentence_pred)
|
109 |
+
for relation in relations:
|
110 |
+
relation["meta"] = {
|
111 |
+
article_url: {
|
112 |
+
"spans": [spans_boundaries[current_span_index]]
|
113 |
+
}
|
114 |
+
}
|
115 |
+
kb.add_relation(relation, article_title, article_publish_date)
|
116 |
+
i += 1
|
117 |
+
|
118 |
+
return kb
|
119 |
+
|
120 |
+
def get_article(url):
|
121 |
+
article = Article(url)
|
122 |
+
article.download()
|
123 |
+
article.parse()
|
124 |
+
return article
|
125 |
+
|
126 |
+
def from_url_to_kb(url, model, tokenizer):
|
127 |
+
article = get_article(url)
|
128 |
+
config = {
|
129 |
+
"article_title": article.title,
|
130 |
+
"article_publish_date": article.publish_date
|
131 |
+
}
|
132 |
+
kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
|
133 |
+
return kb
|
134 |
+
|
135 |
+
def get_news_links(query, lang="en", region="US", pages=1):
|
136 |
+
googlenews = GoogleNews(lang=lang, region=region)
|
137 |
+
googlenews.search(query)
|
138 |
+
all_urls = []
|
139 |
+
for page in range(pages):
|
140 |
+
googlenews.get_page(page)
|
141 |
+
all_urls += googlenews.get_links()
|
142 |
+
return list(set(all_urls))
|
143 |
+
|
144 |
+
def from_urls_to_kb(urls, model, tokenizer, verbose=False):
|
145 |
+
kb = KB()
|
146 |
+
if verbose:
|
147 |
+
print(f"{len(urls)} links to visit")
|
148 |
+
for url in urls:
|
149 |
+
if verbose:
|
150 |
+
print(f"Visiting {url}...")
|
151 |
+
try:
|
152 |
+
kb_url = from_url_to_kb(url, model, tokenizer)
|
153 |
+
kb.merge_with_kb(kb_url)
|
154 |
+
except ArticleException:
|
155 |
+
if verbose:
|
156 |
+
print(f" Couldn't download article at url {url}")
|
157 |
+
return kb
|
158 |
+
|
159 |
+
def save_network_html(kb, filename="network.html"):
|
160 |
+
# create network
|
161 |
+
net = Network(directed=True, width="700px", height="700px")
|
162 |
+
|
163 |
+
# nodes
|
164 |
+
color_entity = "#00FF00"
|
165 |
+
for e in kb.entities:
|
166 |
+
net.add_node(e, shape="circle", color=color_entity)
|
167 |
+
|
168 |
+
# edges
|
169 |
+
for r in kb.relations:
|
170 |
+
net.add_edge(r["head"], r["tail"],
|
171 |
+
title=r["type"], label=r["type"])
|
172 |
+
|
173 |
+
# save network
|
174 |
+
net.repulsion(
|
175 |
+
node_distance=200,
|
176 |
+
central_gravity=0.2,
|
177 |
+
spring_length=200,
|
178 |
+
spring_strength=0.05,
|
179 |
+
damping=0.09
|
180 |
+
)
|
181 |
+
net.set_edge_smooth('dynamic')
|
182 |
+
net.show(filename)
|
183 |
+
|
184 |
+
def save_kb(kb, filename):
|
185 |
+
with open(filename, "wb") as f:
|
186 |
+
pickle.dump(kb, f)
|
187 |
+
|
188 |
+
class CustomUnpickler(pickle.Unpickler):
|
189 |
+
def find_class(self, module, name):
|
190 |
+
if name == 'KB':
|
191 |
+
return KB
|
192 |
+
return super().find_class(module, name)
|
193 |
+
|
194 |
+
def load_kb(filename):
|
195 |
+
res = None
|
196 |
+
with open(filename, "rb") as f:
|
197 |
+
res = CustomUnpickler(f).load()
|
198 |
+
return res
|