Isoformer / modeling_isoformer.py
isoformer-anonymous's picture
Update modeling_isoformer.py
13a7e78 verified
raw
history blame
6.64 kB
from .isoformer_config import IsoformerConfig
from transformers import PreTrainedModel
from .modeling_esm import NTForMaskedLM, MultiHeadAttention
from .esm_config import NTConfig
from .modeling_esm_original import EsmForMaskedLM
from transformers.models.esm.configuration_esm import EsmConfig
from enformer_pytorch import Enformer, str_to_one_hot, EnformerConfig
import torch
from torch import nn
class Isoformer(PreTrainedModel):
config_class = IsoformerConfig
def __init__(self, config):
super().__init__(config)
self.esm_config = EsmConfig(
vocab_size=config.esm_vocab_size,
mask_token_id=config.esm_mask_token_id,
pad_token_id=config.esm_pad_token_id,
hidden_size=config.esm_hidden_size,
num_hidden_layers=config.esm_num_hidden_layers,
num_attention_heads=config.esm_num_attention_heads,
intermediate_size=config.esm_intermediate_size,
max_position_embeddings=config.esm_max_position_embeddings,
token_dropout=config.esm_token_dropout,
emb_layer_norm_before=config.esm_emb_layer_norm_before,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
use_cache=False,
add_bias_fnn=config.esm_add_bias_fnn,
position_embedding_type="rotary",
tie_word_embeddings=False,
)
self.nt_config = NTConfig(
vocab_size=config.nt_vocab_size,
mask_token_id=config.nt_mask_token_id,
pad_token_id=config.nt_pad_token_id,
hidden_size=config.nt_hidden_size,
num_hidden_layers=config.nt_num_hidden_layers,
num_attention_heads=config.nt_num_attention_heads,
intermediate_size=config.nt_intermediate_size,
max_position_embeddings=config.nt_max_position_embeddings,
token_dropout=config.nt_token_dropout,
emb_layer_norm_before=config.nt_emb_layer_norm_before,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
use_cache=False,
add_bias_fnn=config.nt_add_bias_fnn,
position_embedding_type="rotary",
tie_word_embeddings=False,
)
self.config = config
# self.enformer_config = EnformerConfig(
# dim=config.enformer_dim,
# depth=config.enformer_depth,
# heads=config.enformer_heads,
# output_heads=dict(
# human=1,
# mouse=1 # TODO CHANGE
# ),
# target_length=config.enformer_target_length, # 896,
# attn_dim_key=config.enformer_attn_dim_key,
# dropout_rate=0.4,
# attn_dropout=0.05,
# pos_dropout=0.01,
# use_checkpointing=config.enformer_use_checkpointing,
# use_convnext=config.enformer_use_convnext,
# num_downsamples=config.enformer_num_downsamples,
# # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
# dim_divisible_by=config.enformer_dim_divisible_by,
# use_tf_gamma=False,
# )
self.esm_model = EsmForMaskedLM(self.esm_config) # protein encoder
self.nt_model = NTForMaskedLM(self.nt_config) # rna encoder
#self.enformer_model = Enformer(self.enformer_config) # dna encoder
self.enformer_model = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
self.cross_attention_layer_rna = MultiHeadAttention(
config=EsmConfig(
num_attention_heads=config.num_heads_omics_cross_attention,
attention_head_size=3072 // config.num_heads_omics_cross_attention,
hidden_size=3072,
attention_probs_dropout_prob=0,
max_position_embeddings=0
),
omics_of_interest_size=3072,
other_omic_size=768
)
self.cross_attention_layer_protein = MultiHeadAttention(
config=EsmConfig(
num_attention_heads=config.num_heads_omics_cross_attention,
attention_head_size=3072 // config.num_heads_omics_cross_attention,
hidden_size=3072,
attention_probs_dropout_prob=0,
max_position_embeddings=0
),
omics_of_interest_size=3072,
other_omic_size=640
)
self.head_layer_1 = nn.Linear(3072, 2 * 3072)
self.head_layer_2 = nn.Linear(2 * 3072, 30)
def forward(
self,
tensor_dna,
tensor_rna,
tensor_protein,
attention_mask_rna,
attention_mask_protein
):
tensor_dna = tensor_dna[:, 1:] # remove CLS
dna_embedding = self.enformer_model(
tensor_dna,
return_only_embeddings=True
# attention_mask=attention_mask_dna,
# encoder_attention_mask=attention_mask_dna,
# output_hidden_states=True
)
protein_embedding = self.esm_model(
tensor_protein,
attention_mask=attention_mask_protein,
encoder_attention_mask=attention_mask_protein,
output_hidden_states=True
)
rna_embedding = self.nt_model(
tensor_rna,
attention_mask=attention_mask_rna,
encoder_attention_mask=attention_mask_rna,
output_hidden_states=True
)
encoder_attention_mask = torch.unsqueeze(torch.unsqueeze(tensor_rna != 1, 0),0).repeat(1,1,dna_embedding.shape[1],1)
rna_to_dna = self.cross_attention_layer_rna.forward(
hidden_states=dna_embedding,
encoder_hidden_states=rna_embedding["hidden_states"][-1],
encoder_attention_mask=encoder_attention_mask
)
final_dna_embeddings = self.cross_attention_layer_protein.forward(
hidden_states=rna_to_dna["embeddings"],
encoder_hidden_states=protein_embedding["hidden_states"][-1],
)["embeddings"]
sequence_mask = torch.zeros(final_dna_embeddings.shape[1])
sequence_mask[self.config.pool_window_start:self.config.pool_window_end] = 1
x = torch.sum(torch.einsum('ijk,j->ijk', final_dna_embeddings, sequence_mask),axis=1)/torch.sum(sequence_mask)
x = self.head_layer_1(x)
x = torch.nn.functional.softplus(x)
x = self.head_layer_2(x)
return {
"gene_expression_predictions": x,
"final_dna_embeddings": final_dna_embeddings,
}