|
from transformers import PreTrainedModel, PretrainedConfig |
|
from torch import nn |
|
from esm.pretrained import ESM3_sm_open_v0 |
|
import torch |
|
|
|
|
|
class StabilityPredictionConfig(PretrainedConfig): |
|
def __init__(self, embed_dim=1536, *args, **kwargs): |
|
super().__init__(*args, embed_dim=1536, **kwargs) |
|
|
|
|
|
class SingleMutationPooler(nn.Module): |
|
def __init__(self, embed_dim=1536): |
|
super().__init__() |
|
self.wt_weight = nn.Parameter(torch.ones((1, embed_dim)), requires_grad=True) |
|
self.mut_weight = nn.Parameter(-1 * torch.ones((1, embed_dim)), requires_grad=True) |
|
self.norm = nn.LayerNorm(embed_dim, bias=False) |
|
|
|
|
|
def forward(self, wt_embedding, mut_embedding, positions): |
|
embed_shape = wt_embedding.shape[-1] |
|
positions = positions.view(-1, 1).unsqueeze(2).repeat(1, 1, embed_shape) + 1 |
|
wt_residues = torch.gather(wt_embedding, 1, positions).squeeze(1) |
|
mut_residues = torch.gather(mut_embedding, 1, positions).squeeze(1) |
|
wt_residues = wt_residues * self.wt_weight |
|
mut_residues = mut_residues * self.mut_weight |
|
return self.norm(wt_residues + mut_residues) |
|
|
|
|
|
|
|
class StabilityPrediction(PreTrainedModel): |
|
config_class = StabilityPredictionConfig |
|
def __init__(self, config=StabilityPredictionConfig()): |
|
super().__init__(config=config) |
|
self.backbone = ESM3_sm_open_v0(getattr(config, "device", "cpu")) |
|
self.pooler = SingleMutationPooler() |
|
self.regressor = nn.Linear(config.embed_dim, 1) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.uniform_(self.regressor.weight, -0.01, 0.01) |
|
nn.init.zeros_(self.regressor.bias) |
|
|
|
def compute_loss(self, logits, labels): |
|
if labels is None: |
|
return |
|
return F.mse_loss(logits, labels) |
|
|
|
def forward(self, wt_input_ids, mut_input_ids, positions, labels=None): |
|
wt_embeddings = self.backbone(sequence_tokens=wt_input_ids).embeddings |
|
mut_embeddings = self.backbone(sequence_tokens=mut_input_ids).embeddings |
|
aggregated_embeddings = self.pooler(wt_embeddings, mut_embeddings, positions) |
|
logits = self.regressor(aggregated_embeddings) |
|
loss = self.compute_loss(logits, labels) |
|
return { |
|
"loss": loss, |
|
"logits": logits, |
|
} |
|
|