Spaces:
Sleeping
Sleeping
Change max batch_size
Browse files- app.py +8 -4
- src/datasets.py +47 -0
- src/generation.py +1 -1
app.py
CHANGED
@@ -10,7 +10,11 @@ import output
|
|
10 |
|
11 |
from rdkit import Chem
|
12 |
from src import const
|
13 |
-
from src.datasets import
|
|
|
|
|
|
|
|
|
14 |
from src.lightning import DDPM
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
|
@@ -19,7 +23,7 @@ from zipfile import ZipFile
|
|
19 |
|
20 |
MIN_N_STEPS = 100
|
21 |
MAX_N_STEPS = 500
|
22 |
-
MAX_BATCH_SIZE =
|
23 |
|
24 |
|
25 |
MODELS_METADATA = {
|
@@ -454,7 +458,7 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
|
|
454 |
ddpm.val_dataset = dataset
|
455 |
|
456 |
batch_size = min(num_samples, MAX_BATCH_SIZE)
|
457 |
-
dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=
|
458 |
print('Created dataloader')
|
459 |
|
460 |
ddpm.edm.T = n_steps
|
@@ -532,7 +536,7 @@ with demo:
|
|
532 |
label="Number of Denoising Steps", step=10
|
533 |
)
|
534 |
n_atoms = gr.Slider(
|
535 |
-
minimum=0, maximum=20,
|
536 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
537 |
step=1
|
538 |
)
|
|
|
10 |
|
11 |
from rdkit import Chem
|
12 |
from src import const
|
13 |
+
from src.datasets import (
|
14 |
+
get_dataloader, collate_with_fragment_edges,
|
15 |
+
collate_with_fragment_without_pocket_edges,
|
16 |
+
parse_molecule, MOADDataset
|
17 |
+
)
|
18 |
from src.lightning import DDPM
|
19 |
from src.linker_size_lightning import SizeClassifier
|
20 |
from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
|
|
|
23 |
|
24 |
MIN_N_STEPS = 100
|
25 |
MAX_N_STEPS = 500
|
26 |
+
MAX_BATCH_SIZE = 20
|
27 |
|
28 |
|
29 |
MODELS_METADATA = {
|
|
|
458 |
ddpm.val_dataset = dataset
|
459 |
|
460 |
batch_size = min(num_samples, MAX_BATCH_SIZE)
|
461 |
+
dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=collate_with_fragment_without_pocket_edges)
|
462 |
print('Created dataloader')
|
463 |
|
464 |
ddpm.edm.T = n_steps
|
|
|
536 |
label="Number of Denoising Steps", step=10
|
537 |
)
|
538 |
n_atoms = gr.Slider(
|
539 |
+
minimum=0, maximum=20, value=5,
|
540 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
541 |
step=1
|
542 |
)
|
src/datasets.py
CHANGED
@@ -317,6 +317,53 @@ def collate_with_fragment_edges(batch):
|
|
317 |
return out
|
318 |
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
def get_dataloader(dataset, batch_size, collate_fn=collate, shuffle=False):
|
321 |
return DataLoader(dataset, batch_size, collate_fn=collate_fn, shuffle=shuffle)
|
322 |
|
|
|
317 |
return out
|
318 |
|
319 |
|
320 |
+
def collate_with_fragment_without_pocket_edges(batch):
|
321 |
+
out = {}
|
322 |
+
|
323 |
+
# Filter out big molecules
|
324 |
+
# batch = [data for data in batch if data['num_atoms'] <= 50]
|
325 |
+
|
326 |
+
for i, data in enumerate(batch):
|
327 |
+
for key, value in data.items():
|
328 |
+
out.setdefault(key, []).append(value)
|
329 |
+
|
330 |
+
for key, value in out.items():
|
331 |
+
if key in const.DATA_LIST_ATTRS:
|
332 |
+
continue
|
333 |
+
if key in const.DATA_ATTRS_TO_PAD:
|
334 |
+
out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0)
|
335 |
+
continue
|
336 |
+
raise Exception(f'Unknown batch key: {key}')
|
337 |
+
|
338 |
+
frag_mask = out['fragment_only_mask']
|
339 |
+
edge_mask = frag_mask[:, None, :] * frag_mask[:, :, None]
|
340 |
+
diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=frag_mask.device).unsqueeze(0)
|
341 |
+
edge_mask *= diag_mask
|
342 |
+
|
343 |
+
batch_size, n_nodes = frag_mask.size()
|
344 |
+
out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)
|
345 |
+
|
346 |
+
# Building edges and covalent bond values
|
347 |
+
rows, cols, bonds = [], [], []
|
348 |
+
for batch_idx in range(batch_size):
|
349 |
+
for i in range(n_nodes):
|
350 |
+
for j in range(n_nodes):
|
351 |
+
rows.append(i + batch_idx * n_nodes)
|
352 |
+
cols.append(j + batch_idx * n_nodes)
|
353 |
+
|
354 |
+
edges = [torch.LongTensor(rows).to(frag_mask.device), torch.LongTensor(cols).to(frag_mask.device)]
|
355 |
+
out['edges'] = edges
|
356 |
+
|
357 |
+
atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT)
|
358 |
+
out['atom_mask'] = atom_mask[:, :, None]
|
359 |
+
|
360 |
+
for key in const.DATA_ATTRS_TO_ADD_LAST_DIM:
|
361 |
+
if key in out.keys():
|
362 |
+
out[key] = out[key][:, :, None]
|
363 |
+
|
364 |
+
return out
|
365 |
+
|
366 |
+
|
367 |
def get_dataloader(dataset, batch_size, collate_fn=collate, shuffle=False):
|
368 |
return DataLoader(dataset, batch_size, collate_fn=collate_fn, shuffle=shuffle)
|
369 |
|
src/generation.py
CHANGED
@@ -37,7 +37,7 @@ def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=
|
|
37 |
if with_pocket:
|
38 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
39 |
|
40 |
-
batch_size = len(data)
|
41 |
names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
|
42 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
43 |
print('Saved XYZ files')
|
|
|
37 |
if with_pocket:
|
38 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
39 |
|
40 |
+
batch_size = len(data['positions'])
|
41 |
names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
|
42 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
43 |
print('Saved XYZ files')
|