File size: 1,523 Bytes
c238491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import PreTrainedModel
from OmicsConfig import OmicsConfig
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch_geometric.utils import negative_sampling
from torch.nn.functional import cosine_similarity
from torch.optim.lr_scheduler import StepLR


class EdgeWeightPredictorModel(PreTrainedModel):
    config_class = OmicsConfig
    base_model_prefix = "edge_weight_predictor"

    def __init__(self, config):
        super().__init__(config)
        layers = []
        input_size = 2 * config.out_channels
        for hidden_size, activation in zip(config.edge_decoder_hidden_sizes, config.edge_decoder_activations):
            layers.append(nn.Linear(input_size, hidden_size))
            if activation == 'ReLU':
                layers.append(nn.ReLU())
            elif activation == 'Sigmoid':
                layers.append(nn.Sigmoid())
            elif activation == 'Tanh':
                layers.append(nn.Tanh())
            # Add more activations if needed
            input_size = hidden_size
        layers.append(nn.Linear(input_size, 1))
        self.predictor = nn.Sequential(*layers)

    def forward(self, z, edge_index):
        edge_embeddings = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1)
        return self.predictor(edge_embeddings)