Spaces:
Running
on
A10G
Running
on
A10G
Change max batch_size
Browse files- app.py +17 -7
- src/generation.py +3 -2
app.py
CHANGED
@@ -17,6 +17,11 @@ from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
|
|
17 |
from zipfile import ZipFile
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
MODELS_METADATA = {
|
21 |
'geom_difflinker': {
|
22 |
'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1',
|
@@ -329,9 +334,7 @@ def generate_without_pocket(input_file, n_steps, n_atoms, num_samples, selected_
|
|
329 |
|
330 |
for data in dataloader:
|
331 |
try:
|
332 |
-
generate_linkers(
|
333 |
-
ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=False
|
334 |
-
)
|
335 |
except Exception as e:
|
336 |
e = str(e).replace('\'', '')
|
337 |
error = f'Caught exception while generating linkers: {e}'
|
@@ -450,7 +453,8 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
|
|
450 |
dataset = MOADDataset(data=dataset)
|
451 |
ddpm.val_dataset = dataset
|
452 |
|
453 |
-
|
|
|
454 |
print('Created dataloader')
|
455 |
|
456 |
ddpm.edm.T = n_steps
|
@@ -470,10 +474,13 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
|
|
470 |
def sample_fn(_data):
|
471 |
return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
|
472 |
|
473 |
-
for data in dataloader:
|
474 |
try:
|
|
|
475 |
generate_linkers(
|
476 |
-
ddpm=ddpm, data=data,
|
|
|
|
|
477 |
)
|
478 |
except Exception as e:
|
479 |
e = str(e).replace('\'', '')
|
@@ -520,7 +527,10 @@ with demo:
|
|
520 |
gr.Markdown('Upload the file of the target protein in .pdb format (optionally):')
|
521 |
input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)')
|
522 |
|
523 |
-
n_steps = gr.Slider(
|
|
|
|
|
|
|
524 |
n_atoms = gr.Slider(
|
525 |
minimum=0, maximum=20,
|
526 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
|
|
17 |
from zipfile import ZipFile
|
18 |
|
19 |
|
20 |
+
MIN_N_STEPS = 100
|
21 |
+
MAX_N_STEPS = 500
|
22 |
+
MAX_BATCH_SIZE = 5
|
23 |
+
|
24 |
+
|
25 |
MODELS_METADATA = {
|
26 |
'geom_difflinker': {
|
27 |
'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1',
|
|
|
334 |
|
335 |
for data in dataloader:
|
336 |
try:
|
337 |
+
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False)
|
|
|
|
|
338 |
except Exception as e:
|
339 |
e = str(e).replace('\'', '')
|
340 |
error = f'Caught exception while generating linkers: {e}'
|
|
|
453 |
dataset = MOADDataset(data=dataset)
|
454 |
ddpm.val_dataset = dataset
|
455 |
|
456 |
+
batch_size = min(num_samples, MAX_BATCH_SIZE)
|
457 |
+
dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=collate_with_fragment_edges)
|
458 |
print('Created dataloader')
|
459 |
|
460 |
ddpm.edm.T = n_steps
|
|
|
474 |
def sample_fn(_data):
|
475 |
return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
|
476 |
|
477 |
+
for batch_i, data in enumerate(dataloader):
|
478 |
try:
|
479 |
+
offset_idx = batch_i * batch_size
|
480 |
generate_linkers(
|
481 |
+
ddpm=ddpm, data=data,
|
482 |
+
sample_fn=sample_fn, name=name, with_pocket=True,
|
483 |
+
offset_idx=offset_idx,
|
484 |
)
|
485 |
except Exception as e:
|
486 |
e = str(e).replace('\'', '')
|
|
|
527 |
gr.Markdown('Upload the file of the target protein in .pdb format (optionally):')
|
528 |
input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)')
|
529 |
|
530 |
+
n_steps = gr.Slider(
|
531 |
+
minimum=MIN_N_STEPS, maximum=MAX_N_STEPS,
|
532 |
+
label="Number of Denoising Steps", step=10
|
533 |
+
)
|
534 |
n_atoms = gr.Slider(
|
535 |
minimum=0, maximum=20,
|
536 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
src/generation.py
CHANGED
@@ -10,7 +10,7 @@ from src.utils import FoundNaNException
|
|
10 |
from src.datasets import get_one_hot
|
11 |
|
12 |
|
13 |
-
def generate_linkers(ddpm, data,
|
14 |
chain = node_mask = None
|
15 |
for i in range(5):
|
16 |
try:
|
@@ -37,7 +37,8 @@ def generate_linkers(ddpm, data, num_samples, sample_fn, name, with_pocket=False
|
|
37 |
if with_pocket:
|
38 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
39 |
|
40 |
-
|
|
|
41 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
42 |
print('Saved XYZ files')
|
43 |
|
|
|
10 |
from src.datasets import get_one_hot
|
11 |
|
12 |
|
13 |
+
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0):
|
14 |
chain = node_mask = None
|
15 |
for i in range(5):
|
16 |
try:
|
|
|
37 |
if with_pocket:
|
38 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
39 |
|
40 |
+
batch_size = len(data)
|
41 |
+
names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
|
42 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
43 |
print('Saved XYZ files')
|
44 |
|