import sys
import json
import torch
import gradio as gr
from pyvis.network import Network
sys.path.append(".")
import re
from src.benchmarks import get_semistructured_data
CONCURRENCY_LIMIT = 1000
TITLE = "STaRK Semi-structured Knowledge Base Explorer"
BRAND_NAME = {
"amazon": "STaRK-Amazon",
"mag": "STaRK-MAG",
"primekg": "STaRK-Prime",
}
NODE_COLORS = [
"#4285F4", # Blue
"#F4B400", # Yellow
"#0F9D58", # Green
"#00796B", # Teal
"#03A9F4", # Light Blue
"#CDDC39", # Lime
"#3F51B5", # Indigo
"#00BCD4", # Cyan
"#FFC107", # Amber
"#8BC34A", # Light Green
"#9E9E9E", # Grey
"#607D8B", # Blue Grey
"#FFEB3B", # Bright Yellow
"#E1F5FE", # Light Blue 50
"#F1F8E9", # Light Green 50
"#FFF3E0", # Orange 50
"#FFFDE7", # Yellow 50
"#E0F7FA", # Cyan 50
"#E8F5E9", # Green 50
"#E3F2FD", # Blue 50
"#FFF8E1", # Amber 50
"#E0F2F1", # Teal 50
"#F9FBE7", # Lime 50
]
EDGE_COLORS = [
"#1B5E20", # Green 900
"#004D40", # Teal 900
"#1A237E", # Indigo 900
"#3E2723", # Brown 900
"#880E4F", # Pink 900
"#01579B", # Light Blue 900
"#F57F17", # Yellow 900
"#FF6F00", # Amber 900
"#4A148C", # Purple 900
"#0D47A1", # Blue 900
"#006064", # Cyan 900
"#827717", # Lime 900
"#E8EAF6", # Indigo 50
"#ECEFF1", # Blue Grey 50
"#9C27B0", # Purple
"#311B92", # Deep Purple 900
"#673AB7", # Deep Purple
"#EDE7F6", # Deep Purple 50
]
VISJS_HEAD = """
"""
with open("interactive/draw_graph.js", "r") as f:
VISJS_HEAD += f""
def relabel(x, edge_index, batch, pos=None):
num_nodes = x.size(0)
sub_nodes = torch.unique(edge_index)
x = x[sub_nodes]
batch = batch[sub_nodes]
row, col = edge_index
# remapping the nodes in the explanatory subgraph to new ids.
node_idx = row.new_full((num_nodes,), -1)
node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device)
edge_index = node_idx[edge_index]
if pos is not None:
pos = pos[sub_nodes]
return x, edge_index, batch, pos
def generate_network(kb, node_id, max_nodes=10, num_hops='2'):
max_nodes = int(max_nodes)
if 'gene/protein' in kb.node_type_dict.values():
indirected = True
net = Network(directed=False)
else:
indirected = False
net = Network()
def get_one_hop(kb, node_id, max_nodes):
edge_index = kb.edge_index
mask = (
torch.Tensor(edge_index[0] == node_id).float()
+ torch.Tensor(edge_index[1] == node_id).float()
) > 0
edge_index_with_node_id = edge_index[:, mask]
edge_types = kb.edge_types[mask]
# take the edge index with
# ramdomly sample max_nodes edges
if edge_index_with_node_id.size(1) > max_nodes:
perm = torch.randperm(edge_index_with_node_id.size(1))
edge_index_with_node_id = edge_index_with_node_id[:, perm[:max_nodes]]
edge_types = edge_types[perm[:max_nodes]]
return edge_index_with_node_id, edge_types
if num_hops == "1":
edge_index, edge_types = get_one_hop(kb, node_id, max_nodes)
if num_hops == "2":
edge_index, edge_types = get_one_hop(kb, node_id, max_nodes)
neighbor_nodes = torch.unique(edge_index).tolist()
if node_id in neighbor_nodes:
neighbor_nodes.remove(node_id)
for neighbor_node in neighbor_nodes:
e_index, e_type = get_one_hop(kb, neighbor_node, max_nodes=1)
edge_index = torch.cat([edge_index, e_index], dim=1)
edge_types = torch.cat([edge_types, e_type], dim=0)
if num_hops == "inf":
edge_index, edge_types = kb.edge_index, kb.edge_types
# sample max_nodes edges
if edge_index.size(1) > max_nodes:
perm = torch.randperm(edge_index.size(1))
edge_index = edge_index[:, perm[:max_nodes]]
edge_types = edge_types[perm[:max_nodes]]
add_edge_index, add_edge_types = get_one_hop(kb, node_id, max_nodes=1)
edge_index = torch.cat([edge_index, add_edge_index], dim=1)
edge_types = torch.cat([edge_types, add_edge_types], dim=0)
# add a self-loop for node_id to avoid isolated node
edge_index = torch.concat([edge_index, torch.LongTensor([[node_id], [node_id]])], dim=1)
node_ids, relabel_edge_index, _, _ = relabel(
torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
)
for idx, n_id in enumerate(node_ids):
if node_id == n_id:
net.add_node(
idx,
node_id=n_id.item(),
color="#DB4437",
size=20,
label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}<{n_id}>",
font={"align": "middle", "size": 10},
)
else:
net.add_node(
idx,
node_id=n_id.item(),
size=15,
color=NODE_COLORS[kb.node_types[n_id].item()],
label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}",
font={"align": "middle", "size": 10},
)
for idx in range(relabel_edge_index.size(-1)):
if relabel_edge_index[0][idx].item() == relabel_edge_index[1][idx].item():
continue
if indirected:
net.add_edge(
relabel_edge_index[0][idx].item(),
relabel_edge_index[1][idx].item(),
color=EDGE_COLORS[edge_types[idx].item()],
label=kb.edge_type_dict[edge_types[idx].item()]
.replace('___', " ")
.replace('_', " "),
width=1,
font={"align": "middle", "size": 10})
else:
net.add_edge(
relabel_edge_index[0][idx].item(),
relabel_edge_index[1][idx].item(),
color=EDGE_COLORS[edge_types[idx].item()],
label=kb.edge_type_dict[edge_types[idx].item()]
.replace('___', " ")
.replace('_', " "),
width=1,
font={"align": "middle", "size": 10},
arrows="to",
arrowStrikethrough=False)
return net.get_network_data()
def get_text_html(kb, node_id):
text = kb.get_doc_info(node_id, add_rel=False, compact=False)
# add a title
text = text.replace("\n", "
").replace(" ", " ")
text = f"