Spaces:
Running
on
A10G
Running
on
A10G
Fix size nn for pocket-conditioned generation
Browse files- app.py +2 -2
- src/linker_size_lightning.py +2 -2
app.py
CHANGED
@@ -465,7 +465,7 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
|
|
465 |
|
466 |
if n_atoms == 0:
|
467 |
def sample_fn(_data):
|
468 |
-
out, _ = size_nn.forward(_data, return_loss=False)
|
469 |
probabilities = torch.softmax(out, dim=1)
|
470 |
distribution = torch.distributions.Categorical(probs=probabilities)
|
471 |
samples = distribution.sample()
|
@@ -536,7 +536,7 @@ with demo:
|
|
536 |
label="Number of Denoising Steps", step=10
|
537 |
)
|
538 |
n_atoms = gr.Slider(
|
539 |
-
minimum=0, maximum=20,
|
540 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
541 |
step=1
|
542 |
)
|
|
|
465 |
|
466 |
if n_atoms == 0:
|
467 |
def sample_fn(_data):
|
468 |
+
out, _ = size_nn.forward(_data, return_loss=False, with_pocket=True)
|
469 |
probabilities = torch.softmax(out, dim=1)
|
470 |
distribution = torch.distributions.Categorical(probs=probabilities)
|
471 |
samples = distribution.sample()
|
|
|
536 |
label="Number of Denoising Steps", step=10
|
537 |
)
|
538 |
n_atoms = gr.Slider(
|
539 |
+
minimum=0, maximum=20,
|
540 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
541 |
step=1
|
542 |
)
|
src/linker_size_lightning.py
CHANGED
@@ -79,10 +79,10 @@ class SizeClassifier(pl.LightningModule):
|
|
79 |
def test_dataloader(self):
|
80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
81 |
|
82 |
-
def forward(self, data, return_loss=True):
|
83 |
h = data['one_hot']
|
84 |
x = data['positions']
|
85 |
-
fragment_mask = data['fragment_mask']
|
86 |
linker_mask = data['linker_mask']
|
87 |
edge_mask = data['edge_mask']
|
88 |
edges = data['edges']
|
|
|
79 |
def test_dataloader(self):
|
80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
81 |
|
82 |
+
def forward(self, data, return_loss=True, with_pocket=False):
|
83 |
h = data['one_hot']
|
84 |
x = data['positions']
|
85 |
+
fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
|
86 |
linker_mask = data['linker_mask']
|
87 |
edge_mask = data['edge_mask']
|
88 |
edges = data['edges']
|