|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
import math |
|
import requests |
|
from io import BytesIO |
|
from functools import partial |
|
import pickle |
|
from typing import Callable, Optional, Sequence, Tuple, List |
|
import numpy as np |
|
import os |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.nn.init import trunc_normal_ |
|
from torchvision import transforms |
|
from torchvision.transforms import InterpolationMode |
|
|
|
class GLU(nn.Module): |
|
def __init__(self,hidden_size): |
|
super().__init__() |
|
self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False) |
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
self.act1 = nn.GELU() |
|
self.act2 = nn.functional.silu |
|
self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False) |
|
self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False) |
|
self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False) |
|
|
|
def forward(self,x): |
|
x = self.linear_proj(x) |
|
x = self.act1(self.norm1(x)) |
|
x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x) |
|
x = self.dense_4h_to_h(x) |
|
return x |
|
def swiglu(x): |
|
x = torch.chunk(x, 2, dim=-1) |
|
return nn.functional.silu(x[0]) * x[1] |
|
|
|
class GLU_new(nn.Module): |
|
def __init__(self,hidden_size, dropout=0.1): |
|
super().__init__() |
|
intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64 |
|
intermediate_size = 1280 |
|
|
|
self.act = swiglu |
|
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False) |
|
self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
def forward(self,x): |
|
x = self.dense_h_to_4h(x) |
|
x = self.act(x) |
|
x = self.dense_4h_to_h(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
n_queries = 32 |
|
def get_abs_pos(abs_pos, tgt_size): |
|
|
|
|
|
|
|
src_size = int(math.sqrt(abs_pos.size(0))) |
|
tgt_size = int(math.sqrt(tgt_size)) |
|
dtype = abs_pos.dtype |
|
|
|
if src_size != tgt_size: |
|
return F.interpolate( |
|
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), |
|
size=(tgt_size, tgt_size), |
|
mode="bicubic", |
|
align_corners=False, |
|
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) |
|
else: |
|
return abs_pos |
|
|
|
from einops import rearrange, repeat |
|
|
|
def get_1d_sincos_pos_embed(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float32) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum('m,d->md', pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
class Resampler(nn.Module): |
|
def __init__( |
|
self, |
|
kv_dim, |
|
embed_dim, |
|
num_heads=8, |
|
n_queries=64, |
|
max_seqlen=1024, |
|
perceiver_resampler_positional_emb=True, |
|
use_GLU=False, |
|
bos_init=False, |
|
dropout=0.0 |
|
): |
|
super().__init__() |
|
self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb |
|
|
|
if self.perceiver_resampler_positional_emb: |
|
assert n_queries <= max_seqlen |
|
self.stride = max_seqlen // n_queries |
|
|
|
|
|
pos = np.arange(max_seqlen, dtype=np.float32) |
|
self.register_buffer( |
|
"pos_embed", |
|
torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float() |
|
) |
|
self.latents = nn.Parameter(torch.randn(n_queries, embed_dim)) |
|
if bos_init: |
|
self.latents.load('') |
|
else: |
|
nn.init.trunc_normal_(self.latents, std=1e-3) |
|
|
|
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) |
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout) |
|
self.ln_q = nn.LayerNorm(embed_dim) |
|
self.ln_kv = nn.LayerNorm(embed_dim) |
|
self.ln_post = nn.LayerNorm(embed_dim) |
|
if use_GLU: |
|
print('GLU *********************************') |
|
self.proj = GLU_new(embed_dim, dropout=dropout) |
|
else: |
|
self.proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.trunc_normal_(m.weight, std=1e-3) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def forward(self, struc_x): |
|
""" |
|
Args: |
|
x (torch.Tensor): protein structure features |
|
shape (B, L, C) |
|
Returns: |
|
shape (B, n, C) where n is self.num_latents |
|
""" |
|
x = struc_x["encoder_out"] |
|
mask = struc_x["encoder_padding_mask"] |
|
|
|
|
|
nan_mask = torch.isnan(x) |
|
if nan_mask.any(): |
|
x = x.masked_fill(nan_mask, 0.0) |
|
|
|
|
|
|
|
x = self.kv_proj(x) |
|
x = self.ln_kv(x) |
|
|
|
b, seqlen = x.shape[:2] |
|
|
|
latents = self.ln_q(self.latents) |
|
if self.perceiver_resampler_positional_emb: |
|
|
|
latents = latents + self.pos_embed[::self.stride].contiguous() |
|
pos_emb = self.pos_embed[:seqlen].unsqueeze(0) |
|
x = x + pos_emb.contiguous() |
|
|
|
|
|
latents = repeat(latents, "n d -> b n d", b=b) |
|
out = self.attn(latents, x, x, key_padding_mask=~mask)[0] |
|
|
|
out = self.ln_post(out) |
|
out = self.proj(out) |
|
|
|
return out |
|
|
|
class StructureTransformer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
width: int = 640, |
|
n_queries: int = 32, |
|
output_dim: int = 4096, |
|
embedding_keys=set(["mpnn_emb"]), |
|
max_seqlen: int=1024, |
|
num_heads: int=8, |
|
structure_emb_path_prefix='structure_emb', |
|
**kwargs |
|
): |
|
super().__init__() |
|
|
|
self.structure_emb_path_prefix = structure_emb_path_prefix |
|
|
|
self.embedding_keys = embedding_keys |
|
self.max_seqlen = max_seqlen |
|
self.width = width |
|
self.n_queries = n_queries |
|
|
|
self.attn_pool = Resampler( |
|
embed_dim=output_dim, |
|
kv_dim=width, |
|
n_queries=n_queries, |
|
max_seqlen=max_seqlen, |
|
num_heads=num_heads, |
|
**kwargs |
|
) |
|
|
|
def prepare_structure(self, sample): |
|
emb_pad = torch.zeros((self.max_seqlen, self.width)) |
|
emb_mask = torch.zeros((self.max_seqlen), dtype=bool) |
|
|
|
if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample: |
|
mask = sample["pifold_mask"] |
|
pifold_emb = sample["pifold_emb"] |
|
new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan")) |
|
new_pifold_emb[mask > 0] = pifold_emb |
|
sample["pifold_emb"] = new_pifold_emb |
|
|
|
|
|
emb = [] |
|
for ek in self.embedding_keys: |
|
if ek in sample: |
|
if isinstance( sample[ek], List): |
|
emb.append(torch.cat(sample[ek])) |
|
else: |
|
emb.append(sample[ek]) |
|
|
|
emb = torch.cat(emb, dim=-1) |
|
|
|
emb_pad[:len(emb)] = emb |
|
emb_mask[:len(emb)] = 1 |
|
return emb_pad, emb_mask |
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.attn_pool(x) |
|
|
|
return x |
|
|
|
def encode(self, structure_paths: List[str]): |
|
structure_embs = [] |
|
structure_mask = [] |
|
|
|
for structure_path in structure_paths: |
|
structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0] |
|
structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path)) |
|
if not os.path.exists(structure_path): |
|
print('no structure found') |
|
return None |
|
|
|
with open(structure_path, 'rb') as f: |
|
structure, struc_mask = self.prepare_structure(pickle.load(f)) |
|
|
|
|
|
structure_embs.append(structure) |
|
structure_mask.append(struc_mask) |
|
|
|
structure_embs = torch.stack(structure_embs, dim=0).to( |
|
device=next(self.attn_pool.parameters()).device, |
|
dtype=next(self.attn_pool.parameters()).dtype) |
|
structure_mask = torch.stack(structure_mask, dim=0).to( |
|
device=next(self.attn_pool.parameters()).device) |
|
|
|
return self({ |
|
'encoder_out': structure_embs, |
|
'encoder_padding_mask': structure_mask |
|
}) |
|
|