DiffLinker / app.py
igashov
fix
7c181a3
raw
history blame
6.2 kB
import gradio as gr
import os
import torch
import subprocess
from rdkit import Chem
from src import const
from src.visualizer import save_xyz_file
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
from src.lightning import DDPM
from src.linker_size_lightning import SizeClassifier
HTML_TEMPLATE = """<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
.mol-container {{
width: 600px;
height: 600px;
position: relative;
}}
.mol-container select{{
background-image:None;
}}
</style>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
$(document).ready(function() {{
let element = $("#container");
let config = {{ backgroundColor: "white" }};
let viewer = $3Dmol.createViewer( element, config );
viewer.addModel(`{molecule}`, "{fmt}")
viewer.getModel().setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
viewer.zoomTo();
viewer.render();
}});
</script>
</body>
</html>
"""
IFRAME_TEMPLATE = """<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("results", exist_ok=True)
os.makedirs("models", exist_ok=True)
subprocess.run(
'wget https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1 -O models/geom_size_gnn.ckpt',
shell=True
)
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
print('Loaded SizeGNN model')
subprocess.run(
'wget https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1 -O models/geom_difflinker.ckpt',
shell=True
)
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
print('Loaded diffusion model')
def sample_fn(_data):
output, _ = size_nn.forward(_data)
probabilities = torch.softmax(output, dim=1)
distribution = torch.distributions.Categorical(probs=probabilities)
samples = distribution.sample()
sizes = []
for label in samples.detach().cpu().numpy():
sizes.append(size_nn.linker_id2size[label])
sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
return sizes
def read_molecule_content(path):
with open(path, "r") as f:
return "".join(f.readlines())
def read_molecule(path):
if path.endswith('.pdb'):
return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True)
elif path.endswith('.mol'):
return Chem.MolFromMolFile(path, sanitize=False, removeHs=True)
elif path.endswith('.mol2'):
return Chem.MolFromMol2File(path, sanitize=False, removeHs=True)
elif path.endswith('.sdf'):
return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0]
raise Exception('Unknown file extension')
def generate(input_file):
try:
path = input_file.name
molecule = read_molecule(path)
name = '.'.join(path.split('/')[-1].split('.')[:-1])
out_sdf = f'results/{name}_generated.sdf'
print(f'Input path={path}, name={name}')
except Exception as e:
return f'Could not read the molecule: {e}'
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
positions = torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device)
one_hot = torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device)
print('Read and parsed molecule')
dataset = [{
'uuid': '0',
'name': '0',
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
'anchors': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
'fragment_mask': torch.ones_like(charges, dtype=const.TORCH_FLOAT, device=device),
'linker_mask': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
'num_atoms': len(positions),
}]
dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
print('Created dataloader')
for data in dataloader:
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
print('Generated linker')
x = chain[0][:, :, :ddpm.n_dims]
h = chain[0][:, :, ddpm.n_dims:]
save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='generated')
print('Saved XYZ file')
subprocess.run(f'obabel results/{name}_generated.xyz -O {out_sdf}', shell=True)
print('Converted to SDF')
break
generated_molecule = read_molecule_content(out_sdf)
html = HTML_TEMPLATE.format(molecule=generated_molecule, fmt='sdf')
return IFRAME_TEMPLATE.format(html=html)
demo = gr.Blocks()
with demo:
gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design')
with gr.Box():
with gr.Row():
with gr.Column():
gr.Markdown('## Input Fragments')
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format')
input_file = gr.File(file_count='single', label='Input fragments')
button = gr.Button('Generate Linker!')
gr.Markdown('')
visualization = gr.HTML()
button.click(
fn=generate,
inputs=[input_file],
outputs=[visualization],
)
demo.launch()