Spaces:
Runtime error
Runtime error
import torch | |
import pickle | |
import numpy as np | |
class Attention_Extracter: | |
def __init__(self, graph_data_dict_path, encoder_model, gpu=False): | |
self.torch_device = 'cuda' if gpu else 'cpu' | |
self.graph_data_dict = torch.load(graph_data_dict_path) | |
self.encoder_model = encoder_model | |
self.encoder_model.to(self.torch_device) | |
self.encoder_model.eval() | |
self.latent_feat_dict, self.attention_scores1 = self.extract_latent_attention_features() | |
def extract_latent_attention_features(self): | |
latent_features = {} | |
attention_scores1 = {} | |
with torch.no_grad(): | |
for graph_id, data in self.graph_data_dict.items(): | |
data = data.to(self.torch_device) | |
z, attention_weights = self.encoder_model(data.x, data.edge_index, data.edge_attr) | |
latent_features[graph_id] = z.cpu() | |
# Handling the case where attention_weights is a tuple or other data structure | |
if isinstance(attention_weights, (list, tuple)): | |
attention_scores1[graph_id] = [aw for aw in attention_weights] | |
else: | |
attention_scores1[graph_id] = attention_weights.cpu() | |
return latent_features, attention_scores1 | |
def load_edge_indices(self, glist_path, edge_matrix_path): | |
with open(glist_path, 'rb') as f: | |
glist = pickle.load(f) | |
edge_matrix = np.load(edge_matrix_path) | |
edge_matrix = torch.tensor(edge_matrix, dtype=torch.float) | |
edge_index = torch.nonzero(edge_matrix, as_tuple=False).t().contiguous() | |
edge_indices_dict = {} | |
for i in range(edge_index.shape[1]): | |
index1, index2 = edge_index[0, i].item(), edge_index[1, i].item() | |
gene1, gene2 = glist[index1], glist[index2] | |
edge_indices_dict[(index1, index2)] = (gene1, gene2) | |
return edge_indices_dict | |