Spaces:
Running
on
A10G
Running
on
A10G
igashov
commited on
Commit
•
3c26059
1
Parent(s):
aebc0d2
multiple samples
Browse files
app.py
CHANGED
@@ -14,6 +14,8 @@ from src.datasets import get_dataloader, collate_with_fragment_edges, parse_mole
|
|
14 |
from src.lightning import DDPM
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
|
|
|
|
|
17 |
parser = argparse.ArgumentParser()
|
18 |
parser.add_argument('--ip', type=str, default=None)
|
19 |
args = parser.parse_args()
|
@@ -103,10 +105,8 @@ def generate(input_file):
|
|
103 |
molecule = read_molecule(path)
|
104 |
molecule = Chem.RemoveAllHs(molecule)
|
105 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
106 |
-
inp_sdf = f'results/{name}
|
107 |
-
inp_xyz = f'results/{name}
|
108 |
-
out_sdf = f'results/{name}_output.sdf'
|
109 |
-
out_xyz = f'results/{name}_output.xyz'
|
110 |
except Exception as e:
|
111 |
return f'Could not read the molecule: {e}'
|
112 |
|
@@ -133,8 +133,8 @@ def generate(input_file):
|
|
133 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
134 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
135 |
'num_atoms': len(positions),
|
136 |
-
}]
|
137 |
-
dataloader = get_dataloader(dataset, batch_size=
|
138 |
print('Created dataloader')
|
139 |
|
140 |
for data in dataloader:
|
@@ -142,12 +142,21 @@ def generate(input_file):
|
|
142 |
print('Generated linker')
|
143 |
x = chain[0][:, :, :ddpm.n_dims]
|
144 |
h = chain[0][:, :, ddpm.n_dims:]
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
print('Converted to SDF')
|
149 |
break
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
input_fragments_content = read_molecule_content(inp_sdf)
|
152 |
generated_molecule_content = read_molecule_content(out_sdf)
|
153 |
html = output.SAMPLES_RENDERING_TEMPLATE.format(
|
@@ -158,7 +167,7 @@ def generate(input_file):
|
|
158 |
)
|
159 |
return [
|
160 |
output.IFRAME_TEMPLATE.format(html=html),
|
161 |
-
[inp_sdf, inp_xyz
|
162 |
]
|
163 |
|
164 |
|
|
|
14 |
from src.lightning import DDPM
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
|
17 |
+
N_SAMPLES = 5
|
18 |
+
|
19 |
parser = argparse.ArgumentParser()
|
20 |
parser.add_argument('--ip', type=str, default=None)
|
21 |
args = parser.parse_args()
|
|
|
105 |
molecule = read_molecule(path)
|
106 |
molecule = Chem.RemoveAllHs(molecule)
|
107 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
108 |
+
inp_sdf = f'results/input_{name}.sdf'
|
109 |
+
inp_xyz = f'results/input_{name}.xyz'
|
|
|
|
|
110 |
except Exception as e:
|
111 |
return f'Could not read the molecule: {e}'
|
112 |
|
|
|
133 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
134 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
135 |
'num_atoms': len(positions),
|
136 |
+
}] * N_SAMPLES
|
137 |
+
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
138 |
print('Created dataloader')
|
139 |
|
140 |
for data in dataloader:
|
|
|
142 |
print('Generated linker')
|
143 |
x = chain[0][:, :, :ddpm.n_dims]
|
144 |
h = chain[0][:, :, ddpm.n_dims:]
|
145 |
+
names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
|
146 |
+
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
147 |
+
print('Saved XYZ files')
|
|
|
148 |
break
|
149 |
|
150 |
+
out_files = []
|
151 |
+
for i in range(N_SAMPLES):
|
152 |
+
out_xyz = f'results/output_{i+1}_{name}_.xyz'
|
153 |
+
out_sdf = f'results/output_{i+1}_{name}_.sdf'
|
154 |
+
subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
|
155 |
+
out_files.append(out_xyz)
|
156 |
+
out_files.append(out_sdf)
|
157 |
+
print('Converted to SDF')
|
158 |
+
|
159 |
+
out_sdf = f'results/output_1_{name}_.sdf'
|
160 |
input_fragments_content = read_molecule_content(inp_sdf)
|
161 |
generated_molecule_content = read_molecule_content(out_sdf)
|
162 |
html = output.SAMPLES_RENDERING_TEMPLATE.format(
|
|
|
167 |
)
|
168 |
return [
|
169 |
output.IFRAME_TEMPLATE.format(html=html),
|
170 |
+
[inp_sdf, inp_xyz] + out_files,
|
171 |
]
|
172 |
|
173 |
|