from transformers import PreTrainedModel from OmicsConfig import OmicsConfig from transformers import PretrainedConfig, PreTrainedModel import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GATv2Conv from torch_geometric.data import Batch from torch.utils.data import DataLoader from torch.optim import AdamW from torch_geometric.utils import negative_sampling from torch.nn.functional import cosine_similarity from torch.optim.lr_scheduler import StepLR from GATv2EncoderModel import GATv2EncoderModel from GATv2DecoderModel import GATv2DecoderModel from EdgeWeightPredictorModel import EdgeWeightPredictorModel class MultiOmicsGraphAttentionAutoencoderModel(PreTrainedModel): config_class = OmicsConfig base_model_prefix = "graph-attention-autoencoder" def __init__(self, config): super().__init__(config) self.encoder = GATv2EncoderModel(config) self.decoder = GATv2DecoderModel(config) self.optimizer = AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=config.learning_rate) self.scheduler = StepLR(self.optimizer, step_size=30, gamma=0.7) def forward(self, x, edge_index, edge_attr): z, attention_weights = self.encoder(x, edge_index, edge_attr) x_reconstructed = self.decoder(z) return x_reconstructed, attention_weights def predict_edge_weights(self, z, edge_index): return self.decoder.predict_edge_weights(z, edge_index) def train_model(self, data_loader, device): self.encoder.to(device) self.decoder.to(device) self.encoder.train() self.decoder.train() total_loss = 0 total_cosine_similarity = 0 loss_weight_node = 1.0 loss_weight_edge = 1.0 loss_weight_edge_attr = 1.0 for data in data_loader: data = data.to(device) self.optimizer.zero_grad() z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr) x_reconstructed = self.decoder(z) node_loss = graph_reconstruction_loss(x_reconstructed, data.x) edge_loss = edge_reconstruction_loss(z, data.edge_index) cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean() total_cosine_similarity += cos_sim.item() pred_edge_weights = self.decoder.predict_edge_weights(z, data.edge_index) edge_weight_loss = edge_weight_reconstruction_loss(pred_edge_weights, data.edge_attr) loss = (loss_weight_node * node_loss) + (loss_weight_edge * edge_loss) + (loss_weight_edge_attr * edge_weight_loss) print(f"node_loss: {node_loss}, edge_loss: {edge_loss:.4f}, edge_weight_loss: {edge_weight_loss:.4f}, cosine_similarity: {cos_sim:.4f}") loss.backward() self.optimizer.step() total_loss += loss.item() avg_loss, avg_cosine_similarity = total_loss / len(data_loader), total_cosine_similarity / len(data_loader) return avg_loss, avg_cosine_similarity def fit(self, train_loader, validation_loader, epochs, device): train_losses = [] val_losses = [] for epoch in range(1, epochs + 1): train_loss, train_cosine_similarity = self.train_model(train_loader, device) torch.cuda.empty_cache() val_loss, val_cosine_similarity = self.validate(validation_loader, device) print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Cosine Similarity: {train_cosine_similarity:.4f}, Validation Loss: {val_loss:.4f}, Validation Cosine Similarity: {val_cosine_similarity:.4f}") self.scheduler.step() return train_losses, val_losses def validate(self, validation_loader, device): self.encoder.to(device) self.decoder.to(device) self.encoder.eval() self.decoder.eval() total_loss = 0 total_cosine_similarity = 0 with torch.no_grad(): for data in validation_loader: data = data.to(device) z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr) x_reconstructed = self.decoder(z) node_loss = graph_reconstruction_loss(x_reconstructed, data.x) edge_loss = edge_reconstruction_loss(z, data.edge_index) cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean() total_cosine_similarity += cos_sim.item() loss = node_loss + edge_loss total_loss += loss.item() avg_loss = total_loss / len(validation_loader) avg_cosine_similarity = total_cosine_similarity / len(validation_loader) return avg_loss, avg_cosine_similarity def evaluate(self, test_loader, device): self.encoder.to(device) self.decoder.to(device) self.encoder.eval() self.decoder.eval() total_loss = 0 total_accuracy = 0 with torch.no_grad(): for data in test_loader: data = data.to(device) z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr) x_reconstructed = self.decoder(z) node_loss = graph_reconstruction_loss(x_reconstructed, data.x) edge_loss = edge_reconstruction_loss(z, data.edge_index) cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean() total_cosine_similarity += cos_sim.item() loss = node_loss + edge_loss total_loss += loss.item() avg_loss = total_loss / len(validation_loader) avg_cosine_similarity = total_cosine_similarity / len(validation_loader) return avg_loss, avg_cosine_similarity # Define a collate function for the DataLoader def collate_graph_data(batch): return Batch.from_data_list(batch) # Define a function to create a DataLoader def create_data_loader(train_data, batch_size=1, shuffle=True): graph_data = list(train_data.values()) return DataLoader(graph_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_graph_data) # Define functions for the losses def graph_reconstruction_loss(pred_features, true_features): return F.mse_loss(pred_features, true_features) def edge_reconstruction_loss(z, pos_edge_index, neg_edge_index=None): pos_logits = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1) pos_loss = F.binary_cross_entropy_with_logits(pos_logits, torch.ones_like(pos_logits)) if neg_edge_index is None: neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) neg_logits = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1) neg_loss = F.binary_cross_entropy_with_logits(neg_logits, torch.zeros_like(neg_logits)) return pos_loss + neg_loss def edge_weight_reconstruction_loss(pred_weights, true_weights): pred_weights = pred_weights.squeeze(-1) return F.mse_loss(pred_weights, true_weights)