"""Utilities for converting Graphein Networks to Geometric Deep Learning formats. """ # %% # Graphein # Author: Kexin Huang, Arian Jamasb # License: MIT # Project Website: https://github.com/a-r-j/graphein # Code Repository: https://github.com/a-r-j/graphein from __future__ import annotations from typing import List, Optional import networkx as nx import numpy as np import torch try: from graphein.utils.dependencies import import_message except ImportError: raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html') try: import torch_geometric from torch_geometric.data import Data except ImportError: import_message( submodule="graphein.ml.conversion", package="torch_geometric", pip_install=True, conda_channel="rusty1s", ) try: import dgl except ImportError: import_message( submodule="graphein.ml.conversion", package="dgl", pip_install=True, conda_channel="dglteam", ) try: import jax.numpy as jnp except ImportError: import_message( submodule="graphein.ml.conversion", package="jax", pip_install=True, conda_channel="conda-forge", ) try: import jraph except ImportError: import_message( submodule="graphein.ml.conversion", package="jraph", pip_install=True, conda_channel="conda-forge", ) SUPPORTED_FORMATS = ["nx", "pyg", "dgl", "jraph"] """Supported conversion formats. ``"nx"``: NetworkX graph ``"pyg"``: PyTorch Geometric Data object ``"dgl"``: DGL graph ``"Jraph"``: Jraph GraphsTuple """ SUPPORTED_VERBOSITY = ["gnn", "default", "all_info"] """Supported verbosity levels for preserving graph features in conversion.""" class GraphFormatConvertor: """ Provides conversion utilities between NetworkX Graphs and geometric deep learning library destination formats. Currently, we provide support for converstion from ``nx.Graph`` to ``dgl.DGLGraph`` and ``pytorch_geometric.Data``. Supported conversion formats can be retrieved from :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`. :param src_format: The type of graph you'd like to convert from. Supported formats are available in :const:`~graphein.ml.conversion.SUPPORTED_FORMATS` :type src_format: Literal["nx", "pyg", "dgl", "jraph"] :param dst_format: The type of graph format you'd like to convert to. Supported formats are available in: ``graphein.ml.conversion.SUPPORTED_FORMATS`` :type dst_format: Literal["nx", "pyg", "dgl", "jraph"] :param verbose: Select from ``"gnn"``, ``"default"``, ``"all_info"`` to determine how much information is preserved (features) as some are unsupported by various downstream frameworks :type verbose: graphein.ml.conversion.SUPPORTED_VERBOSITY :param columns: List of columns in the node features to retain :type columns: List[str], optional """ def __init__( self, src_format: str, dst_format: str, verbose: SUPPORTED_VERBOSITY = "gnn", columns: Optional[List[str]] = None, ): if (src_format not in SUPPORTED_FORMATS) or ( dst_format not in SUPPORTED_FORMATS ): raise ValueError( "Please specify from supported format, " + "/".join(SUPPORTED_FORMATS) ) self.src_format = src_format self.dst_format = dst_format # supported_verbose_format = ["gnn", "default", "all_info"] if (columns is None) and (verbose not in SUPPORTED_VERBOSITY): raise ValueError( "Please specify the supported verbose mode (" + "/".join(SUPPORTED_VERBOSITY) + ") or specify column names!" ) if columns is None: if verbose == "gnn": columns = [ "edge_index", "coords", "dist_mat", "name", "node_id", ] elif verbose == "default": columns = [ "b_factor", "chain_id", "coords", "dist_mat", "edge_index", "kind", "name", "node_id", "residue_name", ] elif verbose == "all_info": columns = [ "atom_type", "b_factor", "chain_id", "chain_ids", "config", "coords", "dist_mat", "edge_index", "element_symbol", "kind", "name", "node_id", "node_type", "pdb_df", "raw_pdb_df", "residue_name", "residue_number", "rgroup_df", "sequence_A", "sequence_B", ] self.columns = columns self.type2form = { "atom_type": "str", "b_factor": "float", "chain_id": "str", "coords": "np.array", "dist_mat": "np.array", "element_symbol": "str", "node_id": "str", "residue_name": "str", "residue_number": "int", "edge_index": "torch.tensor", "kind": "str", } def convert_nx_to_dgl(self, G: nx.Graph) -> dgl.DGLGraph: """ Converts ``NetworkX`` graph to ``DGL`` :param G: ``nx.Graph`` to convert to ``DGLGraph`` :type G: nx.Graph :return: ``DGLGraph`` object version of input ``NetworkX`` graph :rtype: dgl.DGLGraph """ g = dgl.DGLGraph() node_id = list(G.nodes()) G = nx.convert_node_labels_to_integers(G) ## add node level feat node_dict = {} for i, (_, feat_dict) in enumerate(G.nodes(data=True)): for key, value in feat_dict.items(): if str(key) in self.columns: node_dict[str(key)] = ( [value] if i == 0 else node_dict[str(key)] + [value] ) string_dict = {} node_dict_transformed = {} for i, j in node_dict.items(): if i == "coords": node_dict_transformed[i] = torch.Tensor(np.asarray(j)).type( "torch.FloatTensor" ) elif i == "dist_mat": node_dict_transformed[i] = torch.Tensor( np.asarray(j[0].values) ).type("torch.FloatTensor") elif self.type2form[i] == "str": string_dict[i] = j elif self.type2form[i] in ["float", "int"]: node_dict_transformed[i] = torch.Tensor(np.array(j)) g.add_nodes( len(node_id), node_dict_transformed, ) edge_dict = {} edge_index = torch.LongTensor(list(G.edges)).t().contiguous() # add edge level features for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): for key, value in feat_dict.items(): if str(key) in self.columns: edge_dict[str(key)] = ( list(value) if i == 0 else edge_dict[str(key)] + list(value) ) edge_transform_dict = {} for i, j in node_dict.items(): if self.type2form[i] == "str": string_dict[i] = j elif self.type2form[i] in ["float", "int"]: edge_transform_dict[i] = torch.Tensor(np.array(j)) g.add_edges(edge_index[0], edge_index[1], edge_transform_dict) # add graph level features graph_dict = { str(feat_name): [G.graph[feat_name]] for feat_name in G.graph if str(feat_name) in self.columns } return g def convert_nx_to_pyg(self, G: nx.Graph) -> Data: """ Converts ``NetworkX`` graph to ``pytorch_geometric.data.Data`` object. Requires ``PyTorch Geometric`` (https://pytorch-geometric.readthedocs.io/en/latest/) to be installed. :param G: ``nx.Graph`` to convert to PyTorch Geometric ``Data`` object :type G: nx.Graph :return: ``Data`` object containing networkx graph data :rtype: pytorch_geometric.data.Data """ # Initialise dict used to construct Data object & Assign node ids as a feature data = {"node_id": list(G.nodes())} G = nx.convert_node_labels_to_integers(G) # Construct Edge Index edge_index = torch.LongTensor(list(G.edges)).t().contiguous() # Add node features for i, (_, feat_dict) in enumerate(G.nodes(data=True)): for key, value in feat_dict.items(): if str(key) in self.columns: data[str(key)] = ( [value] if i == 0 else data[str(key)] + [value] ) # Add edge features for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): for key, value in feat_dict.items(): if str(key) in self.columns: data[str(key)] = ( list(value) if i == 0 else data[str(key)] + list(value) ) # Add graph-level features for feat_name in G.graph: if str(feat_name) in self.columns: data[str(feat_name)] = [G.graph[feat_name]] if "edge_index" in self.columns: data["edge_index"] = edge_index.view(2, -1) data = Data.from_dict(data) data.num_nodes = G.number_of_nodes() return data @staticmethod def convert_nx_to_nx(G: nx.Graph) -> nx.Graph: """ Converts NetworkX graph (``nx.Graph``) to NetworkX graph (``nx.Graph``) object. Redundant - returns itself. :param G: NetworkX Graph :type G: nx.Graph :return: NetworkX Graph :rtype: nx.Graph """ return G @staticmethod def convert_dgl_to_nx(G: dgl.DGLGraph) -> nx.Graph: """ Converts a DGL Graph (``dgl.DGLGraph``) to a NetworkX (``nx.Graph``) object. Preserves node and edge attributes. :param G: ``dgl.DGLGraph`` to convert to ``NetworkX`` graph. :type G: dgl.DGLGraph :return: NetworkX graph object. :rtype: nx.Graph """ node_attrs = G.node_attr_schemes().keys() edge_attrs = G.edge_attr_schemes().keys() return dgl.to_networkx(G, node_attrs, edge_attrs) @staticmethod def convert_pyg_to_nx(G: Data) -> nx.Graph: """Converts PyTorch Geometric ``Data`` object to NetworkX graph (``nx.Graph``). :param G: Pytorch Geometric Data. :type G: torch_geometric.data.Data :returns: NetworkX graph. :rtype: nx.Graph """ return torch_geometric.utils.to_networkx(G) def convert_nx_to_jraph(self, G: nx.Graph) -> jraph.GraphsTuple: """Converts NetworkX graph (``nx.Graph``) to Jraph GraphsTuple graph. Requires ``jax`` and ``Jraph``. :param G: Networkx graph to convert. :type G: nx.Graph :return: Jraph GraphsTuple graph. :rtype: jraph.GraphsTuple """ G = nx.convert_node_labels_to_integers(G) n_node = len(G) n_edge = G.number_of_edges() edge_list = list(G.edges()) senders, receivers = zip(*edge_list) senders, receivers = jnp.array(senders), jnp.array(receivers) # Add node features node_features = {} for i, (_, feat_dict) in enumerate(G.nodes(data=True)): for key, value in feat_dict.items(): if str(key) in self.columns: # node_features[str(key)] = ( # [value] # if i == 0 # else node_features[str(key)] + [value] # ) feat = ( [value] if i == 0 else node_features[str(key)] + [value] ) try: feat = torch.tensor(feat) node_features[str(key)] = feat except TypeError: node_features[str(key)] = feat # Add edge features edge_features = {} for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): for key, value in feat_dict.items(): if str(key) in self.columns: edge_features[str(key)] = ( list(value) if i == 0 else edge_features[str(key)] + list(value) ) # Add graph features global_context = { str(feat_name): [G.graph[feat_name]] for feat_name in G.graph if str(feat_name) in self.columns } return jraph.GraphsTuple( nodes=node_features, senders=senders, receivers=receivers, edges=edge_features, n_node=n_node, n_edge=n_edge, globals=global_context, ) def __call__(self, G: nx.Graph): nx_g = eval("self.convert_" + self.src_format + "_to_nx(G)") dst_g = eval("self.convert_nx_to_" + self.dst_format + "(nx_g)") return dst_g # def convert_nx_to_pyg_data(G: nx.Graph) -> Data: # # Initialise dict used to construct Data object # data = {"node_id": list(G.nodes())} # G = nx.convert_node_labels_to_integers(G) # # Construct Edge Index # edge_index = torch.LongTensor(list(G.edges)).t().contiguous() # # Add node features # for i, (_, feat_dict) in enumerate(G.nodes(data=True)): # for key, value in feat_dict.items(): # data[str(key)] = [value] if i == 0 else data[str(key)] + [value] # # Add edge features # for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): # for key, value in feat_dict.items(): # data[str(key)] = ( # list(value) if i == 0 else data[str(key)] + list(value) # ) # # Add graph-level features # for feat_name in G.graph: # data[str(feat_name)] = [G.graph[feat_name]] # data["edge_index"] = edge_index.view(2, -1) # data = Data.from_dict(data) # data.num_nodes = G.number_of_nodes() # return data def convert_nx_to_pyg_data(G: nx.Graph) -> Data: # Initialise dict used to construct Data object data = {"node_id": list(G.nodes())} G = nx.convert_node_labels_to_integers(G) # Construct Edge Index edge_index = torch.LongTensor(list(G.edges)).t().contiguous() # Add node features for i, (_, feat_dict) in enumerate(G.nodes(data=True)): for key, value in feat_dict.items(): data[str(key)] = [value] if i == 0 else data[str(key)] + [value] # Add edge features for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): for key, value in feat_dict.items(): if key == 'distance': data[str(key)] = ( [value] if i == 0 else data[str(key)] + [value] ) else: data[str(key)] = ( [list(value)] if i == 0 else data[str(key)] + [list(value)] ) # Add graph-level features for feat_name in G.graph: data[str(feat_name)] = [G.graph[feat_name]] data["edge_index"] = edge_index.view(2, -1) data = Data.from_dict(data) data.num_nodes = G.number_of_nodes() return data