igashov commited on
Commit
d8600ba
1 Parent(s): ff512d8

Change max batch_size

Browse files
Files changed (3) hide show
  1. app.py +8 -4
  2. src/datasets.py +47 -0
  3. src/generation.py +1 -1
app.py CHANGED
@@ -10,7 +10,11 @@ import output
10
 
11
  from rdkit import Chem
12
  from src import const
13
- from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
 
 
 
 
14
  from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
  from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
@@ -19,7 +23,7 @@ from zipfile import ZipFile
19
 
20
  MIN_N_STEPS = 100
21
  MAX_N_STEPS = 500
22
- MAX_BATCH_SIZE = 5
23
 
24
 
25
  MODELS_METADATA = {
@@ -454,7 +458,7 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
454
  ddpm.val_dataset = dataset
455
 
456
  batch_size = min(num_samples, MAX_BATCH_SIZE)
457
- dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=collate_with_fragment_edges)
458
  print('Created dataloader')
459
 
460
  ddpm.edm.T = n_steps
@@ -532,7 +536,7 @@ with demo:
532
  label="Number of Denoising Steps", step=10
533
  )
534
  n_atoms = gr.Slider(
535
- minimum=0, maximum=20,
536
  label="Linker Size: DiffLinker will predict it if set to 0",
537
  step=1
538
  )
 
10
 
11
  from rdkit import Chem
12
  from src import const
13
+ from src.datasets import (
14
+ get_dataloader, collate_with_fragment_edges,
15
+ collate_with_fragment_without_pocket_edges,
16
+ parse_molecule, MOADDataset
17
+ )
18
  from src.lightning import DDPM
19
  from src.linker_size_lightning import SizeClassifier
20
  from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
 
23
 
24
  MIN_N_STEPS = 100
25
  MAX_N_STEPS = 500
26
+ MAX_BATCH_SIZE = 20
27
 
28
 
29
  MODELS_METADATA = {
 
458
  ddpm.val_dataset = dataset
459
 
460
  batch_size = min(num_samples, MAX_BATCH_SIZE)
461
+ dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=collate_with_fragment_without_pocket_edges)
462
  print('Created dataloader')
463
 
464
  ddpm.edm.T = n_steps
 
536
  label="Number of Denoising Steps", step=10
537
  )
538
  n_atoms = gr.Slider(
539
+ minimum=0, maximum=20, value=5,
540
  label="Linker Size: DiffLinker will predict it if set to 0",
541
  step=1
542
  )
src/datasets.py CHANGED
@@ -317,6 +317,53 @@ def collate_with_fragment_edges(batch):
317
  return out
318
 
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  def get_dataloader(dataset, batch_size, collate_fn=collate, shuffle=False):
321
  return DataLoader(dataset, batch_size, collate_fn=collate_fn, shuffle=shuffle)
322
 
 
317
  return out
318
 
319
 
320
+ def collate_with_fragment_without_pocket_edges(batch):
321
+ out = {}
322
+
323
+ # Filter out big molecules
324
+ # batch = [data for data in batch if data['num_atoms'] <= 50]
325
+
326
+ for i, data in enumerate(batch):
327
+ for key, value in data.items():
328
+ out.setdefault(key, []).append(value)
329
+
330
+ for key, value in out.items():
331
+ if key in const.DATA_LIST_ATTRS:
332
+ continue
333
+ if key in const.DATA_ATTRS_TO_PAD:
334
+ out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0)
335
+ continue
336
+ raise Exception(f'Unknown batch key: {key}')
337
+
338
+ frag_mask = out['fragment_only_mask']
339
+ edge_mask = frag_mask[:, None, :] * frag_mask[:, :, None]
340
+ diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=frag_mask.device).unsqueeze(0)
341
+ edge_mask *= diag_mask
342
+
343
+ batch_size, n_nodes = frag_mask.size()
344
+ out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)
345
+
346
+ # Building edges and covalent bond values
347
+ rows, cols, bonds = [], [], []
348
+ for batch_idx in range(batch_size):
349
+ for i in range(n_nodes):
350
+ for j in range(n_nodes):
351
+ rows.append(i + batch_idx * n_nodes)
352
+ cols.append(j + batch_idx * n_nodes)
353
+
354
+ edges = [torch.LongTensor(rows).to(frag_mask.device), torch.LongTensor(cols).to(frag_mask.device)]
355
+ out['edges'] = edges
356
+
357
+ atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT)
358
+ out['atom_mask'] = atom_mask[:, :, None]
359
+
360
+ for key in const.DATA_ATTRS_TO_ADD_LAST_DIM:
361
+ if key in out.keys():
362
+ out[key] = out[key][:, :, None]
363
+
364
+ return out
365
+
366
+
367
  def get_dataloader(dataset, batch_size, collate_fn=collate, shuffle=False):
368
  return DataLoader(dataset, batch_size, collate_fn=collate_fn, shuffle=shuffle)
369
 
src/generation.py CHANGED
@@ -37,7 +37,7 @@ def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=
37
  if with_pocket:
38
  node_mask[torch.where(data['pocket_mask'])] = 0
39
 
40
- batch_size = len(data)
41
  names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
42
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
43
  print('Saved XYZ files')
 
37
  if with_pocket:
38
  node_mask[torch.where(data['pocket_mask'])] = 0
39
 
40
+ batch_size = len(data['positions'])
41
  names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
42
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
43
  print('Saved XYZ files')