Spaces:
Runtime error
Runtime error
from transformers import PreTrainedModel | |
from OmicsConfig import OmicsConfig | |
from transformers import PretrainedConfig, PreTrainedModel | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import GATv2Conv | |
from torch_geometric.data import Batch | |
from torch.utils.data import DataLoader | |
from torch.optim import AdamW | |
from torch_geometric.utils import negative_sampling | |
from torch.nn.functional import cosine_similarity | |
from torch.optim.lr_scheduler import StepLR | |
class EdgeWeightPredictorModel(PreTrainedModel): | |
config_class = OmicsConfig | |
base_model_prefix = "edge_weight_predictor" | |
def __init__(self, config): | |
super().__init__(config) | |
layers = [] | |
input_size = 2 * config.out_channels | |
for hidden_size, activation in zip(config.edge_decoder_hidden_sizes, config.edge_decoder_activations): | |
layers.append(nn.Linear(input_size, hidden_size)) | |
if activation == 'ReLU': | |
layers.append(nn.ReLU()) | |
elif activation == 'Sigmoid': | |
layers.append(nn.Sigmoid()) | |
elif activation == 'Tanh': | |
layers.append(nn.Tanh()) | |
# Add more activations if needed | |
input_size = hidden_size | |
layers.append(nn.Linear(input_size, 1)) | |
self.predictor = nn.Sequential(*layers) | |
def forward(self, z, edge_index): | |
edge_embeddings = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1) | |
return self.predictor(edge_embeddings) | |