Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch | |
import subprocess | |
from rdkit import Chem | |
from src import const | |
from src.visualizer import save_xyz_file | |
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule | |
from src.lightning import DDPM | |
from src.linker_size_lightning import SizeClassifier | |
HTML_TEMPLATE = """<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<style> | |
.mol-container {{ | |
width: 600px; | |
height: 600px; | |
position: relative; | |
}} | |
.mol-container select{{ | |
background-image:None; | |
}} | |
</style> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div id="container" class="mol-container"></div> | |
<script> | |
$(document).ready(function() {{ | |
let element = $("#container"); | |
let config = {{ backgroundColor: "white" }}; | |
let viewer = $3Dmol.createViewer( element, config ); | |
viewer.addModel(`{molecule}`, "{fmt}") | |
viewer.getModel().setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }}) | |
viewer.zoomTo(); | |
viewer.render(); | |
}}); | |
</script> | |
</body> | |
</html> | |
""" | |
IFRAME_TEMPLATE = """<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
os.makedirs("results", exist_ok=True) | |
os.makedirs("models", exist_ok=True) | |
subprocess.run( | |
'wget https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1 -O models/geom_size_gnn.ckpt', | |
shell=True | |
) | |
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device) | |
print('Loaded SizeGNN model') | |
subprocess.run( | |
'wget https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1 -O models/geom_difflinker.ckpt', | |
shell=True | |
) | |
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device) | |
print('Loaded diffusion model') | |
def sample_fn(_data): | |
output, _ = size_nn.forward(_data) | |
probabilities = torch.softmax(output, dim=1) | |
distribution = torch.distributions.Categorical(probs=probabilities) | |
samples = distribution.sample() | |
sizes = [] | |
for label in samples.detach().cpu().numpy(): | |
sizes.append(size_nn.linker_id2size[label]) | |
sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) | |
return sizes | |
def read_molecule_content(path): | |
with open(path, "r") as f: | |
return "".join(f.readlines()) | |
def read_molecule(path): | |
if path.endswith('.pdb'): | |
return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) | |
elif path.endswith('.mol'): | |
return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) | |
elif path.endswith('.mol2'): | |
return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) | |
elif path.endswith('.sdf'): | |
return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] | |
raise Exception('Unknown file extension') | |
def generate(input_file): | |
try: | |
path = input_file.name | |
molecule = read_molecule(path) | |
name = '.'.join(path.split('/')[-1].split('.')[:-1]) | |
out_sdf = f'results/{name}_generated.sdf' | |
print(f'Input path={path}, name={name}') | |
except Exception as e: | |
return f'Could not read the molecule: {e}' | |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True) | |
positions = torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device) | |
one_hot = torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device) | |
print('Read and parsed molecule') | |
dataset = [{ | |
'uuid': '0', | |
'name': '0', | |
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
'anchors': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device), | |
'fragment_mask': torch.ones_like(charges, dtype=const.TORCH_FLOAT, device=device), | |
'linker_mask': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device), | |
'num_atoms': len(positions), | |
}] | |
dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges) | |
print('Created dataloader') | |
for data in dataloader: | |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) | |
print('Generated linker') | |
x = chain[0][:, :, :ddpm.n_dims] | |
h = chain[0][:, :, ddpm.n_dims:] | |
save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='generated') | |
print('Saved XYZ file') | |
subprocess.run(f'obabel results/{name}_generated.xyz -O {out_sdf}', shell=True) | |
print('Converted to SDF') | |
break | |
generated_molecule = read_molecule_content(out_sdf) | |
html = HTML_TEMPLATE.format(molecule=generated_molecule, fmt='sdf') | |
return IFRAME_TEMPLATE.format(html=html) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design') | |
with gr.Box(): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown('## Input Fragments') | |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format') | |
input_file = gr.File(file_count='single', label='Input fragments') | |
button = gr.Button('Generate Linker!') | |
gr.Markdown('') | |
visualization = gr.HTML() | |
button.click( | |
fn=generate, | |
inputs=[input_file], | |
outputs=[visualization], | |
) | |
demo.launch() | |