Spaces:
Sleeping
Sleeping
updated code
Browse files- app.py +8 -4
- src/datasets.py +114 -9
- src/egnn.py +48 -13
- src/lightning.py +8 -4
- src/linker_size.py +0 -4
- src/linker_size_lightning.py +6 -1
- src/utils.py +14 -0
app.py
CHANGED
@@ -35,12 +35,16 @@ MODELS_METADATA = {
|
|
35 |
'path': 'models/geom_difflinker_given_anchors.ckpt',
|
36 |
},
|
37 |
'pockets_difflinker': {
|
38 |
-
'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1',
|
39 |
-
'path': 'models/pockets_difflinker.ckpt',
|
|
|
|
|
40 |
},
|
41 |
'pockets_difflinker_given_anchors': {
|
42 |
-
'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1',
|
43 |
-
'path': 'models/pockets_difflinker_given_anchors.ckpt',
|
|
|
|
|
44 |
},
|
45 |
}
|
46 |
|
|
|
35 |
'path': 'models/geom_difflinker_given_anchors.ckpt',
|
36 |
},
|
37 |
'pockets_difflinker': {
|
38 |
+
# 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1',
|
39 |
+
# 'path': 'models/pockets_difflinker.ckpt',
|
40 |
+
'link': 'https://zenodo.org/records/10988017/files/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt?download=1',
|
41 |
+
'path': 'models/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt',
|
42 |
},
|
43 |
'pockets_difflinker_given_anchors': {
|
44 |
+
# 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1',
|
45 |
+
# 'path': 'models/pockets_difflinker_given_anchors.ckpt',
|
46 |
+
'link': 'https://zenodo.org/records/10988017/files/pockets_difflinker_full_fc_pdb_excluded.ckpt?download=1',
|
47 |
+
'path': 'models/pockets_difflinker_full_fc_pdb_excluded.ckpt',
|
48 |
},
|
49 |
}
|
50 |
|
src/datasets.py
CHANGED
@@ -148,6 +148,15 @@ class MOADDataset(Dataset):
|
|
148 |
total=len(table)
|
149 |
)
|
150 |
for (_, row), fragments, linker, pocket_data in generator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
uuid = row['uuid']
|
152 |
name = row['molecule']
|
153 |
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
@@ -212,16 +221,112 @@ class MOADDataset(Dataset):
|
|
212 |
|
213 |
return data
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
@staticmethod
|
216 |
-
def
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
-
|
224 |
-
|
|
|
|
|
225 |
|
226 |
|
227 |
def collate(batch):
|
@@ -231,7 +336,7 @@ def collate(batch):
|
|
231 |
# if 'pocket_mask' not in batch[0].keys():
|
232 |
# batch = [data for data in batch if data['num_atoms'] <= 50]
|
233 |
# else:
|
234 |
-
#
|
235 |
|
236 |
for i, data in enumerate(batch):
|
237 |
for key, value in data.items():
|
|
|
148 |
total=len(table)
|
149 |
)
|
150 |
for (_, row), fragments, linker, pocket_data in generator:
|
151 |
+
pdb = row['molecule_name'].split('_')[0]
|
152 |
+
if pdb in {
|
153 |
+
'5ou2', '5ou3', '6hay',
|
154 |
+
'5mo8', '5mo5', '5mo7', '5ctp', '5cu2', '5cu4', '5mmr', '5mmf',
|
155 |
+
'5moe', '3iw7', '4i9n', '3fi2', '3fi3',
|
156 |
+
}:
|
157 |
+
print(f'Skipping pdb={pdb}')
|
158 |
+
continue
|
159 |
+
|
160 |
uuid = row['uuid']
|
161 |
name = row['molecule']
|
162 |
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
|
|
221 |
|
222 |
return data
|
223 |
|
224 |
+
|
225 |
+
class OptimisedMOADDataset(MOADDataset):
|
226 |
+
# TODO: finish testing
|
227 |
+
|
228 |
+
def __len__(self):
|
229 |
+
return len(self.data['fragmentation_level_data'])
|
230 |
+
|
231 |
+
def __getitem__(self, item):
|
232 |
+
fragmentation_level_data = self.data['fragmentation_level_data'][item]
|
233 |
+
protein_level_data = self.data['protein_level_data'][fragmentation_level_data['name']]
|
234 |
+
return {
|
235 |
+
**fragmentation_level_data,
|
236 |
+
**protein_level_data,
|
237 |
+
}
|
238 |
+
|
239 |
@staticmethod
|
240 |
+
def preprocess(data_path, prefix, pocket_mode, device):
|
241 |
+
print('Preprocessing optimised version of the dataset')
|
242 |
+
protein_level_data = {}
|
243 |
+
fragmentation_level_data = []
|
244 |
+
|
245 |
+
table_path = os.path.join(data_path, f'{prefix}_table.csv')
|
246 |
+
fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf')
|
247 |
+
linkers_path = os.path.join(data_path, f'{prefix}_link.sdf')
|
248 |
+
pockets_path = os.path.join(data_path, f'{prefix}_pockets.pkl')
|
249 |
+
|
250 |
+
is_geom = True
|
251 |
+
is_multifrag = 'multifrag' in prefix
|
252 |
+
|
253 |
+
with open(pockets_path, 'rb') as f:
|
254 |
+
pockets = pickle.load(f)
|
255 |
+
|
256 |
+
table = pd.read_csv(table_path)
|
257 |
+
generator = tqdm(
|
258 |
+
zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path), pockets),
|
259 |
+
total=len(table)
|
260 |
+
)
|
261 |
+
for (_, row), fragments, linker, pocket_data in generator:
|
262 |
+
uuid = row['uuid']
|
263 |
+
name = row['molecule']
|
264 |
+
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
265 |
+
link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom)
|
266 |
+
|
267 |
+
# Parsing pocket data
|
268 |
+
pocket_pos = pocket_data[f'{pocket_mode}_coord']
|
269 |
+
pocket_one_hot = []
|
270 |
+
pocket_charges = []
|
271 |
+
for atom_type in pocket_data[f'{pocket_mode}_types']:
|
272 |
+
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
|
273 |
+
pocket_charges.append(const.GEOM_CHARGES[atom_type])
|
274 |
+
pocket_one_hot = np.array(pocket_one_hot)
|
275 |
+
pocket_charges = np.array(pocket_charges)
|
276 |
+
|
277 |
+
positions = np.concatenate([frag_pos, pocket_pos, link_pos], axis=0)
|
278 |
+
one_hot = np.concatenate([frag_one_hot, pocket_one_hot, link_one_hot], axis=0)
|
279 |
+
charges = np.concatenate([frag_charges, pocket_charges, link_charges], axis=0)
|
280 |
+
anchors = np.zeros_like(charges)
|
281 |
+
|
282 |
+
if is_multifrag:
|
283 |
+
for anchor_idx in map(int, row['anchors'].split('-')):
|
284 |
+
anchors[anchor_idx] = 1
|
285 |
+
else:
|
286 |
+
anchors[row['anchor_1']] = 1
|
287 |
+
anchors[row['anchor_2']] = 1
|
288 |
+
|
289 |
+
fragment_only_mask = np.concatenate([
|
290 |
+
np.ones_like(frag_charges),
|
291 |
+
np.zeros_like(pocket_charges),
|
292 |
+
np.zeros_like(link_charges)
|
293 |
+
])
|
294 |
+
pocket_mask = np.concatenate([
|
295 |
+
np.zeros_like(frag_charges),
|
296 |
+
np.ones_like(pocket_charges),
|
297 |
+
np.zeros_like(link_charges)
|
298 |
+
])
|
299 |
+
linker_mask = np.concatenate([
|
300 |
+
np.zeros_like(frag_charges),
|
301 |
+
np.zeros_like(pocket_charges),
|
302 |
+
np.ones_like(link_charges)
|
303 |
+
])
|
304 |
+
fragment_mask = np.concatenate([
|
305 |
+
np.ones_like(frag_charges),
|
306 |
+
np.ones_like(pocket_charges),
|
307 |
+
np.zeros_like(link_charges)
|
308 |
+
])
|
309 |
+
|
310 |
+
fragmentation_level_data.append({
|
311 |
+
'uuid': uuid,
|
312 |
+
'name': name,
|
313 |
+
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
|
314 |
+
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
|
315 |
+
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
|
316 |
+
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
317 |
+
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
318 |
+
})
|
319 |
+
protein_level_data[name] = {
|
320 |
+
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
321 |
+
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
322 |
+
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
323 |
+
'num_atoms': len(positions),
|
324 |
+
}
|
325 |
|
326 |
+
return {
|
327 |
+
'fragmentation_level_data': fragmentation_level_data,
|
328 |
+
'protein_level_data': protein_level_data,
|
329 |
+
}
|
330 |
|
331 |
|
332 |
def collate(batch):
|
|
|
336 |
# if 'pocket_mask' not in batch[0].keys():
|
337 |
# batch = [data for data in batch if data['num_atoms'] <= 50]
|
338 |
# else:
|
339 |
+
# batch = [data for data in batch if data['num_atoms'] <= 1000]
|
340 |
|
341 |
for i, data in enumerate(batch):
|
342 |
for key, value in data.items():
|
src/egnn.py
CHANGED
@@ -315,7 +315,7 @@ class Dynamics(nn.Module):
|
|
315 |
self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
|
316 |
n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
|
317 |
sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
|
318 |
-
normalization=None, centering=False,
|
319 |
):
|
320 |
super().__init__()
|
321 |
self.device = device
|
@@ -324,6 +324,7 @@ class Dynamics(nn.Module):
|
|
324 |
self.condition_time = condition_time
|
325 |
self.model = model
|
326 |
self.centering = centering
|
|
|
327 |
|
328 |
in_node_nf = in_node_nf + context_node_nf + condition_time
|
329 |
if self.model == 'egnn_dynamics':
|
@@ -369,6 +370,8 @@ class Dynamics(nn.Module):
|
|
369 |
- context: (B, N, C)
|
370 |
"""
|
371 |
|
|
|
|
|
372 |
bs, n_nodes = xh.shape[0], xh.shape[1]
|
373 |
edges = self.get_edges(n_nodes, bs) # (2, B*N)
|
374 |
node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
@@ -421,16 +424,6 @@ class Dynamics(nn.Module):
|
|
421 |
if self.condition_time:
|
422 |
h_final = h_final[:, :-1]
|
423 |
|
424 |
-
if torch.any(torch.isnan(vel)):
|
425 |
-
print('Found NaN values in velocities')
|
426 |
-
nan_mask = torch.isnan(vel).float()
|
427 |
-
vel = x * nan_mask + torch.nan_to_num(vel) * (1 - nan_mask)
|
428 |
-
|
429 |
-
if torch.any(torch.isnan(h_final)):
|
430 |
-
print('Found NaN values in features')
|
431 |
-
nan_mask = torch.isnan(h_final).float()
|
432 |
-
h_final = h[:, :h_final.shape[1]] * nan_mask + torch.nan_to_num(h_final) * (1 - nan_mask)
|
433 |
-
|
434 |
vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
|
435 |
h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
|
436 |
node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
|
@@ -477,12 +470,21 @@ class DynamicsWithPockets(Dynamics):
|
|
477 |
if linker_mask is not None:
|
478 |
linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
479 |
|
|
|
|
|
|
|
|
|
480 |
# Reshaping node features & adding time feature
|
481 |
xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
|
482 |
x = xh[:, :self.n_dims].clone() # (B*N, 3)
|
483 |
h = xh[:, self.n_dims:].clone() # (B*N, nf)
|
484 |
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
486 |
if self.condition_time:
|
487 |
if np.prod(t.size()) == 1:
|
488 |
# t is the same for all elements in batch.
|
@@ -537,7 +539,7 @@ class DynamicsWithPockets(Dynamics):
|
|
537 |
return torch.cat([vel, h_final], dim=2)
|
538 |
|
539 |
@staticmethod
|
540 |
-
def
|
541 |
node_mask = node_mask.squeeze().bool()
|
542 |
batch_adj = (batch_mask[:, None] == batch_mask[None, :])
|
543 |
nodes_adj = (node_mask[:, None] & node_mask[None, :])
|
@@ -546,3 +548,36 @@ class DynamicsWithPockets(Dynamics):
|
|
546 |
adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
|
547 |
edges = torch.stack(torch.where(adj))
|
548 |
return edges
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
|
316 |
n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
|
317 |
sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
|
318 |
+
normalization=None, centering=False, graph_type='FC',
|
319 |
):
|
320 |
super().__init__()
|
321 |
self.device = device
|
|
|
324 |
self.condition_time = condition_time
|
325 |
self.model = model
|
326 |
self.centering = centering
|
327 |
+
self.graph_type = graph_type
|
328 |
|
329 |
in_node_nf = in_node_nf + context_node_nf + condition_time
|
330 |
if self.model == 'egnn_dynamics':
|
|
|
370 |
- context: (B, N, C)
|
371 |
"""
|
372 |
|
373 |
+
assert self.graph_type == 'FC'
|
374 |
+
|
375 |
bs, n_nodes = xh.shape[0], xh.shape[1]
|
376 |
edges = self.get_edges(n_nodes, bs) # (2, B*N)
|
377 |
node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
|
|
424 |
if self.condition_time:
|
425 |
h_final = h_final[:, :-1]
|
426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
|
428 |
h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
|
429 |
node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
|
|
|
470 |
if linker_mask is not None:
|
471 |
linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
472 |
|
473 |
+
fragment_only_mask = context[..., -2].view(bs * n_nodes, 1) # (B*N, 1)
|
474 |
+
pocket_only_mask = context[..., -1].view(bs * n_nodes, 1) # (B*N, 1)
|
475 |
+
assert torch.all(fragment_only_mask.bool() | pocket_only_mask.bool() | linker_mask.bool() == node_mask.bool())
|
476 |
+
|
477 |
# Reshaping node features & adding time feature
|
478 |
xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
|
479 |
x = xh[:, :self.n_dims].clone() # (B*N, 3)
|
480 |
h = xh[:, self.n_dims:].clone() # (B*N, nf)
|
481 |
|
482 |
+
assert self.graph_type in ['4A', 'FC-4A', 'FC-10A-4A']
|
483 |
+
if self.graph_type == '4A' or self.graph_type is None:
|
484 |
+
edges = self.get_dist_edges_4A(x, node_mask, edge_mask)
|
485 |
+
else:
|
486 |
+
edges = self.get_dist_edges(x, node_mask, edge_mask, linker_mask, fragment_only_mask, pocket_only_mask)
|
487 |
+
|
488 |
if self.condition_time:
|
489 |
if np.prod(t.size()) == 1:
|
490 |
# t is the same for all elements in batch.
|
|
|
539 |
return torch.cat([vel, h_final], dim=2)
|
540 |
|
541 |
@staticmethod
|
542 |
+
def get_dist_edges_4A(x, node_mask, batch_mask):
|
543 |
node_mask = node_mask.squeeze().bool()
|
544 |
batch_adj = (batch_mask[:, None] == batch_mask[None, :])
|
545 |
nodes_adj = (node_mask[:, None] & node_mask[None, :])
|
|
|
548 |
adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
|
549 |
edges = torch.stack(torch.where(adj))
|
550 |
return edges
|
551 |
+
|
552 |
+
def get_dist_edges(self, x, node_mask, batch_mask, linker_mask, fragment_only_mask, pocket_only_mask):
|
553 |
+
node_mask = node_mask.squeeze().bool()
|
554 |
+
linker_mask = linker_mask.squeeze().bool() & node_mask
|
555 |
+
fragment_only_mask = fragment_only_mask.squeeze().bool() & node_mask
|
556 |
+
pocket_only_mask = pocket_only_mask.squeeze().bool() & node_mask
|
557 |
+
ligand_mask = linker_mask | fragment_only_mask
|
558 |
+
|
559 |
+
# General constrains:
|
560 |
+
batch_adj = (batch_mask[:, None] == batch_mask[None, :])
|
561 |
+
nodes_adj = (node_mask[:, None] & node_mask[None, :])
|
562 |
+
rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
|
563 |
+
constraints = batch_adj & nodes_adj & rm_self_loops
|
564 |
+
|
565 |
+
# Ligand atoms – fully-connected graph
|
566 |
+
ligand_adj = (ligand_mask[:, None] & ligand_mask[None, :])
|
567 |
+
ligand_interactions = ligand_adj & constraints
|
568 |
+
|
569 |
+
# Pocket atoms - within 4A
|
570 |
+
pocket_adj = (pocket_only_mask[:, None] & pocket_only_mask[None, :])
|
571 |
+
pocket_dists_adj = (torch.cdist(x, x) <= 4)
|
572 |
+
pocket_interactions = pocket_adj & pocket_dists_adj & constraints
|
573 |
+
|
574 |
+
# Pocket-ligand atoms - within 10A
|
575 |
+
pocket_ligand_cutoff = 4 if self.graph_type == 'FC-4A' else 10
|
576 |
+
pocket_ligand_adj = (ligand_mask[:, None] & pocket_only_mask[None, :])
|
577 |
+
pocket_ligand_adj = pocket_ligand_adj | (pocket_only_mask[:, None] & ligand_mask[None, :])
|
578 |
+
pocket_ligand_dists_adj = (torch.cdist(x, x) <= pocket_ligand_cutoff)
|
579 |
+
pocket_ligand_interactions = pocket_ligand_adj & pocket_ligand_dists_adj & constraints
|
580 |
+
|
581 |
+
adj = ligand_interactions | pocket_interactions | pocket_ligand_interactions
|
582 |
+
edges = torch.stack(torch.where(adj))
|
583 |
+
return edges
|
src/lightning.py
CHANGED
@@ -44,7 +44,7 @@ class DDPM(pl.LightningModule):
|
|
44 |
normalize_factors, include_charges, model,
|
45 |
data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
|
46 |
normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
|
47 |
-
center_of_mass='fragments', inpainting=False, anchors_context=True,
|
48 |
):
|
49 |
super(DDPM, self).__init__()
|
50 |
|
@@ -54,7 +54,7 @@ class DDPM(pl.LightningModule):
|
|
54 |
self.val_data_prefix = val_data_prefix
|
55 |
self.batch_size = batch_size
|
56 |
self.lr = lr
|
57 |
-
self.torch_device =
|
58 |
self.include_charges = include_charges
|
59 |
self.test_epochs = test_epochs
|
60 |
self.n_stability_samples = n_stability_samples
|
@@ -72,6 +72,9 @@ class DDPM(pl.LightningModule):
|
|
72 |
|
73 |
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
|
74 |
|
|
|
|
|
|
|
75 |
if type(activation) is str:
|
76 |
activation = get_activation(activation)
|
77 |
|
@@ -80,7 +83,7 @@ class DDPM(pl.LightningModule):
|
|
80 |
in_node_nf=in_node_nf,
|
81 |
n_dims=n_dims,
|
82 |
context_node_nf=context_node_nf,
|
83 |
-
device=
|
84 |
hidden_nf=hidden_nf,
|
85 |
activation=activation,
|
86 |
n_layers=n_layers,
|
@@ -94,6 +97,7 @@ class DDPM(pl.LightningModule):
|
|
94 |
model=model,
|
95 |
normalization=normalization,
|
96 |
centering=inpainting,
|
|
|
97 |
)
|
98 |
edm_class = InpaintingEDM if inpainting else EDM
|
99 |
self.edm = edm_class(
|
@@ -424,7 +428,7 @@ class DDPM(pl.LightningModule):
|
|
424 |
context = fragment_mask
|
425 |
|
426 |
# Add information about pocket to the context
|
427 |
-
if
|
428 |
fragment_pocket_mask = fragment_mask
|
429 |
fragment_only_mask = template_data['fragment_only_mask']
|
430 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
|
|
44 |
normalize_factors, include_charges, model,
|
45 |
data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
|
46 |
normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
|
47 |
+
center_of_mass='fragments', inpainting=False, anchors_context=True, graph_type=None,
|
48 |
):
|
49 |
super(DDPM, self).__init__()
|
50 |
|
|
|
54 |
self.val_data_prefix = val_data_prefix
|
55 |
self.batch_size = batch_size
|
56 |
self.lr = lr
|
57 |
+
self.torch_device = torch_device
|
58 |
self.include_charges = include_charges
|
59 |
self.test_epochs = test_epochs
|
60 |
self.n_stability_samples = n_stability_samples
|
|
|
72 |
|
73 |
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
|
74 |
|
75 |
+
if graph_type is None:
|
76 |
+
graph_type = '4A' if '.' in train_data_prefix else 'FC'
|
77 |
+
|
78 |
if type(activation) is str:
|
79 |
activation = get_activation(activation)
|
80 |
|
|
|
83 |
in_node_nf=in_node_nf,
|
84 |
n_dims=n_dims,
|
85 |
context_node_nf=context_node_nf,
|
86 |
+
device=torch_device,
|
87 |
hidden_nf=hidden_nf,
|
88 |
activation=activation,
|
89 |
n_layers=n_layers,
|
|
|
97 |
model=model,
|
98 |
normalization=normalization,
|
99 |
centering=inpainting,
|
100 |
+
graph_type=graph_type,
|
101 |
)
|
102 |
edm_class = InpaintingEDM if inpainting else EDM
|
103 |
self.edm = edm_class(
|
|
|
428 |
context = fragment_mask
|
429 |
|
430 |
# Add information about pocket to the context
|
431 |
+
if '.' in self.train_data_prefix:
|
432 |
fragment_pocket_mask = fragment_mask
|
433 |
fragment_only_mask = template_data['fragment_only_mask']
|
434 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
src/linker_size.py
CHANGED
@@ -21,10 +21,6 @@ class DistributionNodes:
|
|
21 |
prob = prob/np.sum(prob)
|
22 |
|
23 |
self.prob = torch.from_numpy(prob).float()
|
24 |
-
|
25 |
-
entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30))
|
26 |
-
print("Entropy of n_nodes: H[N]", entropy.item())
|
27 |
-
|
28 |
self.m = Categorical(torch.tensor(prob))
|
29 |
|
30 |
def sample(self, n_samples=1):
|
|
|
21 |
prob = prob/np.sum(prob)
|
22 |
|
23 |
self.prob = torch.from_numpy(prob).float()
|
|
|
|
|
|
|
|
|
24 |
self.m = Categorical(torch.tensor(prob))
|
25 |
|
26 |
def sample(self, n_samples=1):
|
src/linker_size_lightning.py
CHANGED
@@ -40,6 +40,7 @@ class SizeClassifier(pl.LightningModule):
|
|
40 |
self.lr = lr
|
41 |
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
|
|
|
43 |
self.gnn = SizeGNN(
|
44 |
in_node_nf=in_node_nf,
|
45 |
hidden_nf=hidden_nf,
|
@@ -79,7 +80,7 @@ class SizeClassifier(pl.LightningModule):
|
|
79 |
def test_dataloader(self):
|
80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
81 |
|
82 |
-
def forward(self, data, return_loss=True, with_pocket=False):
|
83 |
h = data['one_hot']
|
84 |
x = data['positions']
|
85 |
fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
|
@@ -91,6 +92,10 @@ class SizeClassifier(pl.LightningModule):
|
|
91 |
x = x * fragment_mask
|
92 |
h = h * fragment_mask
|
93 |
|
|
|
|
|
|
|
|
|
94 |
# Reshaping
|
95 |
bs, n_nodes = x.shape[0], x.shape[1]
|
96 |
fragment_mask = fragment_mask.view(bs * n_nodes, 1)
|
|
|
40 |
self.lr = lr
|
41 |
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
|
43 |
+
self.in_node_nf = in_node_nf
|
44 |
self.gnn = SizeGNN(
|
45 |
in_node_nf=in_node_nf,
|
46 |
hidden_nf=hidden_nf,
|
|
|
80 |
def test_dataloader(self):
|
81 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
82 |
|
83 |
+
def forward(self, data, return_loss=True, with_pocket=False, adjust_shape=False):
|
84 |
h = data['one_hot']
|
85 |
x = data['positions']
|
86 |
fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
|
|
|
92 |
x = x * fragment_mask
|
93 |
h = h * fragment_mask
|
94 |
|
95 |
+
if h.shape[-1] != self.in_node_nf and adjust_shape:
|
96 |
+
assert torch.allclose(h[..., -1], torch.zeros_like(h[..., -1]))
|
97 |
+
h = h[..., :-1]
|
98 |
+
|
99 |
# Reshaping
|
100 |
bs, n_nodes = x.shape[0], x.shape[1]
|
101 |
fragment_mask = fragment_mask.view(bs * n_nodes, 1)
|
src/utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import sys
|
|
|
2 |
from datetime import datetime
|
3 |
|
4 |
import torch
|
@@ -21,9 +22,11 @@ class Logger(object):
|
|
21 |
# you might want to specify some extra behavior here.
|
22 |
pass
|
23 |
|
|
|
24 |
def log(*args):
|
25 |
print(f'[{datetime.now()}]', *args)
|
26 |
|
|
|
27 |
class EMA:
|
28 |
def __init__(self, beta):
|
29 |
super().__init__()
|
@@ -257,6 +260,17 @@ def disable_rdkit_logging():
|
|
257 |
rkrb.DisableLog('rdApp.error')
|
258 |
|
259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
class FoundNaNException(Exception):
|
261 |
def __init__(self, x, h):
|
262 |
x_nan_idx = self.find_nan_idx(x)
|
|
|
1 |
import sys
|
2 |
+
import random
|
3 |
from datetime import datetime
|
4 |
|
5 |
import torch
|
|
|
22 |
# you might want to specify some extra behavior here.
|
23 |
pass
|
24 |
|
25 |
+
|
26 |
def log(*args):
|
27 |
print(f'[{datetime.now()}]', *args)
|
28 |
|
29 |
+
|
30 |
class EMA:
|
31 |
def __init__(self, beta):
|
32 |
super().__init__()
|
|
|
260 |
rkrb.DisableLog('rdApp.error')
|
261 |
|
262 |
|
263 |
+
def set_deterministic(seed):
|
264 |
+
random.seed(seed)
|
265 |
+
np.random.seed(seed)
|
266 |
+
torch.manual_seed(seed)
|
267 |
+
if torch.cuda.is_available():
|
268 |
+
torch.cuda.manual_seed_all(seed)
|
269 |
+
|
270 |
+
torch.backends.cudnn.deterministic = True
|
271 |
+
torch.backends.cudnn.benchmark = False
|
272 |
+
|
273 |
+
|
274 |
class FoundNaNException(Exception):
|
275 |
def __init__(self, x, h):
|
276 |
x_nan_idx = self.find_nan_idx(x)
|