Spaces:
Runtime error
Runtime error
from transformers import PreTrainedModel | |
from OmicsConfig import OmicsConfig | |
from transformers import PretrainedConfig, PreTrainedModel | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import GATv2Conv | |
from torch_geometric.data import Batch | |
from torch.utils.data import DataLoader | |
from torch.optim import AdamW | |
from torch_geometric.utils import negative_sampling | |
from torch.nn.functional import cosine_similarity | |
from torch.optim.lr_scheduler import StepLR | |
from EdgeWeightPredictorModel import EdgeWeightPredictorModel | |
class GATv2DecoderModel(PreTrainedModel): | |
config_class = OmicsConfig | |
base_model_prefix = "gatv2_decoder" | |
def __init__(self, config): | |
super().__init__(config) | |
self.layers = nn.ModuleList([ | |
nn.Linear(config.out_channels if i == 0 else config.out_channels, config.out_channels) | |
for i in range(config.num_layers) | |
]) | |
self.fc = nn.Linear(config.out_channels, config.original_feature_size) | |
self.edge_weight_predictor = EdgeWeightPredictorModel(config) | |
def forward(self, z): | |
for layer in self.layers: | |
z = layer(z) | |
z = F.relu(z) | |
x_reconstructed = self.fc(z) | |
return x_reconstructed | |
def predict_edge_weights(self, z, edge_index): | |
return self.edge_weight_predictor(z, edge_index) | |