import argparse import shutil import gradio as gr import numpy as np import os import torch import output from rdkit import Chem from src import const from src.datasets import ( get_dataloader, collate_with_fragment_edges, collate_with_fragment_without_pocket_edges, parse_molecule, MOADDataset ) from src.lightning import DDPM from src.linker_size_lightning import SizeClassifier from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket from zipfile import ZipFile MIN_N_STEPS = 100 MAX_N_STEPS = 500 MAX_BATCH_SIZE = 20 MODELS_METADATA = { 'geom_difflinker': { 'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1', 'path': 'models/geom_difflinker.ckpt', }, 'geom_difflinker_given_anchors': { 'link': 'https://zenodo.org/record/7775568/files/geom_difflinker_given_anchors.ckpt?download=1', 'path': 'models/geom_difflinker_given_anchors.ckpt', }, 'pockets_difflinker': { 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1', 'path': 'models/pockets_difflinker.ckpt', }, 'pockets_difflinker_given_anchors': { 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1', 'path': 'models/pockets_difflinker_given_anchors.ckpt', }, } parser = argparse.ArgumentParser() parser.add_argument('--ip', type=str, default=None) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Device: {device}') os.makedirs("results", exist_ok=True) size_gnn_path = 'models/geom_size_gnn.ckpt' size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device) print('Loaded SizeGNN model') diffusion_models = {} for model_name, metadata in MODELS_METADATA.items(): diffusion_path = metadata['path'] diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device) print(f'Loaded model {model_name}') print(os.curdir) print(os.path.abspath(os.curdir)) print(os.listdir(os.curdir)) def read_molecule_content(path): with open(path, "r") as f: return "".join(f.readlines()) def read_molecule(path): if path.endswith('.pdb'): return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) elif path.endswith('.mol'): return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) elif path.endswith('.mol2'): return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) elif path.endswith('.sdf'): return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] raise Exception('Unknown file extension') def read_molecule_file(in_file, allowed_extentions): if isinstance(in_file, str): path = in_file else: path = in_file.name extension = path.split('.')[-1] if extension not in allowed_extentions: msg = output.INVALID_FORMAT_MSG.format(extension=extension) return None, None, msg try: mol = read_molecule(path) except Exception as e: e = str(e).replace('\'', '') msg = output.ERROR_FORMAT_MSG.format(message=e) return None, None, msg if extension == 'pdb': content = Chem.MolToPDBBlock(mol) elif extension in ['mol', 'mol2', 'sdf']: content = Chem.MolToMolBlock(mol, kekulize=False) extension = 'mol' else: raise NotImplementedError return content, extension, None def show_input(in_fragments, in_protein): vis = '' if in_fragments is not None and in_protein is None: vis = show_fragments(in_fragments) elif in_fragments is None and in_protein is not None: vis = show_target(in_protein) elif in_fragments is not None and in_protein is not None: vis = show_fragments_and_target(in_fragments, in_protein) return [vis, gr.Dropdown.update(choices=[], value=None, visible=False), None] def show_fragments(in_fragments): molecule, extension, html = read_molecule_file(in_fragments, allowed_extentions=['sdf', 'pdb', 'mol', 'mol2']) if molecule is not None: html = output.FRAGMENTS_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension) return output.IFRAME_TEMPLATE.format(html=html) def show_target(in_protein): molecule, extension, html = read_molecule_file(in_protein, allowed_extentions=['pdb']) if molecule is not None: html = output.TARGET_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension) return output.IFRAME_TEMPLATE.format(html=html) def show_fragments_and_target(in_fragments, in_protein): fragments_molecule, fragments_extension, msg = read_molecule_file(in_fragments, ['sdf', 'pdb', 'mol', 'mol2']) if fragments_molecule is None: return output.IFRAME_TEMPLATE.format(html=msg) target_molecule, target_extension, msg = read_molecule_file(in_protein, allowed_extentions=['pdb']) if fragments_molecule is None: return output.IFRAME_TEMPLATE.format(html=msg) html = output.FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE.format( molecule=fragments_molecule, fmt=fragments_extension, target=target_molecule, target_fmt=target_extension, ) return output.IFRAME_TEMPLATE.format(html=html) def clear_fragments_input(in_protein): vis = '' if in_protein is not None: vis = show_target(in_protein) return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None] def clear_protein_input(in_fragments): vis = '' if in_fragments is not None: vis = show_fragments(in_fragments) return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None] def click_on_example(example): fragment_fname, target_fname = example fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None target_path = f'examples/{target_fname}' if target_fname != '' else None return [fragment_path, target_path] + show_input(fragment_path, target_path) def draw_sample(sample_path, out_files, num_samples): with_protein = (len(out_files) == num_samples + 3) in_file = out_files[1] in_sdf = in_file if isinstance(in_file, str) else in_file.name input_fragments_content = read_molecule_content(in_sdf) fragments_fmt = in_sdf.split('.')[-1] offset = 2 input_target_content = None target_fmt = None if with_protein: offset += 1 in_pdb = out_files[2] if isinstance(out_files[2], str) else out_files[2].name input_target_content = read_molecule_content(in_pdb) target_fmt = in_pdb.split('.')[-1] out_sdf = sample_path if isinstance(sample_path, str) else sample_path.name generated_molecule_content = read_molecule_content(out_sdf) molecule_fmt = out_sdf.split('.')[-1] if with_protein: html = output.SAMPLES_WITH_TARGET_RENDERING_TEMPLATE.format( fragments=input_fragments_content, fragments_fmt=fragments_fmt, molecule=generated_molecule_content, molecule_fmt=molecule_fmt, target=input_target_content, target_fmt=target_fmt, ) else: html = output.SAMPLES_RENDERING_TEMPLATE.format( fragments=input_fragments_content, fragments_fmt=fragments_fmt, molecule=generated_molecule_content, molecule_fmt=molecule_fmt, ) return output.IFRAME_TEMPLATE.format(html=html) def compress(output_fnames, name): archive_path = f'results/all_files_{name}.zip' with ZipFile(archive_path, 'w') as archive: for fname in output_fnames: archive.write(fname) return archive_path def generate(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms): if in_fragments is None: return [None, None, None, None] if in_protein is None: return generate_without_pocket(in_fragments, n_steps, n_atoms, num_samples, selected_atoms) else: return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms) def generate_without_pocket(input_file, n_steps, n_atoms, num_samples, selected_atoms): # Parsing selected atoms (javascript output) selected_atoms = selected_atoms.strip() if selected_atoms == '': selected_atoms = [] else: selected_atoms = list(map(int, selected_atoms.split(','))) # Selecting model if len(selected_atoms) == 0: selected_model_name = 'geom_difflinker' else: selected_model_name = 'geom_difflinker_given_anchors' print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms) ddpm = diffusion_models[selected_model_name] path = input_file.name extension = path.split('.')[-1] if extension not in ['sdf', 'pdb', 'mol', 'mol2']: msg = output.INVALID_FORMAT_MSG.format(extension=extension) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] try: molecule = read_molecule(path) try: molecule = Chem.RemoveAllHs(molecule) except: pass name = '.'.join(path.split('/')[-1].split('.')[:-1]) inp_sdf = f'results/input_{name}.sdf' except Exception as e: e = str(e).replace('\'', '') error = f'Could not read the molecule: {e}' msg = output.ERROR_FORMAT_MSG.format(message=error) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] if molecule.GetNumAtoms() > 100: error = f'Too large molecule: upper limit is 100 heavy atoms' msg = output.ERROR_FORMAT_MSG.format(message=error) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] with Chem.SDWriter(inp_sdf) as w: w.SetKekulize(False) w.write(molecule) positions, one_hot, charges = parse_molecule(molecule, is_geom=True) anchors = np.zeros_like(charges) anchors[selected_atoms] = 1 fragment_mask = np.ones_like(charges) linker_mask = np.zeros_like(charges) print('Read and parsed molecule') dataset = [{ 'uuid': '0', 'name': '0', 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), 'num_atoms': len(positions), }] * num_samples dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges) print('Created dataloader') ddpm.edm.T = n_steps if n_atoms == 0: def sample_fn(_data): out, _ = size_nn.forward(_data, return_loss=False) probabilities = torch.softmax(out, dim=1) distribution = torch.distributions.Categorical(probs=probabilities) samples = distribution.sample() sizes = [] for label in samples.detach().cpu().numpy(): sizes.append(size_nn.linker_id2size[label]) sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) return sizes else: def sample_fn(_data): return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms for data in dataloader: try: generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False) except Exception as e: e = str(e).replace('\'', '') error = f'Caught exception while generating linkers: {e}' msg = output.ERROR_FORMAT_MSG.format(message=error) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] out_files = try_to_convert_to_sdf(name, num_samples) out_files = [inp_sdf] + out_files out_files = [compress(out_files, name=name)] + out_files choice = out_files[2] return [ draw_sample(choice, out_files, num_samples), out_files, gr.Dropdown.update( choices=out_files[2:], value=choice, visible=True, ), None ] def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms): # Parsing selected atoms (javascript output) selected_atoms = selected_atoms.strip() if selected_atoms == '': selected_atoms = [] else: selected_atoms = list(map(int, selected_atoms.split(','))) # Selecting model if len(selected_atoms) == 0: selected_model_name = 'pockets_difflinker' else: selected_model_name = 'pockets_difflinker_given_anchors' print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms) ddpm = diffusion_models[selected_model_name] fragments_path = in_fragments.name fragments_extension = fragments_path.split('.')[-1] if fragments_extension not in ['sdf', 'pdb', 'mol', 'mol2']: msg = output.INVALID_FORMAT_MSG.format(extension=fragments_extension) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] protein_path = in_protein.name protein_extension = protein_path.split('.')[-1] if protein_extension not in ['pdb']: msg = output.INVALID_FORMAT_MSG.format(extension=protein_extension) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] try: fragments_mol = read_molecule(fragments_path) name = '.'.join(fragments_path.split('/')[-1].split('.')[:-1]) except Exception as e: e = str(e).replace('\'', '') error = f'Could not read the molecule: {e}' msg = output.ERROR_FORMAT_MSG.format(message=error) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] if fragments_mol.GetNumAtoms() > 100: error = f'Too large molecule: upper limit is 100 heavy atoms' msg = output.ERROR_FORMAT_MSG.format(message=error) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] inp_sdf = f'results/input_{name}.sdf' with Chem.SDWriter(inp_sdf) as w: w.SetKekulize(False) w.write(fragments_mol) inp_pdb = f'results/target_{name}.pdb' shutil.copy(protein_path, inp_pdb) frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments_mol, is_geom=True) pocket_pos, pocket_one_hot, pocket_charges = get_pocket(fragments_mol, protein_path) print(f'Detected pocket with {len(pocket_pos)} atoms') positions = np.concatenate([frag_pos, pocket_pos], axis=0) one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0) charges = np.concatenate([frag_charges, pocket_charges], axis=0) anchors = np.zeros_like(charges) anchors[selected_atoms] = 1 fragment_only_mask = np.concatenate([ np.ones_like(frag_charges), np.zeros_like(pocket_charges), ]) pocket_mask = np.concatenate([ np.zeros_like(frag_charges), np.ones_like(pocket_charges), ]) linker_mask = np.concatenate([ np.zeros_like(frag_charges), np.zeros_like(pocket_charges), ]) fragment_mask = np.concatenate([ np.ones_like(frag_charges), np.ones_like(pocket_charges), ]) print('Read and parsed molecule') dataset = [{ 'uuid': '0', 'name': '0', 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device), 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device), 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), 'num_atoms': len(positions), }] * num_samples dataset = MOADDataset(data=dataset) ddpm.val_dataset = dataset batch_size = min(num_samples, MAX_BATCH_SIZE) dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=collate_with_fragment_without_pocket_edges) print('Created dataloader') ddpm.edm.T = n_steps if n_atoms == 0: def sample_fn(_data): out, _ = size_nn.forward(_data, return_loss=False, with_pocket=True) probabilities = torch.softmax(out, dim=1) distribution = torch.distributions.Categorical(probs=probabilities) samples = distribution.sample() sizes = [] for label in samples.detach().cpu().numpy(): sizes.append(size_nn.linker_id2size[label]) sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) return sizes else: def sample_fn(_data): return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms for batch_i, data in enumerate(dataloader): try: offset_idx = batch_i * batch_size generate_linkers( ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=True, offset_idx=offset_idx, ) except Exception as e: e = str(e).replace('\'', '') error = f'Caught exception while generating linkers: {e}' msg = output.ERROR_FORMAT_MSG.format(message=error) return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] out_files = try_to_convert_to_sdf(name, num_samples) out_files = [inp_sdf, inp_pdb] + out_files out_files = [compress(out_files, name=name)] + out_files choice = out_files[3] return [ draw_sample(choice, out_files, num_samples), out_files, gr.Dropdown.update( choices=out_files[3:], value=choice, visible=True, ), None ] demo = gr.Blocks() with demo: gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design') gr.Markdown( 'Given a set of disconnected fragments in 3D, ' 'DiffLinker places missing atoms in between and designs a molecule incorporating all the initial fragments. ' 'Our method can link an arbitrary number of fragments, requires no information on the attachment atoms ' 'and linker size, and can be conditioned on the protein pockets.' ) gr.Markdown( '[**[Paper]**](https://arxiv.org/abs/2210.05274) ' '[**[Code]**](https://github.com/igashov/DiffLinker)' ) with gr.Box(): with gr.Row(): with gr.Column(): gr.Markdown('## Input') gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:') input_fragments_file = gr.File(file_count='single', label='Input Fragments') gr.Markdown('Upload the file of the target protein in .pdb format (optionally):') input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)') n_steps = gr.Slider( minimum=MIN_N_STEPS, maximum=MAX_N_STEPS, label="Number of Denoising Steps", step=10 ) n_atoms = gr.Slider( minimum=0, maximum=20, label="Linker Size: DiffLinker will predict it if set to 0", step=1 ) n_samples = gr.Slider(minimum=5, maximum=50, label="Number of Samples", step=5) examples = gr.Dataset( components=[gr.File(visible=False), gr.File(visible=False)], samples=[ ['examples/example_1.sdf', ''], ['examples/example_2.sdf', ''], ['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'], ['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'], ], type='values', headers=['Input Fragments', 'Target Protein'], ) button = gr.Button('Generate Linker!') gr.Markdown('') gr.Markdown('## Output Files') gr.Markdown('Download files with the generated molecules here:') output_files = gr.File(file_count='multiple', label='Output Files', interactive=False) hidden = gr.Textbox(visible=False) with gr.Column(): gr.Markdown('## Visualization') gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)') samples = gr.Dropdown( choices=[], value=None, type='value', multiselect=False, visible=False, interactive=True, label='Samples' ) visualization = gr.HTML() input_fragments_file.change( fn=show_input, inputs=[input_fragments_file, input_protein_file], outputs=[visualization, samples, hidden], ) input_protein_file.change( fn=show_input, inputs=[input_fragments_file, input_protein_file], outputs=[visualization, samples, hidden], ) input_fragments_file.clear( fn=clear_fragments_input, inputs=[input_protein_file], outputs=[input_fragments_file, visualization, samples, hidden], ) input_protein_file.clear( fn=clear_protein_input, inputs=[input_fragments_file], outputs=[input_protein_file, visualization, samples, hidden], ) examples.click( fn=click_on_example, inputs=[examples], outputs=[input_fragments_file, input_protein_file, visualization, samples, hidden] ) button.click( fn=generate, inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, n_samples, hidden], outputs=[visualization, output_files, samples, hidden], _js=output.RETURN_SELECTION_JS, ) samples.select( fn=draw_sample, inputs=[samples, output_files, n_samples], outputs=[visualization], ) demo.load(_js=output.STARTUP_JS) demo.launch(server_name=args.ip)