import sys import json import torch import gradio as gr from pyvis.network import Network sys.path.append(".") import re from src.benchmarks import get_semistructured_data CONCURRENCY_LIMIT = 1000 TITLE = "STaRK Semi-structured Knowledge Base Explorer" BRAND_NAME = { "amazon": "STaRK-Amazon", "mag": "STaRK-MAG", "primekg": "STaRK-Prime", } NODE_COLORS = [ "#4285F4", # Blue "#F4B400", # Yellow "#0F9D58", # Green "#00796B", # Teal "#03A9F4", # Light Blue "#CDDC39", # Lime "#3F51B5", # Indigo "#00BCD4", # Cyan "#FFC107", # Amber "#8BC34A", # Light Green "#9E9E9E", # Grey "#607D8B", # Blue Grey "#FFEB3B", # Bright Yellow "#E1F5FE", # Light Blue 50 "#F1F8E9", # Light Green 50 "#FFF3E0", # Orange 50 "#FFFDE7", # Yellow 50 "#E0F7FA", # Cyan 50 "#E8F5E9", # Green 50 "#E3F2FD", # Blue 50 "#FFF8E1", # Amber 50 "#E0F2F1", # Teal 50 "#F9FBE7", # Lime 50 ] EDGE_COLORS = [ "#1B5E20", # Green 900 "#004D40", # Teal 900 "#1A237E", # Indigo 900 "#3E2723", # Brown 900 "#880E4F", # Pink 900 "#01579B", # Light Blue 900 "#F57F17", # Yellow 900 "#FF6F00", # Amber 900 "#4A148C", # Purple 900 "#0D47A1", # Blue 900 "#006064", # Cyan 900 "#827717", # Lime 900 "#E8EAF6", # Indigo 50 "#ECEFF1", # Blue Grey 50 "#9C27B0", # Purple "#311B92", # Deep Purple 900 "#673AB7", # Deep Purple "#EDE7F6", # Deep Purple 50 ] VISJS_HEAD = """ """ with open("interactive/draw_graph.js", "r") as f: VISJS_HEAD += f"" def relabel(x, edge_index, batch, pos=None): num_nodes = x.size(0) sub_nodes = torch.unique(edge_index) x = x[sub_nodes] batch = batch[sub_nodes] row, col = edge_index # remapping the nodes in the explanatory subgraph to new ids. node_idx = row.new_full((num_nodes,), -1) node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device) edge_index = node_idx[edge_index] if pos is not None: pos = pos[sub_nodes] return x, edge_index, batch, pos def generate_network(kb, node_id, max_nodes=10, num_hops='2'): max_nodes = int(max_nodes) if 'gene/protein' in kb.node_type_dict.values(): indirected = True net = Network(directed=False) else: indirected = False net = Network() def get_one_hop(kb, node_id, max_nodes): edge_index = kb.edge_index mask = ( torch.Tensor(edge_index[0] == node_id).float() + torch.Tensor(edge_index[1] == node_id).float() ) > 0 edge_index_with_node_id = edge_index[:, mask] edge_types = kb.edge_types[mask] # take the edge index with # ramdomly sample max_nodes edges if edge_index_with_node_id.size(1) > max_nodes: perm = torch.randperm(edge_index_with_node_id.size(1)) edge_index_with_node_id = edge_index_with_node_id[:, perm[:max_nodes]] edge_types = edge_types[perm[:max_nodes]] return edge_index_with_node_id, edge_types if num_hops == "1": edge_index, edge_types = get_one_hop(kb, node_id, max_nodes) if num_hops == "2": edge_index, edge_types = get_one_hop(kb, node_id, max_nodes) neighbor_nodes = torch.unique(edge_index).tolist() if node_id in neighbor_nodes: neighbor_nodes.remove(node_id) for neighbor_node in neighbor_nodes: e_index, e_type = get_one_hop(kb, neighbor_node, max_nodes=1) edge_index = torch.cat([edge_index, e_index], dim=1) edge_types = torch.cat([edge_types, e_type], dim=0) if num_hops == "inf": edge_index, edge_types = kb.edge_index, kb.edge_types # sample max_nodes edges if edge_index.size(1) > max_nodes: perm = torch.randperm(edge_index.size(1)) edge_index = edge_index[:, perm[:max_nodes]] edge_types = edge_types[perm[:max_nodes]] add_edge_index, add_edge_types = get_one_hop(kb, node_id, max_nodes=1) edge_index = torch.cat([edge_index, add_edge_index], dim=1) edge_types = torch.cat([edge_types, add_edge_types], dim=0) # add a self-loop for node_id to avoid isolated node edge_index = torch.concat([edge_index, torch.LongTensor([[node_id], [node_id]])], dim=1) node_ids, relabel_edge_index, _, _ = relabel( torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes()) ) for idx, n_id in enumerate(node_ids): if node_id == n_id: net.add_node( idx, node_id=n_id.item(), color="#DB4437", size=20, label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}<{n_id}>", font={"align": "middle", "size": 10}, ) else: net.add_node( idx, node_id=n_id.item(), size=15, color=NODE_COLORS[kb.node_types[n_id].item()], label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}", font={"align": "middle", "size": 10}, ) for idx in range(relabel_edge_index.size(-1)): if relabel_edge_index[0][idx].item() == relabel_edge_index[1][idx].item(): continue if indirected: net.add_edge( relabel_edge_index[0][idx].item(), relabel_edge_index[1][idx].item(), color=EDGE_COLORS[edge_types[idx].item()], label=kb.edge_type_dict[edge_types[idx].item()] .replace('___', " ") .replace('_', " "), width=1, font={"align": "middle", "size": 10}) else: net.add_edge( relabel_edge_index[0][idx].item(), relabel_edge_index[1][idx].item(), color=EDGE_COLORS[edge_types[idx].item()], label=kb.edge_type_dict[edge_types[idx].item()] .replace('___', " ") .replace('_', " "), width=1, font={"align": "middle", "size": 10}, arrows="to", arrowStrikethrough=False) return net.get_network_data() def get_text_html(kb, node_id): text = kb.get_doc_info(node_id, add_rel=False, compact=False) # add a title text = text.replace("\n", "
").replace(" ", " ") text = f"

Textual Info of Entity {node_id}:

{text}" text = re.sub(r"\$([^$]+)\$", r"\\(\1\\)", text) # show the text as what it is with empty space and can be scrolled return f"""
{text}
""" def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops='1'): network = generate_network(kb, node_id, max_nodes, num_hops) nodes = network[0] edges = network[1] # A dirty hack to trigger the drawGraph function ;) # Have to do it this way because of the way Gradio handles HTML updates figure_html = f"""
""" return figure_html def main(): # kb = get_semistructured_data(DATASET_NAME) kbs = {k: get_semistructured_data(k, indirected=False) for k in BRAND_NAME.keys()} with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo: gr.Markdown(f"# {TITLE}") for name, kb in kbs.items(): with gr.Tab(BRAND_NAME[name]): with gr.Row(): entity_id = gr.Number( label="Entity ID", elem_id=f"{name}-entity-id-input" ) max_paths = gr.Slider( 1, 200, 10, step=1, label="Max Number of Paths" ) num_hops = gr.Dropdown( ["1", "2", "inf"], value="2", label="Number of Hops" ) query_btn = gr.Button( value="Display Semi-structured Data", variant="primary", elem_id=f"{name}-fetch-btn" ) with gr.Row(): graph_area = gr.HTML(elem_classes="graph-area") text_area = gr.HTML(elem_classes="text-area") query_btn.click( # copy capture current kb and name lambda e, n, h, kb=kb, name=name: ( get_subgraph_html(kb, name, e, n, h), get_text_html(kb, e), ), inputs=[entity_id, max_paths, num_hops], outputs=[graph_area, text_area], api_name=f"{name}-fetch-graph" ) # Hidden inputs for fetch just text with gr.Row(visible=False): entity_for_text = gr.Number( label="Text Entity ID", elem_id=f"{name}-entity-id-text-input" ) query_text_btn = gr.Button( value="Show Text", elem_id=f"{name}-fetch-text-btn" ) query_text_btn.click( lambda e, kb=kb: get_text_html(kb, e), inputs=[entity_for_text], outputs=text_area, api_name=f"{name}-fetch-text" ) demo.queue(max_size=2*CONCURRENCY_LIMIT, default_concurrency_limit=CONCURRENCY_LIMIT) demo.launch(share=True) if __name__ == "__main__": main()