igashov commited on
Commit
4d1ca7b
1 Parent(s): d8600ba

Fix size nn for pocket-conditioned generation

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. src/linker_size_lightning.py +2 -2
app.py CHANGED
@@ -465,7 +465,7 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
465
 
466
  if n_atoms == 0:
467
  def sample_fn(_data):
468
- out, _ = size_nn.forward(_data, return_loss=False)
469
  probabilities = torch.softmax(out, dim=1)
470
  distribution = torch.distributions.Categorical(probs=probabilities)
471
  samples = distribution.sample()
@@ -536,7 +536,7 @@ with demo:
536
  label="Number of Denoising Steps", step=10
537
  )
538
  n_atoms = gr.Slider(
539
- minimum=0, maximum=20, value=5,
540
  label="Linker Size: DiffLinker will predict it if set to 0",
541
  step=1
542
  )
 
465
 
466
  if n_atoms == 0:
467
  def sample_fn(_data):
468
+ out, _ = size_nn.forward(_data, return_loss=False, with_pocket=True)
469
  probabilities = torch.softmax(out, dim=1)
470
  distribution = torch.distributions.Categorical(probs=probabilities)
471
  samples = distribution.sample()
 
536
  label="Number of Denoising Steps", step=10
537
  )
538
  n_atoms = gr.Slider(
539
+ minimum=0, maximum=20,
540
  label="Linker Size: DiffLinker will predict it if set to 0",
541
  step=1
542
  )
src/linker_size_lightning.py CHANGED
@@ -79,10 +79,10 @@ class SizeClassifier(pl.LightningModule):
79
  def test_dataloader(self):
80
  return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
81
 
82
- def forward(self, data, return_loss=True):
83
  h = data['one_hot']
84
  x = data['positions']
85
- fragment_mask = data['fragment_mask']
86
  linker_mask = data['linker_mask']
87
  edge_mask = data['edge_mask']
88
  edges = data['edges']
 
79
  def test_dataloader(self):
80
  return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
81
 
82
+ def forward(self, data, return_loss=True, with_pocket=False):
83
  h = data['one_hot']
84
  x = data['positions']
85
+ fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
86
  linker_mask = data['linker_mask']
87
  edge_mask = data['edge_mask']
88
  edges = data['edges']