VatsalPatel18's picture
Model files
c238491
raw
history blame
1.37 kB
from transformers import PreTrainedModel
from OmicsConfig import OmicsConfig
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch_geometric.utils import negative_sampling
from torch.nn.functional import cosine_similarity
from torch.optim.lr_scheduler import StepLR
from EdgeWeightPredictorModel import EdgeWeightPredictorModel
class GATv2DecoderModel(PreTrainedModel):
config_class = OmicsConfig
base_model_prefix = "gatv2_decoder"
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList([
nn.Linear(config.out_channels if i == 0 else config.out_channels, config.out_channels)
for i in range(config.num_layers)
])
self.fc = nn.Linear(config.out_channels, config.original_feature_size)
self.edge_weight_predictor = EdgeWeightPredictorModel(config)
def forward(self, z):
for layer in self.layers:
z = layer(z)
z = F.relu(z)
x_reconstructed = self.fc(z)
return x_reconstructed
def predict_edge_weights(self, z, edge_index):
return self.edge_weight_predictor(z, edge_index)