import json, time, os, sys, glob import gradio as gr sys.path.append('/home/user/app/ProteinMPNN/vanilla_proteinmpnn') import matplotlib.pyplot as plt import shutil import warnings import numpy as np import torch from torch import optim from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split, Subset import copy import torch.nn as nn import torch.nn.functional as F import random import os.path from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN import plotly.express as px import urllib device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") model_name="v_48_020" # ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030, v_32_002, v_32_010; v_32_020, v_32_030; v_48_010=version with 48 edges 0.10A noise backbone_noise=0.00 # Standard deviation of Gaussian noise to add to backbone atoms path_to_model_weights='/home/user/app/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights' hidden_dim = 128 num_layers = 3 model_folder_path = path_to_model_weights if model_folder_path[-1] != '/': model_folder_path = model_folder_path + '/' checkpoint_path = model_folder_path + f'{model_name}.pt' checkpoint = torch.load(checkpoint_path, map_location=device) noise_level_print = checkpoint['noise_level'] model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges']) model.to(device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() import re import numpy as np def get_pdb(pdb_code="", filepath=""): if pdb_code is None or pdb_code == "": return filepath.name else: os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") return f"{pdb_code}.pdb" def update(inp, file,designed_chain, fixed_chain, num_seqs, sampling_temp): pdb_path =get_pdb(pdb_code=inp, filepath=file) if designed_chain == "": designed_chain_list = [] else: designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",") if fixed_chain == "": fixed_chain_list = [] else: fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",") chain_list = list(set(designed_chain_list + fixed_chain_list)) num_seq_per_target = num_seqs save_score=0 # 0 for False, 1 for True; save score=-log_prob to npy files save_probs=0 # 0 for False, 1 for True; save MPNN predicted probabilites per position score_only=0 # 0 for False, 1 for True; score input backbone-sequence pairs conditional_probs_only=0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone) conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone) batch_size=1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory max_length=20000 # Max sequence length out_folder='.' # Path to a folder to output sequences, e.g. /home/out/ jsonl_path='' # Path to a folder with parsed pdb into jsonl omit_AAs='X' # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine. pssm_multi=0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions pssm_threshold=0.0 # A value between -inf + inf to restric per position AAs pssm_log_odds_flag=0 # 0 for False, 1 for True pssm_bias_flag=0 # 0 for False, 1 for True folder_for_outputs = out_folder NUM_BATCHES = num_seq_per_target//batch_size BATCH_COPIES = batch_size temperatures = [sampling_temp] omit_AAs_list = omit_AAs alphabet = 'ACDEFGHIKLMNPQRSTVWYX' omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32) chain_id_dict = None fixed_positions_dict = None pssm_dict = None omit_AA_dict = None bias_AA_dict = None tied_positions_dict = None bias_by_res_dict = None bias_AAs_np = np.zeros(len(alphabet)) ############################################################### pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list) dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length) chain_id_dict = {} chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list) with torch.no_grad(): for ix, protein in enumerate(dataset_valid): score_list = [] all_probs_list = [] all_log_probs_list = [] S_sample_list = [] batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)] X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict) pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false name_ = batch_clones[0]['name'] randn_1 = torch.randn(chain_M.shape, device=X.device) log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1) mask_for_loss = mask*chain_M*chain_M_pos scores = _scores(S, log_probs, mask_for_loss) native_score = scores.cpu().data.numpy() message="" for temp in temperatures: for j in range(NUM_BATCHES): randn_2 = torch.randn(chain_M.shape, device=X.device) if tied_positions_dict == None: sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all) S_sample = sample_dict["S"] else: sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all) # Compute scores S_sample = sample_dict["S"] log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"]) mask_for_loss = mask*chain_M*chain_M_pos scores = _scores(S_sample, log_probs, mask_for_loss) scores = scores.cpu().data.numpy() all_probs_list.append(sample_dict["probs"].cpu().data.numpy()) all_log_probs_list.append(log_probs.cpu().data.numpy()) S_sample_list.append(S_sample.cpu().data.numpy()) for b_ix in range(BATCH_COPIES): masked_chain_length_list = masked_chain_length_list_list[b_ix] masked_list = masked_list_list[b_ix] seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix]) seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix]) score = scores[b_ix] score_list.append(score) native_seq = _S_to_seq(S[b_ix], chain_M[b_ix]) if b_ix == 0 and j==0 and temp==temperatures[0]: start = 0 end = 0 list_of_AAs = [] for mask_l in masked_chain_length_list: end += mask_l list_of_AAs.append(native_seq[start:end]) start = end native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)])) l0 = 0 for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]: l0 += mc_length native_seq = native_seq[:l0] + '/' + native_seq[l0:] l0 += 1 sorted_masked_chain_letters = np.argsort(masked_list_list[0]) print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters] sorted_visible_chain_letters = np.argsort(visible_list_list[0]) print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters] native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4) line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq) message+=f"{line}\n" start = 0 end = 0 list_of_AAs = [] for mask_l in masked_chain_length_list: end += mask_l list_of_AAs.append(seq[start:end]) start = end seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)])) l0 = 0 for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]: l0 += mc_length seq = seq[:l0] + '/' + seq[l0:] l0 += 1 score_print = np.format_float_positional(np.float32(score), unique=False, precision=4) seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4) line = '>T={}, sample={}, score={}, seq_recovery={}\n{}\n'.format(temp,b_ix,score_print,seq_rec_print,seq) message+=f"{line}\n" all_probs_concat = np.concatenate(all_probs_list) all_log_probs_concat = np.concatenate(all_log_probs_list) S_sample_concat = np.concatenate(S_sample_list) fig = px.imshow(all_probs_concat.mean(0).T, labels=dict(x="positions", y="amino acids", color="probability"), y=list(alphabet), template="simple_white" ) fig.update_xaxes(side="top") return message, fig proteinMPNN = gr.Blocks() with proteinMPNN: gr.Markdown("# ProteinMPNN") gr.Markdown("""Citation: **Robust deep learning based protein sequence design using ProteinMPNN**
Justas Dauparas, Ivan Anishchenko, Nathaniel Bennett, Hua Bai, Robert J. Ragotte, Lukas F. Milles, Basile I. M. Wicky, Alexis Courbet, Robbert J. de Haas, Neville Bethel, Philip J. Y. Leung, Timothy F. Huddy, Sam Pellock, Doug Tischer, Frederick Chan, Brian Koepnick, Hannah Nguyen, Alex Kang, Banumathi Sankaran, Asim Bera, Neil P. King, David Baker
bioRxiv 2022.06.03.494563; doi: [10.1101/2022.06.03.494563](https://doi.org/10.1101/2022.06.03.494563)

Server built by [@simonduerr](https://twitter.com/simonduerr) and hosted by Huggingface""") with gr.Tabs(): with gr.TabItem("Input"): inp = gr.Textbox( placeholder="PDB Code or upload file below", label="Input structure" ) file = gr.File(file_count="single", type="file") with gr.TabItem("Settings"): with gr.Row(): designed_chain = gr.Textbox(value="A", label="Designed chain") fixed_chain = gr.Textbox(placeholder="Use commas to fix multiple chains", label="Fixed chain") with gr.Row(): num_seqs = gr.Slider(minimum=1,maximum=50, value=1,step=1, label="Number of sequences") sampling_temp = gr.Radio(choices=[0.1, 0.15, 0.2, 0.25, 0.3], value=0.1, label="Sampling temperature") btn = gr.Button("Run") gr.Markdown( """ Sampling temperature for amino acids, `T=0.0` means taking argmax, `T>>1.0` means sample randomly. Suggested values `0.1, 0.15, 0.2, 0.25, 0.3`. Higher values will lead to more diversity. """ ) gr.Markdown("# Output") out = gr.Textbox(label="status") plot = gr.Plot() btn.click(fn=update, inputs=[inp, file, designed_chain, fixed_chain, num_seqs, sampling_temp], outputs=[out, plot]) proteinMPNN.launch(share=True)