Spaces:
Running
on
T4
Running
on
T4
File size: 14,067 Bytes
e1a6cd9 fb853ff e1a6cd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
import json, time, os, sys, glob
import gradio as gr
sys.path.append('/home/user/app/ProteinMPNN/vanilla_proteinmpnn')
import matplotlib.pyplot as plt
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN
import plotly.express as px
import urllib
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
model_name="v_48_020" # ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030, v_32_002, v_32_010; v_32_020, v_32_030; v_48_010=version with 48 edges 0.10A noise
backbone_noise=0.00 # Standard deviation of Gaussian noise to add to backbone atoms
path_to_model_weights='/home/user/app/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights'
hidden_dim = 128
num_layers = 3
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'
checkpoint = torch.load(checkpoint_path, map_location=device)
noise_level_print = checkpoint['noise_level']
model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
import re
import numpy as np
def get_pdb(pdb_code="", filepath=""):
if pdb_code is None or pdb_code == "":
return filepath.name
else:
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
return f"{pdb_code}.pdb"
def update(inp, file,designed_chain, fixed_chain, num_seqs, sampling_temp):
pdb_path =get_pdb(pdb_code=inp, filepath=file)
if designed_chain == "":
designed_chain_list = []
else:
designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")
if fixed_chain == "":
fixed_chain_list = []
else:
fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")
chain_list = list(set(designed_chain_list + fixed_chain_list))
num_seq_per_target = num_seqs
save_score=0 # 0 for False, 1 for True; save score=-log_prob to npy files
save_probs=0 # 0 for False, 1 for True; save MPNN predicted probabilites per position
score_only=0 # 0 for False, 1 for True; score input backbone-sequence pairs
conditional_probs_only=0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)
conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)
batch_size=1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory
max_length=20000 # Max sequence length
out_folder='.' # Path to a folder to output sequences, e.g. /home/out/
jsonl_path='' # Path to a folder with parsed pdb into jsonl
omit_AAs='X' # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.
pssm_multi=0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions
pssm_threshold=0.0 # A value between -inf + inf to restric per position AAs
pssm_log_odds_flag=0 # 0 for False, 1 for True
pssm_bias_flag=0 # 0 for False, 1 for True
folder_for_outputs = out_folder
NUM_BATCHES = num_seq_per_target//batch_size
BATCH_COPIES = batch_size
temperatures = [sampling_temp]
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
chain_id_dict = None
fixed_positions_dict = None
pssm_dict = None
omit_AA_dict = None
bias_AA_dict = None
tied_positions_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))
###############################################################
pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)
with torch.no_grad():
for ix, protein in enumerate(dataset_valid):
score_list = []
all_probs_list = []
all_log_probs_list = []
S_sample_list = []
batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict)
pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
name_ = batch_clones[0]['name']
randn_1 = torch.randn(chain_M.shape, device=X.device)
log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
mask_for_loss = mask*chain_M*chain_M_pos
scores = _scores(S, log_probs, mask_for_loss)
native_score = scores.cpu().data.numpy()
message=""
for temp in temperatures:
for j in range(NUM_BATCHES):
randn_2 = torch.randn(chain_M.shape, device=X.device)
if tied_positions_dict == None:
sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
S_sample = sample_dict["S"]
else:
sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
# Compute scores
S_sample = sample_dict["S"]
log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
mask_for_loss = mask*chain_M*chain_M_pos
scores = _scores(S_sample, log_probs, mask_for_loss)
scores = scores.cpu().data.numpy()
all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
all_log_probs_list.append(log_probs.cpu().data.numpy())
S_sample_list.append(S_sample.cpu().data.numpy())
for b_ix in range(BATCH_COPIES):
masked_chain_length_list = masked_chain_length_list_list[b_ix]
masked_list = masked_list_list[b_ix]
seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
score = scores[b_ix]
score_list.append(score)
native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
if b_ix == 0 and j==0 and temp==temperatures[0]:
start = 0
end = 0
list_of_AAs = []
for mask_l in masked_chain_length_list:
end += mask_l
list_of_AAs.append(native_seq[start:end])
start = end
native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
l0 = 0
for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
l0 += mc_length
native_seq = native_seq[:l0] + '/' + native_seq[l0:]
l0 += 1
sorted_masked_chain_letters = np.argsort(masked_list_list[0])
print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
sorted_visible_chain_letters = np.argsort(visible_list_list[0])
print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq)
message+=f"{line}\n"
start = 0
end = 0
list_of_AAs = []
for mask_l in masked_chain_length_list:
end += mask_l
list_of_AAs.append(seq[start:end])
start = end
seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
l0 = 0
for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
l0 += mc_length
seq = seq[:l0] + '/' + seq[l0:]
l0 += 1
score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
line = '>T={}, sample={}, score={}, seq_recovery={}\n{}\n'.format(temp,b_ix,score_print,seq_rec_print,seq)
message+=f"{line}\n"
all_probs_concat = np.concatenate(all_probs_list)
all_log_probs_concat = np.concatenate(all_log_probs_list)
S_sample_concat = np.concatenate(S_sample_list)
fig = px.imshow(all_probs_concat.mean(0).T,
labels=dict(x="positions", y="amino acids", color="probability"),
y=list(alphabet),
template="simple_white"
)
fig.update_xaxes(side="top")
return message, fig
proteinMPNN = gr.Blocks()
with proteinMPNN:
gr.Markdown("# ProteinMPNN")
gr.Markdown("""Citation: **Robust deep learning based protein sequence design using ProteinMPNN** <br>
Justas Dauparas, Ivan Anishchenko, Nathaniel Bennett, Hua Bai, Robert J. Ragotte, Lukas F. Milles, Basile I. M. Wicky, Alexis Courbet, Robbert J. de Haas, Neville Bethel, Philip J. Y. Leung, Timothy F. Huddy, Sam Pellock, Doug Tischer, Frederick Chan, Brian Koepnick, Hannah Nguyen, Alex Kang, Banumathi Sankaran, Asim Bera, Neil P. King, David Baker <br>
bioRxiv 2022.06.03.494563; doi: [10.1101/2022.06.03.494563](https://doi.org/10.1101/2022.06.03.494563) <br><br> Server built by [@simonduerr](https://twitter.com/simonduerr) and hosted by Huggingface""")
with gr.Tabs():
with gr.TabItem("Input"):
inp = gr.Textbox( placeholder="PDB Code or upload file below", label="Input structure"
)
file = gr.File(file_count="single", type="file")
with gr.TabItem("Settings"):
with gr.Row():
designed_chain = gr.Textbox(value="A", label="Designed chain")
fixed_chain = gr.Textbox(placeholder="Use commas to fix multiple chains", label="Fixed chain")
with gr.Row():
num_seqs = gr.Slider(minimum=1,maximum=50, value=1,step=1, label="Number of sequences")
sampling_temp = gr.Radio(choices=[0.1, 0.15, 0.2, 0.25, 0.3], value=0.1, label="Sampling temperature")
btn = gr.Button("Run")
gr.Markdown(
""" Sampling temperature for amino acids, `T=0.0` means taking argmax, `T>>1.0` means sample randomly. Suggested values `0.1, 0.15, 0.2, 0.25, 0.3`. Higher values will lead to more diversity.
"""
)
gr.Markdown("# Output")
out = gr.Textbox(label="status")
plot = gr.Plot()
btn.click(fn=update, inputs=[inp, file, designed_chain, fixed_chain, num_seqs, sampling_temp], outputs=[out, plot])
proteinMPNN.launch(share=True)
|