esm3_ddg_v2 / ddg_predictor_modeling.py
hazemessam's picture
Upload StabilityPrediction
e821e69 verified
from transformers import PreTrainedModel, PretrainedConfig
from torch import nn
from esm.pretrained import ESM3_sm_open_v0
import torch
class StabilityPredictionConfig(PretrainedConfig):
def __init__(self, embed_dim=1536, *args, **kwargs):
super().__init__(*args, embed_dim=1536, **kwargs)
class SingleMutationPooler(nn.Module):
def __init__(self, embed_dim=1536):
super().__init__()
self.wt_weight = nn.Parameter(torch.ones((1, embed_dim)), requires_grad=True)
self.mut_weight = nn.Parameter(-1 * torch.ones((1, embed_dim)), requires_grad=True)
self.norm = nn.LayerNorm(embed_dim, bias=False)
def forward(self, wt_embedding, mut_embedding, positions):
embed_shape = wt_embedding.shape[-1]
positions = positions.view(-1, 1).unsqueeze(2).repeat(1, 1, embed_shape) + 1
wt_residues = torch.gather(wt_embedding, 1, positions).squeeze(1)
mut_residues = torch.gather(mut_embedding, 1, positions).squeeze(1)
wt_residues = wt_residues * self.wt_weight
mut_residues = mut_residues * self.mut_weight
return self.norm(wt_residues + mut_residues)
class StabilityPrediction(PreTrainedModel):
config_class = StabilityPredictionConfig
def __init__(self, config=StabilityPredictionConfig()):
super().__init__(config=config)
self.backbone = ESM3_sm_open_v0(getattr(config, "device", "cpu"))
self.pooler = SingleMutationPooler()
self.regressor = nn.Linear(config.embed_dim, 1)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.regressor.weight, -0.01, 0.01)
nn.init.zeros_(self.regressor.bias)
def compute_loss(self, logits, labels):
if labels is None:
return
return F.mse_loss(logits, labels)
def forward(self, wt_input_ids, mut_input_ids, positions, labels=None):
wt_embeddings = self.backbone(sequence_tokens=wt_input_ids).embeddings
mut_embeddings = self.backbone(sequence_tokens=mut_input_ids).embeddings
aggregated_embeddings = self.pooler(wt_embeddings, mut_embeddings, positions)
logits = self.regressor(aggregated_embeddings)
loss = self.compute_loss(logits, labels)
return {
"loss": loss,
"logits": logits,
}