Spaces:
Running
on
A10G
Running
on
A10G
igashov
commited on
Commit
•
95ba5bc
1
Parent(s):
e361c7a
DiffLinker code
Browse files- app.py +83 -7
- src/__init__.py +0 -0
- src/const.py +218 -0
- src/datasets.py +350 -0
- src/delinker.py +278 -0
- src/delinker_utils/__init__.py +0 -0
- src/delinker_utils/calc_SC_RDKit.py +40 -0
- src/delinker_utils/frag_utils.py +413 -0
- src/delinker_utils/sascorer.py +173 -0
- src/edm.py +730 -0
- src/egnn.py +541 -0
- src/lightning.py +473 -0
- src/linker_size.py +95 -0
- src/linker_size_lightning.py +460 -0
- src/metrics.py +167 -0
- src/molecule_builder.py +102 -0
- src/noise.py +169 -0
- src/utils.py +348 -0
- src/visualizer.py +227 -0
app.py
CHANGED
@@ -1,6 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
HTML_TEMPLATE = """<!DOCTYPE html>
|
@@ -43,20 +51,88 @@ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
|
43 |
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
|
44 |
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
with open(path, "r") as f:
|
48 |
return "".join(f.readlines())
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def generate(input_file):
|
52 |
try:
|
53 |
path = input_file.name
|
54 |
molecule = read_molecule(path)
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
return IFRAME_TEMPLATE.format(html=html)
|
61 |
|
62 |
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
+
import torch
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
from rdkit import Chem
|
7 |
+
from src import const
|
8 |
+
from src.visualizer import save_xyz_file
|
9 |
+
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
|
10 |
+
from src.lightning import DDPM
|
11 |
+
from src.linker_size_lightning import SizeClassifier
|
12 |
|
13 |
|
14 |
HTML_TEMPLATE = """<!DOCTYPE html>
|
|
|
51 |
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
|
52 |
|
53 |
|
54 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
+
os.makedirs("results", exist_ok=True)
|
56 |
+
print('Created results directory')
|
57 |
+
|
58 |
+
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
|
59 |
+
print('Loaded SizeGNN model')
|
60 |
+
|
61 |
+
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
|
62 |
+
print('Loaded diffusion model')
|
63 |
+
|
64 |
+
|
65 |
+
def sample_fn(_data):
|
66 |
+
output, _ = size_nn.forward(_data)
|
67 |
+
probabilities = torch.softmax(output, dim=1)
|
68 |
+
distribution = torch.distributions.Categorical(probs=probabilities)
|
69 |
+
samples = distribution.sample()
|
70 |
+
sizes = []
|
71 |
+
for label in samples.detach().cpu().numpy():
|
72 |
+
sizes.append(size_nn.linker_id2size[label])
|
73 |
+
sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
|
74 |
+
return sizes
|
75 |
+
|
76 |
+
|
77 |
+
def read_molecule_content(path):
|
78 |
with open(path, "r") as f:
|
79 |
return "".join(f.readlines())
|
80 |
|
81 |
|
82 |
+
def read_molecule(path):
|
83 |
+
if path.endswith('.pdb'):
|
84 |
+
return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True)
|
85 |
+
elif path.endswith('.mol'):
|
86 |
+
return Chem.MolFromMolFile(path, sanitize=False, removeHs=True)
|
87 |
+
elif path.endswith('.mol2'):
|
88 |
+
return Chem.MolFromMol2File(path, sanitize=False, removeHs=True)
|
89 |
+
elif path.endswith('.sdf'):
|
90 |
+
return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0]
|
91 |
+
raise Exception('Unknown file extension')
|
92 |
+
|
93 |
+
|
94 |
def generate(input_file):
|
95 |
try:
|
96 |
path = input_file.name
|
97 |
molecule = read_molecule(path)
|
98 |
+
name = '.'.join(molecule.split('/')[-1].split('.')[:-1])
|
99 |
+
out_sdf = f'results/{name}_generated.sdf'
|
100 |
+
print(f'Input path={path}, name={name}')
|
101 |
+
except Exception as e:
|
102 |
+
return f'Could not read the molecule: {e}'
|
103 |
+
|
104 |
+
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
|
105 |
+
positions = torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device)
|
106 |
+
one_hot = torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device)
|
107 |
+
print('Read and parsed molecule')
|
108 |
+
|
109 |
+
dataset = [{
|
110 |
+
'uuid': '0',
|
111 |
+
'name': '0',
|
112 |
+
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
113 |
+
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
114 |
+
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
115 |
+
'anchors': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
|
116 |
+
'fragment_mask': torch.ones_like(charges, dtype=const.TORCH_FLOAT, device=device),
|
117 |
+
'linker_mask': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
|
118 |
+
'num_atoms': len(positions),
|
119 |
+
}]
|
120 |
+
dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
|
121 |
+
print('Created dataloader')
|
122 |
+
|
123 |
+
for data in dataloader:
|
124 |
+
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
125 |
+
print('Generated linker')
|
126 |
+
x = chain[0][:, :, :ddpm.n_dims]
|
127 |
+
h = chain[0][:, :, ddpm.n_dims:]
|
128 |
+
save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='generated')
|
129 |
+
print('Saved XYZ file')
|
130 |
+
subprocess.run(f'obabel results/{name}_generated.xyz -O {out_sdf}', shell=True)
|
131 |
+
print('Converted to SDF')
|
132 |
+
break
|
133 |
+
|
134 |
+
generated_molecule = read_molecule_content(out_sdf)
|
135 |
+
html = HTML_TEMPLATE.format(molecule=generated_molecule, fmt='sdf')
|
136 |
return IFRAME_TEMPLATE.format(html=html)
|
137 |
|
138 |
|
src/__init__.py
ADDED
File without changes
|
src/const.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from rdkit import Chem
|
4 |
+
|
5 |
+
|
6 |
+
TORCH_FLOAT = torch.float32
|
7 |
+
TORCH_INT = torch.int8
|
8 |
+
|
9 |
+
# #################################################################################### #
|
10 |
+
# ####################################### ZINC ####################################### #
|
11 |
+
# #################################################################################### #
|
12 |
+
|
13 |
+
# Atom idx for one-hot encoding
|
14 |
+
ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7}
|
15 |
+
IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I'}
|
16 |
+
|
17 |
+
# Atomic numbers (Z)
|
18 |
+
CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53}
|
19 |
+
|
20 |
+
# One-hot atom types
|
21 |
+
NUMBER_OF_ATOM_TYPES = len(ATOM2IDX)
|
22 |
+
|
23 |
+
|
24 |
+
# #################################################################################### #
|
25 |
+
# ####################################### GEOM ####################################### #
|
26 |
+
# #################################################################################### #
|
27 |
+
|
28 |
+
# Atom idx for one-hot encoding
|
29 |
+
GEOM_ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7, 'P': 8}
|
30 |
+
GEOM_IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I', 8: 'P'}
|
31 |
+
|
32 |
+
# Atomic numbers (Z)
|
33 |
+
GEOM_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
|
34 |
+
|
35 |
+
# One-hot atom types
|
36 |
+
GEOM_NUMBER_OF_ATOM_TYPES = len(GEOM_ATOM2IDX)
|
37 |
+
|
38 |
+
# Dataset keys
|
39 |
+
DATA_LIST_ATTRS = {
|
40 |
+
'uuid', 'name', 'fragments_smi', 'linker_smi', 'num_atoms'
|
41 |
+
}
|
42 |
+
DATA_ATTRS_TO_PAD = {
|
43 |
+
'positions', 'one_hot', 'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask'
|
44 |
+
}
|
45 |
+
DATA_ATTRS_TO_ADD_LAST_DIM = {
|
46 |
+
'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask'
|
47 |
+
}
|
48 |
+
|
49 |
+
# Distribution of linker size in train data
|
50 |
+
LINKER_SIZE_DIST = {
|
51 |
+
4: 85540,
|
52 |
+
3: 113928,
|
53 |
+
6: 70946,
|
54 |
+
7: 30408,
|
55 |
+
5: 77671,
|
56 |
+
9: 5177,
|
57 |
+
10: 1214,
|
58 |
+
8: 12712,
|
59 |
+
11: 158,
|
60 |
+
12: 7,
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
# Bond lengths from:
|
65 |
+
# http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html
|
66 |
+
# And:
|
67 |
+
# http://chemistry-reference.com/tables/Bond%20Lengths%20and%20Enthalpies.pdf
|
68 |
+
BONDS_1 = {
|
69 |
+
'H': {
|
70 |
+
'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92,
|
71 |
+
'B': 119, 'Si': 148, 'P': 144, 'As': 152, 'S': 134,
|
72 |
+
'Cl': 127, 'Br': 141, 'I': 161
|
73 |
+
},
|
74 |
+
'C': {
|
75 |
+
'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135,
|
76 |
+
'Si': 185, 'P': 184, 'S': 182, 'Cl': 177, 'Br': 194,
|
77 |
+
'I': 214
|
78 |
+
},
|
79 |
+
'N': {
|
80 |
+
'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136,
|
81 |
+
'Cl': 175, 'Br': 214, 'S': 168, 'I': 222, 'P': 177
|
82 |
+
},
|
83 |
+
'O': {
|
84 |
+
'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142,
|
85 |
+
'Br': 172, 'S': 151, 'P': 163, 'Si': 163, 'Cl': 164,
|
86 |
+
'I': 194
|
87 |
+
},
|
88 |
+
'F': {
|
89 |
+
'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142,
|
90 |
+
'S': 158, 'Si': 160, 'Cl': 166, 'Br': 178, 'P': 156,
|
91 |
+
'I': 187
|
92 |
+
},
|
93 |
+
'B': {
|
94 |
+
'H': 119, 'Cl': 175
|
95 |
+
},
|
96 |
+
'Si': {
|
97 |
+
'Si': 233, 'H': 148, 'C': 185, 'O': 163, 'S': 200,
|
98 |
+
'F': 160, 'Cl': 202, 'Br': 215, 'I': 243,
|
99 |
+
},
|
100 |
+
'Cl': {
|
101 |
+
'Cl': 199, 'H': 127, 'C': 177, 'N': 175, 'O': 164,
|
102 |
+
'P': 203, 'S': 207, 'B': 175, 'Si': 202, 'F': 166,
|
103 |
+
'Br': 214
|
104 |
+
},
|
105 |
+
'S': {
|
106 |
+
'H': 134, 'C': 182, 'N': 168, 'O': 151, 'S': 204,
|
107 |
+
'F': 158, 'Cl': 207, 'Br': 225, 'Si': 200, 'P': 210,
|
108 |
+
'I': 234
|
109 |
+
},
|
110 |
+
'Br': {
|
111 |
+
'Br': 228, 'H': 141, 'C': 194, 'O': 172, 'N': 214,
|
112 |
+
'Si': 215, 'S': 225, 'F': 178, 'Cl': 214, 'P': 222
|
113 |
+
},
|
114 |
+
'P': {
|
115 |
+
'P': 221, 'H': 144, 'C': 184, 'O': 163, 'Cl': 203,
|
116 |
+
'S': 210, 'F': 156, 'N': 177, 'Br': 222
|
117 |
+
},
|
118 |
+
'I': {
|
119 |
+
'H': 161, 'C': 214, 'Si': 243, 'N': 222, 'O': 194,
|
120 |
+
'S': 234, 'F': 187, 'I': 266
|
121 |
+
},
|
122 |
+
'As': {
|
123 |
+
'H': 152
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
BONDS_2 = {
|
128 |
+
'C': {'C': 134, 'N': 129, 'O': 120, 'S': 160},
|
129 |
+
'N': {'C': 129, 'N': 125, 'O': 121},
|
130 |
+
'O': {'C': 120, 'N': 121, 'O': 121, 'P': 150},
|
131 |
+
'P': {'O': 150, 'S': 186},
|
132 |
+
'S': {'P': 186}
|
133 |
+
}
|
134 |
+
|
135 |
+
BONDS_3 = {
|
136 |
+
'C': {'C': 120, 'N': 116, 'O': 113},
|
137 |
+
'N': {'C': 116, 'N': 110},
|
138 |
+
'O': {'C': 113}
|
139 |
+
}
|
140 |
+
|
141 |
+
BOND_DICT = [
|
142 |
+
None,
|
143 |
+
Chem.rdchem.BondType.SINGLE,
|
144 |
+
Chem.rdchem.BondType.DOUBLE,
|
145 |
+
Chem.rdchem.BondType.TRIPLE,
|
146 |
+
Chem.rdchem.BondType.AROMATIC,
|
147 |
+
]
|
148 |
+
|
149 |
+
BOND2IDX = {
|
150 |
+
Chem.rdchem.BondType.SINGLE: 1,
|
151 |
+
Chem.rdchem.BondType.DOUBLE: 2,
|
152 |
+
Chem.rdchem.BondType.TRIPLE: 3,
|
153 |
+
Chem.rdchem.BondType.AROMATIC: 4,
|
154 |
+
}
|
155 |
+
|
156 |
+
ALLOWED_BONDS = {
|
157 |
+
'H': 1,
|
158 |
+
'C': 4,
|
159 |
+
'N': 3,
|
160 |
+
'O': 2,
|
161 |
+
'F': 1,
|
162 |
+
'B': 3,
|
163 |
+
'Al': 3,
|
164 |
+
'Si': 4,
|
165 |
+
'P': [3, 5],
|
166 |
+
'S': 4,
|
167 |
+
'Cl': 1,
|
168 |
+
'As': 3,
|
169 |
+
'Br': 1,
|
170 |
+
'I': 1,
|
171 |
+
'Hg': [1, 2],
|
172 |
+
'Bi': [3, 5]
|
173 |
+
}
|
174 |
+
|
175 |
+
MARGINS_EDM = [10, 5, 2]
|
176 |
+
|
177 |
+
COLORS = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8']
|
178 |
+
# RADII = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
|
179 |
+
RADII = [0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77]
|
180 |
+
|
181 |
+
ZINC_TRAIN_LINKER_ID2SIZE = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
182 |
+
ZINC_TRAIN_LINKER_SIZE2ID = {
|
183 |
+
size: idx
|
184 |
+
for idx, size in enumerate(ZINC_TRAIN_LINKER_ID2SIZE)
|
185 |
+
}
|
186 |
+
ZINC_TRAIN_LINKER_SIZE_WEIGHTS = [
|
187 |
+
3.47347831e-01,
|
188 |
+
4.63079100e-01,
|
189 |
+
5.12370917e-01,
|
190 |
+
5.62392614e-01,
|
191 |
+
1.30294388e+00,
|
192 |
+
3.24247801e+00,
|
193 |
+
8.12391184e+00,
|
194 |
+
3.45634358e+01,
|
195 |
+
2.72428571e+02,
|
196 |
+
6.26585714e+03
|
197 |
+
]
|
198 |
+
|
199 |
+
|
200 |
+
GEOM_TRAIN_LINKER_ID2SIZE = [
|
201 |
+
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
202 |
+
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 36, 38, 41
|
203 |
+
]
|
204 |
+
GEOM_TRAIN_LINKER_SIZE2ID = {
|
205 |
+
size: idx
|
206 |
+
for idx, size in enumerate(GEOM_TRAIN_LINKER_ID2SIZE)
|
207 |
+
}
|
208 |
+
GEOM_TRAIN_LINKER_SIZE_WEIGHTS = [
|
209 |
+
1.07790681e+00, 4.54693604e-01, 3.62575713e-01, 3.75199484e-01,
|
210 |
+
3.67812588e-01, 3.92388528e-01, 3.83421054e-01, 4.26924670e-01,
|
211 |
+
4.92768040e-01, 4.99761944e-01, 4.92342726e-01, 5.71456905e-01,
|
212 |
+
7.30631393e-01, 8.45412928e-01, 9.97252243e-01, 1.25423985e+00,
|
213 |
+
1.57316129e+00, 2.19902962e+00, 3.22640431e+00, 4.25481066e+00,
|
214 |
+
6.34749573e+00, 9.00676236e+00, 1.43084017e+01, 2.25763173e+01,
|
215 |
+
3.36867096e+01, 9.50713805e+01, 2.08693274e+02, 2.51659537e+02,
|
216 |
+
7.77856749e+02, 8.55642424e+03, 8.55642424e+03, 4.27821212e+03,
|
217 |
+
4.27821212e+03
|
218 |
+
]
|
src/datasets.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import pickle
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from rdkit import Chem
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
from tqdm import tqdm
|
10 |
+
from src import const
|
11 |
+
|
12 |
+
|
13 |
+
from pdb import set_trace
|
14 |
+
|
15 |
+
|
16 |
+
def read_sdf(sdf_path):
|
17 |
+
with Chem.SDMolSupplier(sdf_path, sanitize=False) as supplier:
|
18 |
+
for molecule in supplier:
|
19 |
+
yield molecule
|
20 |
+
|
21 |
+
|
22 |
+
def get_one_hot(atom, atoms_dict):
|
23 |
+
one_hot = np.zeros(len(atoms_dict))
|
24 |
+
one_hot[atoms_dict[atom]] = 1
|
25 |
+
return one_hot
|
26 |
+
|
27 |
+
|
28 |
+
def parse_molecule(mol, is_geom):
|
29 |
+
one_hot = []
|
30 |
+
charges = []
|
31 |
+
atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX
|
32 |
+
charges_dict = const.GEOM_CHARGES if is_geom else const.CHARGES
|
33 |
+
for atom in mol.GetAtoms():
|
34 |
+
one_hot.append(get_one_hot(atom.GetSymbol(), atom2idx))
|
35 |
+
charges.append(charges_dict[atom.GetSymbol()])
|
36 |
+
positions = mol.GetConformer().GetPositions()
|
37 |
+
return positions, np.array(one_hot), np.array(charges)
|
38 |
+
|
39 |
+
|
40 |
+
class ZincDataset(Dataset):
|
41 |
+
def __init__(self, data_path, prefix, device):
|
42 |
+
dataset_path = os.path.join(data_path, f'{prefix}.pt')
|
43 |
+
if os.path.exists(dataset_path):
|
44 |
+
self.data = torch.load(dataset_path, map_location=device)
|
45 |
+
else:
|
46 |
+
print(f'Preprocessing dataset with prefix {prefix}')
|
47 |
+
self.data = ZincDataset.preprocess(data_path, prefix, device)
|
48 |
+
torch.save(self.data, dataset_path)
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.data)
|
52 |
+
|
53 |
+
def __getitem__(self, item):
|
54 |
+
return self.data[item]
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def preprocess(data_path, prefix, device):
|
58 |
+
data = []
|
59 |
+
table_path = os.path.join(data_path, f'{prefix}_table.csv')
|
60 |
+
fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf')
|
61 |
+
linkers_path = os.path.join(data_path, f'{prefix}_link.sdf')
|
62 |
+
|
63 |
+
is_geom = ('geom' in prefix) or ('MOAD' in prefix)
|
64 |
+
is_multifrag = 'multifrag' in prefix
|
65 |
+
|
66 |
+
table = pd.read_csv(table_path)
|
67 |
+
generator = tqdm(zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path)), total=len(table))
|
68 |
+
for (_, row), fragments, linker in generator:
|
69 |
+
uuid = row['uuid']
|
70 |
+
name = row['molecule']
|
71 |
+
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
72 |
+
link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom)
|
73 |
+
|
74 |
+
positions = np.concatenate([frag_pos, link_pos], axis=0)
|
75 |
+
one_hot = np.concatenate([frag_one_hot, link_one_hot], axis=0)
|
76 |
+
charges = np.concatenate([frag_charges, link_charges], axis=0)
|
77 |
+
anchors = np.zeros_like(charges)
|
78 |
+
|
79 |
+
if is_multifrag:
|
80 |
+
for anchor_idx in map(int, row['anchors'].split('-')):
|
81 |
+
anchors[anchor_idx] = 1
|
82 |
+
else:
|
83 |
+
anchors[row['anchor_1']] = 1
|
84 |
+
anchors[row['anchor_2']] = 1
|
85 |
+
fragment_mask = np.concatenate([np.ones_like(frag_charges), np.zeros_like(link_charges)])
|
86 |
+
linker_mask = np.concatenate([np.zeros_like(frag_charges), np.ones_like(link_charges)])
|
87 |
+
|
88 |
+
data.append({
|
89 |
+
'uuid': uuid,
|
90 |
+
'name': name,
|
91 |
+
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
92 |
+
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
93 |
+
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
94 |
+
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
|
95 |
+
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
96 |
+
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
97 |
+
'num_atoms': len(positions),
|
98 |
+
})
|
99 |
+
|
100 |
+
return data
|
101 |
+
|
102 |
+
|
103 |
+
class MOADDataset(Dataset):
|
104 |
+
def __init__(self, data_path, prefix, device):
|
105 |
+
prefix, pocket_mode = prefix.split('.')
|
106 |
+
|
107 |
+
dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
|
108 |
+
if os.path.exists(dataset_path):
|
109 |
+
self.data = torch.load(dataset_path, map_location=device)
|
110 |
+
else:
|
111 |
+
print(f'Preprocessing dataset with prefix {prefix}')
|
112 |
+
self.data = MOADDataset.preprocess(data_path, prefix, pocket_mode, device)
|
113 |
+
torch.save(self.data, dataset_path)
|
114 |
+
|
115 |
+
def __len__(self):
|
116 |
+
return len(self.data)
|
117 |
+
|
118 |
+
def __getitem__(self, item):
|
119 |
+
return self.data[item]
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def preprocess(data_path, prefix, pocket_mode, device):
|
123 |
+
data = []
|
124 |
+
table_path = os.path.join(data_path, f'{prefix}_table.csv')
|
125 |
+
fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf')
|
126 |
+
linkers_path = os.path.join(data_path, f'{prefix}_link.sdf')
|
127 |
+
pockets_path = os.path.join(data_path, f'{prefix}_pockets.pkl')
|
128 |
+
|
129 |
+
is_geom = True
|
130 |
+
is_multifrag = 'multifrag' in prefix
|
131 |
+
|
132 |
+
with open(pockets_path, 'rb') as f:
|
133 |
+
pockets = pickle.load(f)
|
134 |
+
|
135 |
+
table = pd.read_csv(table_path)
|
136 |
+
generator = tqdm(
|
137 |
+
zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path), pockets),
|
138 |
+
total=len(table)
|
139 |
+
)
|
140 |
+
for (_, row), fragments, linker, pocket_data in generator:
|
141 |
+
uuid = row['uuid']
|
142 |
+
name = row['molecule']
|
143 |
+
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
144 |
+
link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom)
|
145 |
+
|
146 |
+
# Parsing pocket data
|
147 |
+
pocket_pos = pocket_data[f'{pocket_mode}_coord']
|
148 |
+
pocket_one_hot = []
|
149 |
+
pocket_charges = []
|
150 |
+
for atom_type in pocket_data[f'{pocket_mode}_types']:
|
151 |
+
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
|
152 |
+
pocket_charges.append(const.GEOM_CHARGES[atom_type])
|
153 |
+
pocket_one_hot = np.array(pocket_one_hot)
|
154 |
+
pocket_charges = np.array(pocket_charges)
|
155 |
+
|
156 |
+
positions = np.concatenate([frag_pos, pocket_pos, link_pos], axis=0)
|
157 |
+
one_hot = np.concatenate([frag_one_hot, pocket_one_hot, link_one_hot], axis=0)
|
158 |
+
charges = np.concatenate([frag_charges, pocket_charges, link_charges], axis=0)
|
159 |
+
anchors = np.zeros_like(charges)
|
160 |
+
|
161 |
+
if is_multifrag:
|
162 |
+
for anchor_idx in map(int, row['anchors'].split('-')):
|
163 |
+
anchors[anchor_idx] = 1
|
164 |
+
else:
|
165 |
+
anchors[row['anchor_1']] = 1
|
166 |
+
anchors[row['anchor_2']] = 1
|
167 |
+
|
168 |
+
fragment_only_mask = np.concatenate([
|
169 |
+
np.ones_like(frag_charges),
|
170 |
+
np.zeros_like(pocket_charges),
|
171 |
+
np.zeros_like(link_charges)
|
172 |
+
])
|
173 |
+
pocket_mask = np.concatenate([
|
174 |
+
np.zeros_like(frag_charges),
|
175 |
+
np.ones_like(pocket_charges),
|
176 |
+
np.zeros_like(link_charges)
|
177 |
+
])
|
178 |
+
linker_mask = np.concatenate([
|
179 |
+
np.zeros_like(frag_charges),
|
180 |
+
np.zeros_like(pocket_charges),
|
181 |
+
np.ones_like(link_charges)
|
182 |
+
])
|
183 |
+
fragment_mask = np.concatenate([
|
184 |
+
np.ones_like(frag_charges),
|
185 |
+
np.ones_like(pocket_charges),
|
186 |
+
np.zeros_like(link_charges)
|
187 |
+
])
|
188 |
+
|
189 |
+
data.append({
|
190 |
+
'uuid': uuid,
|
191 |
+
'name': name,
|
192 |
+
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
193 |
+
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
194 |
+
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
195 |
+
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
|
196 |
+
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
|
197 |
+
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
|
198 |
+
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
199 |
+
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
200 |
+
'num_atoms': len(positions),
|
201 |
+
})
|
202 |
+
|
203 |
+
return data
|
204 |
+
|
205 |
+
@staticmethod
|
206 |
+
def create_edges(positions, fragment_mask_only, linker_mask_only):
|
207 |
+
ligand_mask = fragment_mask_only.astype(bool) | linker_mask_only.astype(bool)
|
208 |
+
ligand_adj = ligand_mask[:, None] & ligand_mask[None, :]
|
209 |
+
proximity_adj = np.linalg.norm(positions[:, None, :] - positions[None, :, :], axis=-1) <= 6
|
210 |
+
full_adj = ligand_adj | proximity_adj
|
211 |
+
full_adj &= ~np.eye(len(positions)).astype(bool)
|
212 |
+
|
213 |
+
curr_rows, curr_cols = np.where(full_adj)
|
214 |
+
return [curr_rows, curr_cols]
|
215 |
+
|
216 |
+
|
217 |
+
def collate(batch):
|
218 |
+
out = {}
|
219 |
+
|
220 |
+
# Filter out big molecules
|
221 |
+
if 'pocket_mask' not in batch[0].keys():
|
222 |
+
batch = [data for data in batch if data['num_atoms'] <= 50]
|
223 |
+
else:
|
224 |
+
batch = [data for data in batch if data['num_atoms'] <= 1000]
|
225 |
+
|
226 |
+
for i, data in enumerate(batch):
|
227 |
+
for key, value in data.items():
|
228 |
+
out.setdefault(key, []).append(value)
|
229 |
+
|
230 |
+
for key, value in out.items():
|
231 |
+
if key in const.DATA_LIST_ATTRS:
|
232 |
+
continue
|
233 |
+
if key in const.DATA_ATTRS_TO_PAD:
|
234 |
+
out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0)
|
235 |
+
continue
|
236 |
+
raise Exception(f'Unknown batch key: {key}')
|
237 |
+
|
238 |
+
atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT)
|
239 |
+
out['atom_mask'] = atom_mask[:, :, None]
|
240 |
+
|
241 |
+
batch_size, n_nodes = atom_mask.size()
|
242 |
+
|
243 |
+
# In case of MOAD edge_mask is batch_idx
|
244 |
+
if 'pocket_mask' in batch[0].keys():
|
245 |
+
batch_mask = torch.cat([
|
246 |
+
torch.ones(n_nodes, dtype=const.TORCH_INT) * i
|
247 |
+
for i in range(batch_size)
|
248 |
+
]).to(atom_mask.device)
|
249 |
+
out['edge_mask'] = batch_mask
|
250 |
+
else:
|
251 |
+
edge_mask = atom_mask[:, None, :] * atom_mask[:, :, None]
|
252 |
+
diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=atom_mask.device).unsqueeze(0)
|
253 |
+
edge_mask *= diag_mask
|
254 |
+
out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)
|
255 |
+
|
256 |
+
for key in const.DATA_ATTRS_TO_ADD_LAST_DIM:
|
257 |
+
if key in out.keys():
|
258 |
+
out[key] = out[key][:, :, None]
|
259 |
+
|
260 |
+
return out
|
261 |
+
|
262 |
+
|
263 |
+
def collate_with_fragment_edges(batch):
|
264 |
+
out = {}
|
265 |
+
|
266 |
+
# Filter out big molecules
|
267 |
+
batch = [data for data in batch if data['num_atoms'] <= 50]
|
268 |
+
|
269 |
+
for i, data in enumerate(batch):
|
270 |
+
for key, value in data.items():
|
271 |
+
out.setdefault(key, []).append(value)
|
272 |
+
|
273 |
+
for key, value in out.items():
|
274 |
+
if key in const.DATA_LIST_ATTRS:
|
275 |
+
continue
|
276 |
+
if key in const.DATA_ATTRS_TO_PAD:
|
277 |
+
out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0)
|
278 |
+
continue
|
279 |
+
raise Exception(f'Unknown batch key: {key}')
|
280 |
+
|
281 |
+
frag_mask = out['fragment_mask']
|
282 |
+
edge_mask = frag_mask[:, None, :] * frag_mask[:, :, None]
|
283 |
+
diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=frag_mask.device).unsqueeze(0)
|
284 |
+
edge_mask *= diag_mask
|
285 |
+
|
286 |
+
batch_size, n_nodes = frag_mask.size()
|
287 |
+
out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)
|
288 |
+
|
289 |
+
# Building edges and covalent bond values
|
290 |
+
rows, cols, bonds = [], [], []
|
291 |
+
for batch_idx in range(batch_size):
|
292 |
+
for i in range(n_nodes):
|
293 |
+
for j in range(n_nodes):
|
294 |
+
rows.append(i + batch_idx * n_nodes)
|
295 |
+
cols.append(j + batch_idx * n_nodes)
|
296 |
+
|
297 |
+
edges = [torch.LongTensor(rows).to(frag_mask.device), torch.LongTensor(cols).to(frag_mask.device)]
|
298 |
+
out['edges'] = edges
|
299 |
+
|
300 |
+
atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT)
|
301 |
+
out['atom_mask'] = atom_mask[:, :, None]
|
302 |
+
|
303 |
+
for key in const.DATA_ATTRS_TO_ADD_LAST_DIM:
|
304 |
+
if key in out.keys():
|
305 |
+
out[key] = out[key][:, :, None]
|
306 |
+
|
307 |
+
return out
|
308 |
+
|
309 |
+
|
310 |
+
def get_dataloader(dataset, batch_size, collate_fn=collate, shuffle=False):
|
311 |
+
return DataLoader(dataset, batch_size, collate_fn=collate_fn, shuffle=shuffle)
|
312 |
+
|
313 |
+
|
314 |
+
def create_template(tensor, fragment_size, linker_size, fill=0):
|
315 |
+
values_to_keep = tensor[:fragment_size]
|
316 |
+
values_to_add = torch.ones(linker_size, tensor.shape[1], dtype=values_to_keep.dtype, device=values_to_keep.device)
|
317 |
+
values_to_add = values_to_add * fill
|
318 |
+
return torch.cat([values_to_keep, values_to_add], dim=0)
|
319 |
+
|
320 |
+
|
321 |
+
def create_templates_for_linker_generation(data, linker_sizes):
|
322 |
+
"""
|
323 |
+
Takes data batch and new linker size and returns data batch where fragment-related data is the same
|
324 |
+
but linker-related data is replaced with zero templates with new linker sizes
|
325 |
+
"""
|
326 |
+
decoupled_data = []
|
327 |
+
for i, linker_size in enumerate(linker_sizes):
|
328 |
+
data_dict = {}
|
329 |
+
fragment_mask = data['fragment_mask'][i].squeeze()
|
330 |
+
fragment_size = fragment_mask.sum().int()
|
331 |
+
for k, v in data.items():
|
332 |
+
if k == 'num_atoms':
|
333 |
+
# Computing new number of atoms (fragment_size + linker_size)
|
334 |
+
data_dict[k] = fragment_size + linker_size
|
335 |
+
continue
|
336 |
+
if k in const.DATA_LIST_ATTRS:
|
337 |
+
# These attributes are written without modification
|
338 |
+
data_dict[k] = v[i]
|
339 |
+
continue
|
340 |
+
if k in const.DATA_ATTRS_TO_PAD:
|
341 |
+
# Should write fragment-related data + (zeros x linker_size)
|
342 |
+
fill_value = 1 if k == 'linker_mask' else 0
|
343 |
+
template = create_template(v[i], fragment_size, linker_size, fill=fill_value)
|
344 |
+
if k in const.DATA_ATTRS_TO_ADD_LAST_DIM:
|
345 |
+
template = template.squeeze(-1)
|
346 |
+
data_dict[k] = template
|
347 |
+
|
348 |
+
decoupled_data.append(data_dict)
|
349 |
+
|
350 |
+
return collate(decoupled_data)
|
src/delinker.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from rdkit import Chem
|
5 |
+
from rdkit.Chem import MolStandardize
|
6 |
+
from src import metrics
|
7 |
+
from src.delinker_utils import sascorer, calc_SC_RDKit
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from pdb import set_trace
|
11 |
+
|
12 |
+
|
13 |
+
def get_valid_as_in_delinker(data, progress=False):
|
14 |
+
valid = []
|
15 |
+
generator = tqdm(enumerate(data), total=len(data)) if progress else enumerate(data)
|
16 |
+
for i, m in generator:
|
17 |
+
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=False)
|
18 |
+
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=False)
|
19 |
+
frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=False)
|
20 |
+
|
21 |
+
pred_mol_frags = Chem.GetMolFrags(pred_mol, asMols=True, sanitizeFrags=False)
|
22 |
+
pred_mol_filtered = max(pred_mol_frags, default=pred_mol, key=lambda mol: mol.GetNumAtoms())
|
23 |
+
|
24 |
+
try:
|
25 |
+
Chem.SanitizeMol(pred_mol_filtered)
|
26 |
+
Chem.SanitizeMol(true_mol)
|
27 |
+
Chem.SanitizeMol(frag)
|
28 |
+
except:
|
29 |
+
continue
|
30 |
+
|
31 |
+
if len(pred_mol_filtered.GetSubstructMatch(frag)) > 0:
|
32 |
+
valid.append({
|
33 |
+
'pred_mol': m['pred_mol'],
|
34 |
+
'true_mol': m['true_mol'],
|
35 |
+
'pred_mol_smi': Chem.MolToSmiles(pred_mol_filtered),
|
36 |
+
'true_mol_smi': Chem.MolToSmiles(true_mol),
|
37 |
+
'frag_smi': Chem.MolToSmiles(frag)
|
38 |
+
})
|
39 |
+
|
40 |
+
return valid
|
41 |
+
|
42 |
+
|
43 |
+
def extract_linker_smiles(molecule, fragments):
|
44 |
+
match = molecule.GetSubstructMatch(fragments)
|
45 |
+
elinker = Chem.EditableMol(molecule)
|
46 |
+
for atom_id in sorted(match, reverse=True):
|
47 |
+
elinker.RemoveAtom(atom_id)
|
48 |
+
linker = elinker.GetMol()
|
49 |
+
Chem.RemoveStereochemistry(linker)
|
50 |
+
try:
|
51 |
+
linker = MolStandardize.canonicalize_tautomer_smiles(Chem.MolToSmiles(linker))
|
52 |
+
except:
|
53 |
+
linker = Chem.MolToSmiles(linker)
|
54 |
+
return linker
|
55 |
+
|
56 |
+
|
57 |
+
def compute_and_add_linker_smiles(data, progress=False):
|
58 |
+
data_with_linkers = []
|
59 |
+
generator = tqdm(data) if progress else data
|
60 |
+
for m in generator:
|
61 |
+
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
|
62 |
+
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
|
63 |
+
frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=True)
|
64 |
+
|
65 |
+
pred_linker = extract_linker_smiles(pred_mol, frag)
|
66 |
+
true_linker = extract_linker_smiles(true_mol, frag)
|
67 |
+
data_with_linkers.append({
|
68 |
+
**m,
|
69 |
+
'pred_linker': pred_linker,
|
70 |
+
'true_linker': true_linker,
|
71 |
+
})
|
72 |
+
|
73 |
+
return data_with_linkers
|
74 |
+
|
75 |
+
|
76 |
+
def compute_uniqueness(data, progress=False):
|
77 |
+
mol_dictionary = {}
|
78 |
+
generator = tqdm(data) if progress else data
|
79 |
+
for m in generator:
|
80 |
+
frag = m['frag_smi']
|
81 |
+
pred_mol = m['pred_mol_smi']
|
82 |
+
true_mol = m['true_mol_smi']
|
83 |
+
|
84 |
+
key = f'{true_mol}.{frag}'
|
85 |
+
mol_dictionary.setdefault(key, []).append(pred_mol)
|
86 |
+
|
87 |
+
total_mol = 0
|
88 |
+
unique_mol = 0
|
89 |
+
for molecules in mol_dictionary.values():
|
90 |
+
total_mol += len(molecules)
|
91 |
+
unique_mol += len(set(molecules))
|
92 |
+
|
93 |
+
return unique_mol / total_mol
|
94 |
+
|
95 |
+
|
96 |
+
def compute_novelty(data, progress=False):
|
97 |
+
novel = 0
|
98 |
+
true_linkers = set([m['true_linker'] for m in data])
|
99 |
+
generator = tqdm(data) if progress else data
|
100 |
+
for m in generator:
|
101 |
+
pred_linker = m['pred_linker']
|
102 |
+
if pred_linker in true_linkers:
|
103 |
+
continue
|
104 |
+
else:
|
105 |
+
novel += 1
|
106 |
+
|
107 |
+
return novel / len(data)
|
108 |
+
|
109 |
+
|
110 |
+
def compute_recovery_rate(data, progress=False):
|
111 |
+
total = set()
|
112 |
+
recovered = set()
|
113 |
+
generator = tqdm(data) if progress else data
|
114 |
+
for m in generator:
|
115 |
+
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
|
116 |
+
Chem.RemoveStereochemistry(pred_mol)
|
117 |
+
pred_mol = Chem.MolToSmiles(Chem.RemoveHs(pred_mol))
|
118 |
+
|
119 |
+
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
|
120 |
+
Chem.RemoveStereochemistry(true_mol)
|
121 |
+
true_mol = Chem.MolToSmiles(Chem.RemoveHs(true_mol))
|
122 |
+
|
123 |
+
true_link = m['true_linker']
|
124 |
+
total.add(f'{true_mol}.{true_link}')
|
125 |
+
if pred_mol == true_mol:
|
126 |
+
recovered.add(f'{true_mol}.{true_link}')
|
127 |
+
|
128 |
+
return len(recovered) / len(total)
|
129 |
+
|
130 |
+
|
131 |
+
def calc_sa_score_mol(mol):
|
132 |
+
if mol is None:
|
133 |
+
return None
|
134 |
+
return sascorer.calculateScore(mol)
|
135 |
+
|
136 |
+
|
137 |
+
def check_ring_filter(linker):
|
138 |
+
check = True
|
139 |
+
# Get linker rings
|
140 |
+
ssr = Chem.GetSymmSSSR(linker)
|
141 |
+
# Check rings
|
142 |
+
for ring in ssr:
|
143 |
+
for atom_idx in ring:
|
144 |
+
for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
|
145 |
+
if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
|
146 |
+
check = False
|
147 |
+
return check
|
148 |
+
|
149 |
+
|
150 |
+
def check_pains(mol, pains_smarts):
|
151 |
+
for pain in pains_smarts:
|
152 |
+
if mol.HasSubstructMatch(pain):
|
153 |
+
return False
|
154 |
+
return True
|
155 |
+
|
156 |
+
|
157 |
+
def calc_2d_filters(toks, pains_smarts):
|
158 |
+
pred_mol = Chem.MolFromSmiles(toks['pred_mol_smi'])
|
159 |
+
frag = Chem.MolFromSmiles(toks['frag_smi'])
|
160 |
+
linker = Chem.MolFromSmiles(toks['pred_linker'])
|
161 |
+
|
162 |
+
result = [False, False, False]
|
163 |
+
if len(pred_mol.GetSubstructMatch(frag)) > 0:
|
164 |
+
sa_score = False
|
165 |
+
ra_score = False
|
166 |
+
pains_score = False
|
167 |
+
|
168 |
+
try:
|
169 |
+
sa_score = calc_sa_score_mol(pred_mol) < calc_sa_score_mol(frag)
|
170 |
+
except Exception as e:
|
171 |
+
print(f'Could not compute SA score: {e}')
|
172 |
+
try:
|
173 |
+
ra_score = check_ring_filter(linker)
|
174 |
+
except Exception as e:
|
175 |
+
print(f'Could not compute RA score: {e}')
|
176 |
+
try:
|
177 |
+
pains_score = check_pains(pred_mol, pains_smarts)
|
178 |
+
except Exception as e:
|
179 |
+
print(f'Could not compute PAINS score: {e}')
|
180 |
+
|
181 |
+
result = [sa_score, ra_score, pains_score]
|
182 |
+
|
183 |
+
return result
|
184 |
+
|
185 |
+
|
186 |
+
def calc_filters_2d_dataset(data):
|
187 |
+
with open('models/wehi_pains.csv', 'r') as f:
|
188 |
+
pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)]
|
189 |
+
|
190 |
+
pass_all = pass_SA = pass_RA = pass_PAINS = 0
|
191 |
+
for m in data:
|
192 |
+
filters_2d = calc_2d_filters(m, pains_smarts)
|
193 |
+
pass_all += filters_2d[0] & filters_2d[1] & filters_2d[2]
|
194 |
+
pass_SA += filters_2d[0]
|
195 |
+
pass_RA += filters_2d[1]
|
196 |
+
pass_PAINS += filters_2d[2]
|
197 |
+
|
198 |
+
return pass_all / len(data), pass_SA / len(data), pass_RA / len(data), pass_PAINS / len(data)
|
199 |
+
|
200 |
+
|
201 |
+
def calc_sc_rdkit_full_mol(gen_mol, ref_mol):
|
202 |
+
try:
|
203 |
+
score = calc_SC_RDKit.calc_SC_RDKit_score(gen_mol, ref_mol)
|
204 |
+
return score
|
205 |
+
except:
|
206 |
+
return -0.5
|
207 |
+
|
208 |
+
|
209 |
+
def sc_rdkit_score(data):
|
210 |
+
scores = []
|
211 |
+
for m in data:
|
212 |
+
score = calc_sc_rdkit_full_mol(m['pred_mol'], m['true_mol'])
|
213 |
+
scores.append(score)
|
214 |
+
|
215 |
+
return np.mean(scores)
|
216 |
+
|
217 |
+
|
218 |
+
def get_delinker_metrics(pred_molecules, true_molecules, true_fragments):
|
219 |
+
default_values = {
|
220 |
+
'DeLinker/validity': 0,
|
221 |
+
'DeLinker/uniqueness': 0,
|
222 |
+
'DeLinker/novelty': 0,
|
223 |
+
'DeLinker/recovery': 0,
|
224 |
+
'DeLinker/2D_filters': 0,
|
225 |
+
'DeLinker/2D_filters_SA': 0,
|
226 |
+
'DeLinker/2D_filters_RA': 0,
|
227 |
+
'DeLinker/2D_filters_PAINS': 0,
|
228 |
+
'DeLinker/SC_RDKit': 0,
|
229 |
+
}
|
230 |
+
if len(pred_molecules) == 0:
|
231 |
+
return default_values
|
232 |
+
|
233 |
+
data = []
|
234 |
+
for pred_mol, true_mol, true_frag in zip(pred_molecules, true_molecules, true_fragments):
|
235 |
+
data.append({
|
236 |
+
'pred_mol': pred_mol,
|
237 |
+
'true_mol': true_mol,
|
238 |
+
'pred_mol_smi': Chem.MolToSmiles(pred_mol),
|
239 |
+
'true_mol_smi': Chem.MolToSmiles(true_mol),
|
240 |
+
'frag_smi': Chem.MolToSmiles(true_frag)
|
241 |
+
})
|
242 |
+
|
243 |
+
# Validity according to DeLinker paper:
|
244 |
+
# Passing rdkit.Chem.Sanitize and the biggest fragment contains both fragments
|
245 |
+
valid_data = get_valid_as_in_delinker(data)
|
246 |
+
validity_as_in_delinker = len(valid_data) / len(data)
|
247 |
+
if len(valid_data) == 0:
|
248 |
+
return default_values
|
249 |
+
|
250 |
+
# Compute linkers and add to results
|
251 |
+
valid_data = compute_and_add_linker_smiles(valid_data)
|
252 |
+
|
253 |
+
# Compute uniqueness
|
254 |
+
uniqueness = compute_uniqueness(valid_data)
|
255 |
+
|
256 |
+
# Compute novelty
|
257 |
+
novelty = compute_novelty(valid_data)
|
258 |
+
|
259 |
+
# Compute recovered molecules
|
260 |
+
recovery_rate = compute_recovery_rate(valid_data)
|
261 |
+
|
262 |
+
# 2D filters
|
263 |
+
pass_all, pass_SA, pass_RA, pass_PAINS = calc_filters_2d_dataset(valid_data)
|
264 |
+
|
265 |
+
# 3D Filters
|
266 |
+
sc_rdkit = sc_rdkit_score(valid_data)
|
267 |
+
|
268 |
+
return {
|
269 |
+
'DeLinker/validity': validity_as_in_delinker,
|
270 |
+
'DeLinker/uniqueness': uniqueness,
|
271 |
+
'DeLinker/novelty': novelty,
|
272 |
+
'DeLinker/recovery': recovery_rate,
|
273 |
+
'DeLinker/2D_filters': pass_all,
|
274 |
+
'DeLinker/2D_filters_SA': pass_SA,
|
275 |
+
'DeLinker/2D_filters_RA': pass_RA,
|
276 |
+
'DeLinker/2D_filters_PAINS': pass_PAINS,
|
277 |
+
'DeLinker/SC_RDKit': sc_rdkit,
|
278 |
+
}
|
src/delinker_utils/__init__.py
ADDED
File without changes
|
src/delinker_utils/calc_SC_RDKit.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from rdkit import Chem
|
3 |
+
from rdkit.Chem import AllChem, rdShapeHelpers
|
4 |
+
from rdkit.Chem.FeatMaps import FeatMaps
|
5 |
+
from rdkit import RDConfig
|
6 |
+
|
7 |
+
# Set up features to use in FeatureMap
|
8 |
+
fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
|
9 |
+
fdef = AllChem.BuildFeatureFactory(fdefName)
|
10 |
+
|
11 |
+
fmParams = {}
|
12 |
+
for k in fdef.GetFeatureFamilies():
|
13 |
+
fparams = FeatMaps.FeatMapParams()
|
14 |
+
fmParams[k] = fparams
|
15 |
+
|
16 |
+
keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable',
|
17 |
+
'ZnBinder', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe')
|
18 |
+
|
19 |
+
|
20 |
+
def get_FeatureMapScore(query_mol, ref_mol):
|
21 |
+
featLists = []
|
22 |
+
for m in [query_mol, ref_mol]:
|
23 |
+
rawFeats = fdef.GetFeaturesForMol(m)
|
24 |
+
# filter that list down to only include the ones we're intereted in
|
25 |
+
featLists.append([f for f in rawFeats if f.GetFamily() in keep])
|
26 |
+
fms = [FeatMaps.FeatMap(feats=x, weights=[1] * len(x), params=fmParams) for x in featLists]
|
27 |
+
fms[0].scoreMode = FeatMaps.FeatMapScoreMode.Best
|
28 |
+
fm_score = fms[0].ScoreFeats(featLists[1]) / min(fms[0].GetNumFeatures(), len(featLists[1]))
|
29 |
+
|
30 |
+
return fm_score
|
31 |
+
|
32 |
+
|
33 |
+
def calc_SC_RDKit_score(query_mol, ref_mol):
|
34 |
+
fm_score = get_FeatureMapScore(query_mol, ref_mol)
|
35 |
+
|
36 |
+
protrude_dist = rdShapeHelpers.ShapeProtrudeDist(query_mol, ref_mol,
|
37 |
+
allowReordering=False)
|
38 |
+
SC_RDKit_score = 0.5 * fm_score + 0.5 * (1 - protrude_dist)
|
39 |
+
|
40 |
+
return SC_RDKit_score
|
src/delinker_utils/frag_utils.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import networkx as nx
|
3 |
+
|
4 |
+
from joblib import Parallel, delayed
|
5 |
+
from rdkit import Chem
|
6 |
+
from rdkit.Chem import AllChem
|
7 |
+
from src.delinker_utils import sascorer
|
8 |
+
|
9 |
+
|
10 |
+
def read_triples_file(filename):
|
11 |
+
'''Reads .smi file '''
|
12 |
+
'''Returns array containing smiles strings of molecules'''
|
13 |
+
smiles, names = [], []
|
14 |
+
with open(filename, 'r') as f:
|
15 |
+
for line in f:
|
16 |
+
if line:
|
17 |
+
smiles.append(line.strip().split(' ')[0:3])
|
18 |
+
return smiles
|
19 |
+
|
20 |
+
|
21 |
+
def remove_dummys(smi_string):
|
22 |
+
return Chem.MolToSmiles(Chem.RemoveHs(AllChem.ReplaceSubstructs(Chem.MolFromSmiles(smi_string),Chem.MolFromSmiles('*'),Chem.MolFromSmiles('[H]'),True)[0]))
|
23 |
+
|
24 |
+
|
25 |
+
def sa_filter(results, verbose=True):
|
26 |
+
count = 0
|
27 |
+
total = 0
|
28 |
+
for processed, res in enumerate(results):
|
29 |
+
total += len(res)
|
30 |
+
for m in res:
|
31 |
+
# Check SA score has improved
|
32 |
+
if calc_mol_props(m[1])[1] < calc_mol_props(m[0])[1]:
|
33 |
+
count += 1
|
34 |
+
# Progress
|
35 |
+
if verbose:
|
36 |
+
if processed % 10 == 0:
|
37 |
+
print("\rProcessed %d" % processed, end="")
|
38 |
+
print("\r",end="")
|
39 |
+
return count/total
|
40 |
+
|
41 |
+
|
42 |
+
def ring_check_res(res, clean_frag):
|
43 |
+
check = True
|
44 |
+
gen_mol = Chem.MolFromSmiles(res[1])
|
45 |
+
linker = Chem.DeleteSubstructs(gen_mol, clean_frag)
|
46 |
+
|
47 |
+
# Get linker rings
|
48 |
+
ssr = Chem.GetSymmSSSR(linker)
|
49 |
+
# Check rings
|
50 |
+
for ring in ssr:
|
51 |
+
for atom_idx in ring:
|
52 |
+
for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
|
53 |
+
if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
|
54 |
+
check = False
|
55 |
+
return check
|
56 |
+
|
57 |
+
|
58 |
+
def ring_filter(results, verbose=True):
|
59 |
+
count = 0
|
60 |
+
total = 0
|
61 |
+
du = Chem.MolFromSmiles('*')
|
62 |
+
for processed, res in enumerate(results):
|
63 |
+
total += len(res)
|
64 |
+
for m in res:
|
65 |
+
# Clean frags
|
66 |
+
clean_frag = Chem.RemoveHs(AllChem.ReplaceSubstructs(Chem.MolFromSmiles(m[0]),du,Chem.MolFromSmiles('[H]'),True)[0])
|
67 |
+
if ring_check_res(m, clean_frag):
|
68 |
+
count += 1
|
69 |
+
# Progress
|
70 |
+
if verbose:
|
71 |
+
if processed % 10 == 0:
|
72 |
+
print("\rProcessed %d" % processed, end="")
|
73 |
+
print("\r",end="")
|
74 |
+
return count/total
|
75 |
+
|
76 |
+
|
77 |
+
def check_ring_filter(linker):
|
78 |
+
check = True
|
79 |
+
# Get linker rings
|
80 |
+
ssr = Chem.GetSymmSSSR(linker)
|
81 |
+
# Check rings
|
82 |
+
for ring in ssr:
|
83 |
+
for atom_idx in ring:
|
84 |
+
for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
|
85 |
+
if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
|
86 |
+
check = False
|
87 |
+
return check
|
88 |
+
|
89 |
+
|
90 |
+
def check_pains(mol, pains_smarts):
|
91 |
+
for pain in pains_smarts:
|
92 |
+
if mol.HasSubstructMatch(pain):
|
93 |
+
return False
|
94 |
+
return True
|
95 |
+
|
96 |
+
|
97 |
+
def calc_2d_filters(toks, pains_smarts):
|
98 |
+
try:
|
99 |
+
# Input format: (Full Molecule (SMILES), Linker (SMILES), Unlinked Fragments (SMILES))
|
100 |
+
frags = Chem.MolFromSmiles(toks[2])
|
101 |
+
linker = Chem.MolFromSmiles(toks[1])
|
102 |
+
full_mol = Chem.MolFromSmiles(toks[0])
|
103 |
+
# Remove dummy atoms from unlinked fragments
|
104 |
+
du = Chem.MolFromSmiles('*')
|
105 |
+
clean_frag = Chem.RemoveHs(AllChem.ReplaceSubstructs(frags, du, Chem.MolFromSmiles('[H]'), True)[0])
|
106 |
+
|
107 |
+
res = []
|
108 |
+
# Check: Unlinked fragments in full molecule
|
109 |
+
if len(full_mol.GetSubstructMatch(clean_frag)) > 0:
|
110 |
+
# Check: SA score improved from unlinked fragments to full molecule
|
111 |
+
if calc_sa_score_mol(full_mol) < calc_sa_score_mol(frags):
|
112 |
+
res.append(True)
|
113 |
+
else:
|
114 |
+
res.append(False)
|
115 |
+
# Check: No non-aromatic rings with double bonds
|
116 |
+
if check_ring_filter(linker):
|
117 |
+
res.append(True)
|
118 |
+
else:
|
119 |
+
res.append(False)
|
120 |
+
# Check: Pass pains filters
|
121 |
+
if check_pains(full_mol, pains_smarts):
|
122 |
+
res.append(True)
|
123 |
+
else:
|
124 |
+
res.append(False)
|
125 |
+
return res
|
126 |
+
except:
|
127 |
+
return [False, False, False]
|
128 |
+
|
129 |
+
|
130 |
+
def calc_filters_2d_dataset(results, pains_smarts_loc, n_cores=1):
|
131 |
+
# Load pains filters
|
132 |
+
with open(pains_smarts_loc, 'r') as f:
|
133 |
+
pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)]
|
134 |
+
# calc_2d_filters([results[0][2], results[0][4], results[0][1]], pains_smarts)
|
135 |
+
with Parallel(n_jobs=n_cores, backend='multiprocessing') as parallel:
|
136 |
+
filters_2d = parallel(delayed(calc_2d_filters)([toks[2], toks[4], toks[1]], pains_smarts) for toks in results)
|
137 |
+
|
138 |
+
return filters_2d
|
139 |
+
|
140 |
+
|
141 |
+
def calc_mol_props(smiles):
|
142 |
+
# Create RDKit mol
|
143 |
+
mol = Chem.MolFromSmiles(smiles)
|
144 |
+
if mol is None:
|
145 |
+
print("Error passing: %s" % smiles)
|
146 |
+
return None
|
147 |
+
|
148 |
+
# QED
|
149 |
+
qed = Chem.QED.qed(mol)
|
150 |
+
# Synthetic accessibility score - number of cycles (rings with > 6 atoms)
|
151 |
+
sas = sascorer.calculateScore(mol)
|
152 |
+
# Cyles with >6 atoms
|
153 |
+
ri = mol.GetRingInfo()
|
154 |
+
nMacrocycles = 0
|
155 |
+
for x in ri.AtomRings():
|
156 |
+
if len(x) > 6:
|
157 |
+
nMacrocycles += 1
|
158 |
+
|
159 |
+
prop_array = [qed, sas]
|
160 |
+
|
161 |
+
return prop_array
|
162 |
+
|
163 |
+
|
164 |
+
def calc_sa_score_mol(mol, verbose=False):
|
165 |
+
if mol is None:
|
166 |
+
if verbose:
|
167 |
+
print("Error passing: %s" % mol)
|
168 |
+
return None
|
169 |
+
# Synthetic accessibility score
|
170 |
+
return sascorer.calculateScore(mol)
|
171 |
+
|
172 |
+
|
173 |
+
def get_linker(full_mol, clean_frag, starting_point):
|
174 |
+
# INPUT FORMAT: molecule (RDKit mol object), clean fragments (RDKit mol object), starting fragments (SMILES)
|
175 |
+
|
176 |
+
# Get matches of fragments
|
177 |
+
matches = list(full_mol.GetSubstructMatches(clean_frag))
|
178 |
+
|
179 |
+
# If no matches, terminate
|
180 |
+
if len(matches) == 0:
|
181 |
+
print("No matches")
|
182 |
+
return ""
|
183 |
+
|
184 |
+
# Get number of atoms in linker
|
185 |
+
linker_len = full_mol.GetNumHeavyAtoms() - clean_frag.GetNumHeavyAtoms()
|
186 |
+
if linker_len == 0:
|
187 |
+
return ""
|
188 |
+
|
189 |
+
# Setup
|
190 |
+
mol_to_break = Chem.Mol(full_mol)
|
191 |
+
Chem.Kekulize(full_mol, clearAromaticFlags=True)
|
192 |
+
|
193 |
+
poss_linker = []
|
194 |
+
|
195 |
+
if len(matches) > 0:
|
196 |
+
# Loop over matches
|
197 |
+
for match in matches:
|
198 |
+
mol_rw = Chem.RWMol(full_mol)
|
199 |
+
# Get linker atoms
|
200 |
+
linker_atoms = list(set(list(range(full_mol.GetNumHeavyAtoms()))).difference(match))
|
201 |
+
linker_bonds = []
|
202 |
+
atoms_joined_to_linker = []
|
203 |
+
# Loop over starting fragments atoms
|
204 |
+
# Get (i) bonds between starting fragments and linker, (ii) atoms joined to linker
|
205 |
+
for idx_to_delete in sorted(match, reverse=True):
|
206 |
+
nei = [x.GetIdx() for x in mol_rw.GetAtomWithIdx(idx_to_delete).GetNeighbors()]
|
207 |
+
intersect = set(nei).intersection(set(linker_atoms))
|
208 |
+
if len(intersect) == 1:
|
209 |
+
linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, list(intersect)[0]).GetIdx())
|
210 |
+
atoms_joined_to_linker.append(idx_to_delete)
|
211 |
+
elif len(intersect) > 1:
|
212 |
+
for idx_nei in list(intersect):
|
213 |
+
linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, idx_nei).GetIdx())
|
214 |
+
atoms_joined_to_linker.append(idx_to_delete)
|
215 |
+
|
216 |
+
# Check number of atoms joined to linker
|
217 |
+
# If not == 2, check next match
|
218 |
+
if len(set(atoms_joined_to_linker)) != 2:
|
219 |
+
continue
|
220 |
+
|
221 |
+
# Delete starting fragments atoms
|
222 |
+
for idx_to_delete in sorted(match, reverse=True):
|
223 |
+
mol_rw.RemoveAtom(idx_to_delete)
|
224 |
+
|
225 |
+
linker = Chem.Mol(mol_rw)
|
226 |
+
# Check linker required num atoms
|
227 |
+
if linker.GetNumHeavyAtoms() == linker_len:
|
228 |
+
mol_rw = Chem.RWMol(full_mol)
|
229 |
+
# Delete linker atoms
|
230 |
+
for idx_to_delete in sorted(linker_atoms, reverse=True):
|
231 |
+
mol_rw.RemoveAtom(idx_to_delete)
|
232 |
+
frags = Chem.Mol(mol_rw)
|
233 |
+
# Check there are two disconnected fragments
|
234 |
+
if len(Chem.rdmolops.GetMolFrags(frags)) == 2:
|
235 |
+
# Fragment molecule into starting fragments and linker
|
236 |
+
fragmented_mol = Chem.FragmentOnBonds(mol_to_break, linker_bonds)
|
237 |
+
# Remove starting fragments from fragmentation
|
238 |
+
linker_to_return = Chem.Mol(fragmented_mol)
|
239 |
+
qp = Chem.AdjustQueryParameters()
|
240 |
+
qp.makeDummiesQueries = True
|
241 |
+
for f in starting_point.split('.'):
|
242 |
+
qfrag = Chem.AdjustQueryProperties(Chem.MolFromSmiles(f), qp)
|
243 |
+
linker_to_return = AllChem.DeleteSubstructs(linker_to_return, qfrag, onlyFrags=True)
|
244 |
+
|
245 |
+
# Check linker is connected and two bonds to outside molecule
|
246 |
+
if len(Chem.rdmolops.GetMolFrags(linker)) == 1 and len(linker_bonds) == 2:
|
247 |
+
Chem.Kekulize(linker_to_return, clearAromaticFlags=True)
|
248 |
+
# If for some reason a starting fragment isn't removed (and it's larger than the linker), remove (happens v. occassionally)
|
249 |
+
if len(Chem.rdmolops.GetMolFrags(linker_to_return)) > 1:
|
250 |
+
for frag in Chem.MolToSmiles(linker_to_return).split('.'):
|
251 |
+
if Chem.MolFromSmiles(frag).GetNumHeavyAtoms() == linker_len:
|
252 |
+
return frag
|
253 |
+
return Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(linker_to_return)))
|
254 |
+
|
255 |
+
# If not, add to possible linkers (above doesn't capture some complex cases)
|
256 |
+
else:
|
257 |
+
fragmented_mol = Chem.MolFromSmiles(Chem.MolToSmiles(fragmented_mol), sanitize=False)
|
258 |
+
linker_to_return = AllChem.DeleteSubstructs(fragmented_mol, Chem.MolFromSmiles(starting_point))
|
259 |
+
poss_linker.append(Chem.MolToSmiles(linker_to_return))
|
260 |
+
|
261 |
+
# If only one possibility, return linker
|
262 |
+
if len(poss_linker) == 1:
|
263 |
+
return poss_linker[0]
|
264 |
+
# If no possibilities, process failed
|
265 |
+
elif len(poss_linker) == 0:
|
266 |
+
print("FAIL:", Chem.MolToSmiles(full_mol), Chem.MolToSmiles(clean_frag), starting_point)
|
267 |
+
return ""
|
268 |
+
# If multiple possibilities, process probably failed
|
269 |
+
else:
|
270 |
+
print("More than one poss linker. ", poss_linker)
|
271 |
+
return poss_linker[0]
|
272 |
+
|
273 |
+
|
274 |
+
def get_linker_v2(full_mol, clean_frag):
|
275 |
+
# INPUT FORMAT: molecule (RDKit mol object), clean fragments (RDKit mol object), starting fragments (SMILES)
|
276 |
+
|
277 |
+
# Get matches of fragments
|
278 |
+
matches = list(full_mol.GetSubstructMatches(clean_frag))
|
279 |
+
|
280 |
+
# If no matches, terminate
|
281 |
+
if len(matches) == 0:
|
282 |
+
print("No matches")
|
283 |
+
return ""
|
284 |
+
|
285 |
+
# Get number of atoms in linker
|
286 |
+
linker_len = full_mol.GetNumHeavyAtoms() - clean_frag.GetNumHeavyAtoms()
|
287 |
+
if linker_len == 0:
|
288 |
+
return ""
|
289 |
+
|
290 |
+
# Setup
|
291 |
+
mol_to_break = Chem.Mol(full_mol)
|
292 |
+
Chem.Kekulize(full_mol, clearAromaticFlags=True)
|
293 |
+
|
294 |
+
poss_linker = []
|
295 |
+
|
296 |
+
if len(matches) > 0:
|
297 |
+
# Loop over matches
|
298 |
+
for match in matches:
|
299 |
+
mol_rw = Chem.RWMol(full_mol)
|
300 |
+
# Get linker atoms
|
301 |
+
linker_atoms = list(set(list(range(full_mol.GetNumHeavyAtoms()))).difference(match))
|
302 |
+
linker_bonds = []
|
303 |
+
atoms_joined_to_linker = []
|
304 |
+
# Loop over starting fragments atoms
|
305 |
+
# Get (i) bonds between starting fragments and linker, (ii) atoms joined to linker
|
306 |
+
for idx_to_delete in sorted(match, reverse=True):
|
307 |
+
nei = [x.GetIdx() for x in mol_rw.GetAtomWithIdx(idx_to_delete).GetNeighbors()]
|
308 |
+
intersect = set(nei).intersection(set(linker_atoms))
|
309 |
+
if len(intersect) == 1:
|
310 |
+
linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, list(intersect)[0]).GetIdx())
|
311 |
+
atoms_joined_to_linker.append(idx_to_delete)
|
312 |
+
elif len(intersect) > 1:
|
313 |
+
for idx_nei in list(intersect):
|
314 |
+
linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, idx_nei).GetIdx())
|
315 |
+
atoms_joined_to_linker.append(idx_to_delete)
|
316 |
+
|
317 |
+
# Check number of atoms joined to linker
|
318 |
+
# If not == 2, check next match
|
319 |
+
if len(set(atoms_joined_to_linker)) != 2:
|
320 |
+
continue
|
321 |
+
|
322 |
+
# Delete starting fragments atoms
|
323 |
+
for idx_to_delete in sorted(match, reverse=True):
|
324 |
+
mol_rw.RemoveAtom(idx_to_delete)
|
325 |
+
|
326 |
+
linker = Chem.Mol(mol_rw)
|
327 |
+
# Check linker required num atoms
|
328 |
+
if linker.GetNumHeavyAtoms() == linker_len:
|
329 |
+
mol_rw = Chem.RWMol(full_mol)
|
330 |
+
# Delete linker atoms
|
331 |
+
for idx_to_delete in sorted(linker_atoms, reverse=True):
|
332 |
+
mol_rw.RemoveAtom(idx_to_delete)
|
333 |
+
frags = Chem.Mol(mol_rw)
|
334 |
+
|
335 |
+
# Check linker is connected and two bonds to outside molecule
|
336 |
+
if len(Chem.rdmolops.GetMolFrags(linker)) == 1 and len(linker_bonds) == 2:
|
337 |
+
Chem.Kekulize(linker, clearAromaticFlags=True)
|
338 |
+
# If for some reason a starting fragment isn't removed (and it's larger than the linker), remove (happens v. occassionally)
|
339 |
+
if len(Chem.rdmolops.GetMolFrags(linker)) > 1:
|
340 |
+
for frag in Chem.MolToSmiles(linker).split('.'):
|
341 |
+
if Chem.MolFromSmiles(frag).GetNumHeavyAtoms() == linker_len:
|
342 |
+
return frag
|
343 |
+
return Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(linker)))
|
344 |
+
|
345 |
+
# If not, add to possible linkers (above doesn't capture some complex cases)
|
346 |
+
else:
|
347 |
+
poss_linker.append(Chem.MolToSmiles(linker))
|
348 |
+
|
349 |
+
# If only one possibility, return linker
|
350 |
+
if len(poss_linker) == 1:
|
351 |
+
return poss_linker[0]
|
352 |
+
# If no possibilities, process failed
|
353 |
+
elif len(poss_linker) == 0:
|
354 |
+
print("FAIL:", Chem.MolToSmiles(full_mol), Chem.MolToSmiles(clean_frag))
|
355 |
+
return ""
|
356 |
+
# If multiple possibilities, process probably failed
|
357 |
+
else:
|
358 |
+
print("More than one poss linker. ", poss_linker)
|
359 |
+
return poss_linker[0]
|
360 |
+
|
361 |
+
|
362 |
+
def unique(results):
|
363 |
+
total_dupes = 0
|
364 |
+
total = 0
|
365 |
+
for res in results:
|
366 |
+
original_num = len(res)
|
367 |
+
test_data = set(res)
|
368 |
+
new_num = len(test_data)
|
369 |
+
total_dupes += original_num - new_num
|
370 |
+
total += original_num
|
371 |
+
return 1 - total_dupes/float(total)
|
372 |
+
|
373 |
+
|
374 |
+
def check_recovered_original_mol_with_idx(results):
|
375 |
+
outcomes = []
|
376 |
+
rec_idx = []
|
377 |
+
for res in results:
|
378 |
+
success = False
|
379 |
+
# Load original mol and canonicalise
|
380 |
+
orig_mol = Chem.MolFromSmiles(res[0][0][0])
|
381 |
+
Chem.RemoveStereochemistry(orig_mol)
|
382 |
+
orig_mol = Chem.MolToSmiles(Chem.RemoveHs(orig_mol))
|
383 |
+
#orig_mol = MolStandardize.canonicalize_tautomer_smiles(orig_mol)
|
384 |
+
# Check generated mols
|
385 |
+
for m in res:
|
386 |
+
# print(1)
|
387 |
+
gen_mol = Chem.MolFromSmiles(m[0][2])
|
388 |
+
Chem.RemoveStereochemistry(gen_mol)
|
389 |
+
gen_mol = Chem.MolToSmiles(Chem.RemoveHs(gen_mol))
|
390 |
+
#gen_mol = MolStandardize.canonicalize_tautomer_smiles(gen_mol)
|
391 |
+
if gen_mol == orig_mol:
|
392 |
+
# outcomes.append(True)
|
393 |
+
success = True
|
394 |
+
rec_idx.append(m[1])
|
395 |
+
# break
|
396 |
+
if not success:
|
397 |
+
outcomes.append(False)
|
398 |
+
else:
|
399 |
+
outcomes.append(True)
|
400 |
+
return outcomes, rec_idx
|
401 |
+
|
402 |
+
|
403 |
+
def topology_from_rdkit(rdkit_molecule):
|
404 |
+
topology = nx.Graph()
|
405 |
+
for atom in rdkit_molecule.GetAtoms():
|
406 |
+
# Add the atoms as nodes
|
407 |
+
topology.add_node(atom.GetIdx(), atom_type=atom.GetAtomicNum())
|
408 |
+
|
409 |
+
# Add the bonds as edges
|
410 |
+
for bond in rdkit_molecule.GetBonds():
|
411 |
+
topology.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond_type=bond.GetBondType())
|
412 |
+
|
413 |
+
return topology
|
src/delinker_utils/sascorer.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# calculation of synthetic accessibility score as described in:
|
3 |
+
#
|
4 |
+
# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
|
5 |
+
# Peter Ertl and Ansgar Schuffenhauer
|
6 |
+
# Journal of Cheminformatics 1:8 (2009)
|
7 |
+
# http://www.jcheminf.com/content/1/1/8
|
8 |
+
#
|
9 |
+
# several small modifications to the original paper are included
|
10 |
+
# particularly slightly different formula for marocyclic penalty
|
11 |
+
# and taking into account also molecule symmetry (fingerprint density)
|
12 |
+
#
|
13 |
+
# for a set of 10k diverse molecules the agreement between the original method
|
14 |
+
# as implemented in PipelinePilot and this implementation is r2 = 0.97
|
15 |
+
#
|
16 |
+
# peter ertl & greg landrum, september 2013
|
17 |
+
#
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
from rdkit import Chem
|
21 |
+
from rdkit.Chem import rdMolDescriptors
|
22 |
+
from rdkit.six.moves import cPickle
|
23 |
+
from rdkit.six import iteritems
|
24 |
+
|
25 |
+
import math
|
26 |
+
from collections import defaultdict
|
27 |
+
|
28 |
+
import os.path as op
|
29 |
+
|
30 |
+
_fscores = None
|
31 |
+
|
32 |
+
|
33 |
+
def readFragmentScores(name='models/fpscores'):
|
34 |
+
import gzip
|
35 |
+
global _fscores
|
36 |
+
# generate the full path filename:
|
37 |
+
if name == "fpscores":
|
38 |
+
name = op.join(op.dirname(__file__), name)
|
39 |
+
_fscores = cPickle.load(gzip.open('%s.pkl.gz' % name))
|
40 |
+
outDict = {}
|
41 |
+
for i in _fscores:
|
42 |
+
for j in range(1, len(i)):
|
43 |
+
outDict[i[j]] = float(i[0])
|
44 |
+
_fscores = outDict
|
45 |
+
|
46 |
+
|
47 |
+
def numBridgeheadsAndSpiro(mol, ri=None):
|
48 |
+
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
49 |
+
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
50 |
+
return nBridgehead, nSpiro
|
51 |
+
|
52 |
+
|
53 |
+
def calculateScore(m):
|
54 |
+
if _fscores is None:
|
55 |
+
readFragmentScores()
|
56 |
+
|
57 |
+
# fragment score
|
58 |
+
fp = rdMolDescriptors.GetMorganFingerprint(m,
|
59 |
+
2) #<- 2 is the *radius* of the circular fingerprint
|
60 |
+
fps = fp.GetNonzeroElements()
|
61 |
+
score1 = 0.
|
62 |
+
nf = 0
|
63 |
+
for bitId, v in iteritems(fps):
|
64 |
+
nf += v
|
65 |
+
sfp = bitId
|
66 |
+
score1 += _fscores.get(sfp, -4) * v
|
67 |
+
score1 /= nf
|
68 |
+
|
69 |
+
# features score
|
70 |
+
nAtoms = m.GetNumAtoms()
|
71 |
+
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
72 |
+
ri = m.GetRingInfo()
|
73 |
+
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
74 |
+
nMacrocycles = 0
|
75 |
+
for x in ri.AtomRings():
|
76 |
+
if len(x) > 8:
|
77 |
+
nMacrocycles += 1
|
78 |
+
|
79 |
+
sizePenalty = nAtoms**1.005 - nAtoms
|
80 |
+
stereoPenalty = math.log10(nChiralCenters + 1)
|
81 |
+
spiroPenalty = math.log10(nSpiro + 1)
|
82 |
+
bridgePenalty = math.log10(nBridgeheads + 1)
|
83 |
+
macrocyclePenalty = 0.
|
84 |
+
# ---------------------------------------
|
85 |
+
# This differs from the paper, which defines:
|
86 |
+
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
87 |
+
# This form generates better results when 2 or more macrocycles are present
|
88 |
+
if nMacrocycles > 0:
|
89 |
+
macrocyclePenalty = math.log10(2)
|
90 |
+
|
91 |
+
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
92 |
+
|
93 |
+
# correction for the fingerprint density
|
94 |
+
# not in the original publication, added in version 1.1
|
95 |
+
# to make highly symmetrical molecules easier to synthetise
|
96 |
+
score3 = 0.
|
97 |
+
if nAtoms > len(fps):
|
98 |
+
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
99 |
+
|
100 |
+
sascore = score1 + score2 + score3
|
101 |
+
|
102 |
+
# need to transform "raw" value into scale between 1 and 10
|
103 |
+
min = -4.0
|
104 |
+
max = 2.5
|
105 |
+
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
106 |
+
# smooth the 10-end
|
107 |
+
if sascore > 8.:
|
108 |
+
sascore = 8. + math.log(sascore + 1. - 9.)
|
109 |
+
if sascore > 10.:
|
110 |
+
sascore = 10.0
|
111 |
+
elif sascore < 1.:
|
112 |
+
sascore = 1.0
|
113 |
+
|
114 |
+
return sascore
|
115 |
+
|
116 |
+
|
117 |
+
def processMols(mols):
|
118 |
+
print('smiles\tName\tsa_score')
|
119 |
+
for i, m in enumerate(mols):
|
120 |
+
if m is None:
|
121 |
+
continue
|
122 |
+
|
123 |
+
s = calculateScore(m)
|
124 |
+
|
125 |
+
smiles = Chem.MolToSmiles(m)
|
126 |
+
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == '__main__':
|
130 |
+
import sys, time
|
131 |
+
|
132 |
+
t1 = time.time()
|
133 |
+
readFragmentScores("fpscores")
|
134 |
+
t2 = time.time()
|
135 |
+
|
136 |
+
suppl = Chem.SmilesMolSupplier(sys.argv[1])
|
137 |
+
t3 = time.time()
|
138 |
+
processMols(suppl)
|
139 |
+
t4 = time.time()
|
140 |
+
|
141 |
+
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
|
142 |
+
file=sys.stderr)
|
143 |
+
|
144 |
+
#
|
145 |
+
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
|
146 |
+
# All rights reserved.
|
147 |
+
#
|
148 |
+
# Redistribution and use in source and binary forms, with or without
|
149 |
+
# modification, are permitted provided that the following conditions are
|
150 |
+
# met:
|
151 |
+
#
|
152 |
+
# * Redistributions of source code must retain the above copyright
|
153 |
+
# notice, this list of conditions and the following disclaimer.
|
154 |
+
# * Redistributions in binary form must reproduce the above
|
155 |
+
# copyright notice, this list of conditions and the following
|
156 |
+
# disclaimer in the documentation and/or other materials provided
|
157 |
+
# with the distribution.
|
158 |
+
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
|
159 |
+
# nor the names of its contributors may be used to endorse or promote
|
160 |
+
# products derived from this software without specific prior written permission.
|
161 |
+
#
|
162 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
163 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
164 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
165 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
166 |
+
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
167 |
+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
168 |
+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
169 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
170 |
+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
171 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
172 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
173 |
+
#
|
src/edm.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
|
6 |
+
from src import utils
|
7 |
+
from src.egnn import Dynamics
|
8 |
+
from src.noise import GammaNetwork, PredefinedNoiseSchedule
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
from pdb import set_trace
|
12 |
+
|
13 |
+
|
14 |
+
class EDM(torch.nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
dynamics: Union[Dynamics],
|
18 |
+
in_node_nf: int,
|
19 |
+
n_dims: int,
|
20 |
+
timesteps: int = 1000,
|
21 |
+
noise_schedule='learned',
|
22 |
+
noise_precision=1e-4,
|
23 |
+
loss_type='vlb',
|
24 |
+
norm_values=(1., 1., 1.),
|
25 |
+
norm_biases=(None, 0., 0.),
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
if noise_schedule == 'learned':
|
29 |
+
assert loss_type == 'vlb', 'A noise schedule can only be learned with a vlb objective'
|
30 |
+
self.gamma = GammaNetwork()
|
31 |
+
else:
|
32 |
+
self.gamma = PredefinedNoiseSchedule(noise_schedule, timesteps=timesteps, precision=noise_precision)
|
33 |
+
|
34 |
+
self.dynamics = dynamics
|
35 |
+
self.in_node_nf = in_node_nf
|
36 |
+
self.n_dims = n_dims
|
37 |
+
self.T = timesteps
|
38 |
+
self.norm_values = norm_values
|
39 |
+
self.norm_biases = norm_biases
|
40 |
+
|
41 |
+
def forward(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context=None):
|
42 |
+
# Normalization and concatenation
|
43 |
+
x, h = self.normalize(x, h)
|
44 |
+
xh = torch.cat([x, h], dim=2)
|
45 |
+
|
46 |
+
# Volume change loss term
|
47 |
+
delta_log_px = self.delta_log_px(linker_mask).mean()
|
48 |
+
|
49 |
+
# Sample t
|
50 |
+
t_int = torch.randint(0, self.T + 1, size=(x.size(0), 1), device=x.device).float()
|
51 |
+
s_int = t_int - 1
|
52 |
+
t = t_int / self.T
|
53 |
+
s = s_int / self.T
|
54 |
+
|
55 |
+
# Masks for t=0 and t>0
|
56 |
+
t_is_zero = (t_int == 0).squeeze().float()
|
57 |
+
t_is_not_zero = 1 - t_is_zero
|
58 |
+
|
59 |
+
# Compute gamma_t and gamma_s according to the noise schedule
|
60 |
+
gamma_t = self.inflate_batch_array(self.gamma(t), x)
|
61 |
+
gamma_s = self.inflate_batch_array(self.gamma(s), x)
|
62 |
+
|
63 |
+
# Compute alpha_t and sigma_t from gamma
|
64 |
+
alpha_t = self.alpha(gamma_t, x)
|
65 |
+
sigma_t = self.sigma(gamma_t, x)
|
66 |
+
|
67 |
+
# Sample noise
|
68 |
+
# Note: only for linker
|
69 |
+
eps_t = self.sample_combined_position_feature_noise(n_samples=x.size(0), n_nodes=x.size(1), mask=linker_mask)
|
70 |
+
|
71 |
+
# Sample z_t given x, h for timestep t, from q(z_t | x, h)
|
72 |
+
# Note: keep fragments unchanged
|
73 |
+
z_t = alpha_t * xh + sigma_t * eps_t
|
74 |
+
z_t = xh * fragment_mask + z_t * linker_mask
|
75 |
+
|
76 |
+
# Neural net prediction
|
77 |
+
eps_t_hat = self.dynamics.forward(
|
78 |
+
xh=z_t,
|
79 |
+
t=t,
|
80 |
+
node_mask=node_mask,
|
81 |
+
linker_mask=linker_mask,
|
82 |
+
context=context,
|
83 |
+
edge_mask=edge_mask,
|
84 |
+
)
|
85 |
+
eps_t_hat = eps_t_hat * linker_mask
|
86 |
+
|
87 |
+
# Computing basic error (further used for computing NLL and L2-loss)
|
88 |
+
error_t = self.sum_except_batch((eps_t - eps_t_hat) ** 2)
|
89 |
+
|
90 |
+
# Computing L2-loss for t>0
|
91 |
+
normalization = (self.n_dims + self.in_node_nf) * self.numbers_of_nodes(linker_mask)
|
92 |
+
l2_loss = error_t / normalization
|
93 |
+
l2_loss = l2_loss.mean()
|
94 |
+
|
95 |
+
# The KL between q(z_T | x) and p(z_T) = Normal(0, 1) (should be close to zero)
|
96 |
+
kl_prior = self.kl_prior(xh, linker_mask).mean()
|
97 |
+
|
98 |
+
# Computing NLL middle term
|
99 |
+
SNR_weight = (self.SNR(gamma_s - gamma_t) - 1).squeeze(1).squeeze(1)
|
100 |
+
loss_term_t = self.T * 0.5 * SNR_weight * error_t
|
101 |
+
loss_term_t = (loss_term_t * t_is_not_zero).sum() / t_is_not_zero.sum()
|
102 |
+
|
103 |
+
# Computing noise returned by dynamics
|
104 |
+
noise = torch.norm(eps_t_hat, dim=[1, 2])
|
105 |
+
noise_t = (noise * t_is_not_zero).sum() / t_is_not_zero.sum()
|
106 |
+
|
107 |
+
if t_is_zero.sum() > 0:
|
108 |
+
# The _constants_ depending on sigma_0 from the
|
109 |
+
# cross entropy term E_q(z0 | x) [log p(x | z0)]
|
110 |
+
neg_log_constants = -self.log_constant_of_p_x_given_z0(x, linker_mask)
|
111 |
+
|
112 |
+
# Computes the L_0 term (even if gamma_t is not actually gamma_0)
|
113 |
+
# and selected only relevant via masking
|
114 |
+
loss_term_0 = -self.log_p_xh_given_z0_without_constants(h, z_t, gamma_t, eps_t, eps_t_hat, linker_mask)
|
115 |
+
loss_term_0 = loss_term_0 + neg_log_constants
|
116 |
+
loss_term_0 = (loss_term_0 * t_is_zero).sum() / t_is_zero.sum()
|
117 |
+
|
118 |
+
# Computing noise returned by dynamics
|
119 |
+
noise_0 = (noise * t_is_zero).sum() / t_is_zero.sum()
|
120 |
+
else:
|
121 |
+
loss_term_0 = 0.
|
122 |
+
noise_0 = 0.
|
123 |
+
|
124 |
+
return delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def sample_chain(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context, keep_frames=None):
|
128 |
+
n_samples = x.size(0)
|
129 |
+
n_nodes = x.size(1)
|
130 |
+
|
131 |
+
# Normalization and concatenation
|
132 |
+
x, h, = self.normalize(x, h)
|
133 |
+
xh = torch.cat([x, h], dim=2)
|
134 |
+
|
135 |
+
# Initial linker sampling from N(0, I)
|
136 |
+
z = self.sample_combined_position_feature_noise(n_samples, n_nodes, mask=linker_mask)
|
137 |
+
z = xh * fragment_mask + z * linker_mask
|
138 |
+
|
139 |
+
if keep_frames is None:
|
140 |
+
keep_frames = self.T
|
141 |
+
else:
|
142 |
+
assert keep_frames <= self.T
|
143 |
+
chain = torch.zeros((keep_frames,) + z.size(), device=z.device)
|
144 |
+
|
145 |
+
# Sample p(z_s | z_t)
|
146 |
+
for s in reversed(range(0, self.T)):
|
147 |
+
s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
|
148 |
+
t_array = s_array + 1
|
149 |
+
s_array = s_array / self.T
|
150 |
+
t_array = t_array / self.T
|
151 |
+
|
152 |
+
z = self.sample_p_zs_given_zt_only_linker(
|
153 |
+
s=s_array,
|
154 |
+
t=t_array,
|
155 |
+
z_t=z,
|
156 |
+
node_mask=node_mask,
|
157 |
+
fragment_mask=fragment_mask,
|
158 |
+
linker_mask=linker_mask,
|
159 |
+
edge_mask=edge_mask,
|
160 |
+
context=context,
|
161 |
+
)
|
162 |
+
write_index = (s * keep_frames) // self.T
|
163 |
+
chain[write_index] = self.unnormalize_z(z)
|
164 |
+
|
165 |
+
# Finally sample p(x, h | z_0)
|
166 |
+
x, h = self.sample_p_xh_given_z0_only_linker(
|
167 |
+
z_0=z,
|
168 |
+
node_mask=node_mask,
|
169 |
+
fragment_mask=fragment_mask,
|
170 |
+
linker_mask=linker_mask,
|
171 |
+
edge_mask=edge_mask,
|
172 |
+
context=context,
|
173 |
+
)
|
174 |
+
chain[0] = torch.cat([x, h], dim=2)
|
175 |
+
|
176 |
+
return chain
|
177 |
+
|
178 |
+
def sample_p_zs_given_zt_only_linker(self, s, t, z_t, node_mask, fragment_mask, linker_mask, edge_mask, context):
|
179 |
+
"""Samples from zs ~ p(zs | zt). Only used during sampling. Samples only linker features and coords"""
|
180 |
+
gamma_s = self.gamma(s)
|
181 |
+
gamma_t = self.gamma(t)
|
182 |
+
|
183 |
+
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t)
|
184 |
+
sigma_s = self.sigma(gamma_s, target_tensor=z_t)
|
185 |
+
sigma_t = self.sigma(gamma_t, target_tensor=z_t)
|
186 |
+
|
187 |
+
# Neural net prediction.
|
188 |
+
eps_hat = self.dynamics.forward(
|
189 |
+
xh=z_t,
|
190 |
+
t=t,
|
191 |
+
node_mask=node_mask,
|
192 |
+
linker_mask=linker_mask,
|
193 |
+
context=context,
|
194 |
+
edge_mask=edge_mask,
|
195 |
+
)
|
196 |
+
eps_hat = eps_hat * linker_mask
|
197 |
+
|
198 |
+
# Compute mu for p(z_s | z_t)
|
199 |
+
mu = z_t / alpha_t_given_s - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_hat
|
200 |
+
|
201 |
+
# Compute sigma for p(z_s | z_t)
|
202 |
+
sigma = sigma_t_given_s * sigma_s / sigma_t
|
203 |
+
|
204 |
+
# Sample z_s given the parameters derived from zt
|
205 |
+
z_s = self.sample_normal(mu, sigma, linker_mask)
|
206 |
+
z_s = z_t * fragment_mask + z_s * linker_mask
|
207 |
+
|
208 |
+
return z_s
|
209 |
+
|
210 |
+
def sample_p_xh_given_z0_only_linker(self, z_0, node_mask, fragment_mask, linker_mask, edge_mask, context):
|
211 |
+
"""Samples x ~ p(x|z0). Samples only linker features and coords"""
|
212 |
+
zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device)
|
213 |
+
gamma_0 = self.gamma(zeros)
|
214 |
+
|
215 |
+
# Computes sqrt(sigma_0^2 / alpha_0^2)
|
216 |
+
sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1)
|
217 |
+
eps_hat = self.dynamics.forward(
|
218 |
+
t=zeros,
|
219 |
+
xh=z_0,
|
220 |
+
node_mask=node_mask,
|
221 |
+
linker_mask=linker_mask,
|
222 |
+
edge_mask=edge_mask,
|
223 |
+
context=context
|
224 |
+
)
|
225 |
+
eps_hat = eps_hat * linker_mask
|
226 |
+
|
227 |
+
mu_x = self.compute_x_pred(eps_t=eps_hat, z_t=z_0, gamma_t=gamma_0)
|
228 |
+
xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=linker_mask)
|
229 |
+
xh = z_0 * fragment_mask + xh * linker_mask
|
230 |
+
|
231 |
+
x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:]
|
232 |
+
x, h = self.unnormalize(x, h)
|
233 |
+
h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask
|
234 |
+
|
235 |
+
return x, h
|
236 |
+
|
237 |
+
def compute_x_pred(self, eps_t, z_t, gamma_t):
|
238 |
+
"""Computes x_pred, i.e. the most likely prediction of x."""
|
239 |
+
sigma_t = self.sigma(gamma_t, target_tensor=eps_t)
|
240 |
+
alpha_t = self.alpha(gamma_t, target_tensor=eps_t)
|
241 |
+
x_pred = 1. / alpha_t * (z_t - sigma_t * eps_t)
|
242 |
+
return x_pred
|
243 |
+
|
244 |
+
def kl_prior(self, xh, mask):
|
245 |
+
"""
|
246 |
+
Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
|
247 |
+
This is essentially a lot of work for something that is in practice negligible in the loss.
|
248 |
+
However, you compute it so that you see it when you've made a mistake in your noise schedule.
|
249 |
+
"""
|
250 |
+
# Compute the last alpha value, alpha_T
|
251 |
+
ones = torch.ones((xh.size(0), 1), device=xh.device)
|
252 |
+
gamma_T = self.gamma(ones)
|
253 |
+
alpha_T = self.alpha(gamma_T, xh)
|
254 |
+
|
255 |
+
# Compute means
|
256 |
+
mu_T = alpha_T * xh
|
257 |
+
mu_T_x, mu_T_h = mu_T[:, :, :self.n_dims], mu_T[:, :, self.n_dims:]
|
258 |
+
|
259 |
+
# Compute standard deviations (only batch axis for x-part, inflated for h-part)
|
260 |
+
sigma_T_x = self.sigma(gamma_T, mu_T_x).view(-1) # Remove inflate, only keep batch dimension for x-part
|
261 |
+
sigma_T_h = self.sigma(gamma_T, mu_T_h)
|
262 |
+
|
263 |
+
# Compute KL for h-part
|
264 |
+
zeros, ones = torch.zeros_like(mu_T_h), torch.ones_like(sigma_T_h)
|
265 |
+
kl_distance_h = self.gaussian_kl(mu_T_h, sigma_T_h, zeros, ones)
|
266 |
+
|
267 |
+
# Compute KL for x-part
|
268 |
+
zeros, ones = torch.zeros_like(mu_T_x), torch.ones_like(sigma_T_x)
|
269 |
+
d = self.dimensionality(mask)
|
270 |
+
kl_distance_x = self.gaussian_kl_for_dimension(mu_T_x, sigma_T_x, zeros, ones, d=d)
|
271 |
+
|
272 |
+
return kl_distance_x + kl_distance_h
|
273 |
+
|
274 |
+
def log_constant_of_p_x_given_z0(self, x, mask):
|
275 |
+
batch_size = x.size(0)
|
276 |
+
degrees_of_freedom_x = self.dimensionality(mask)
|
277 |
+
zeros = torch.zeros((batch_size, 1), device=x.device)
|
278 |
+
gamma_0 = self.gamma(zeros)
|
279 |
+
|
280 |
+
# Recall that sigma_x = sqrt(sigma_0^2 / alpha_0^2) = SNR(-0.5 gamma_0)
|
281 |
+
log_sigma_x = 0.5 * gamma_0.view(batch_size)
|
282 |
+
|
283 |
+
return degrees_of_freedom_x * (- log_sigma_x - 0.5 * np.log(2 * np.pi))
|
284 |
+
|
285 |
+
def log_p_xh_given_z0_without_constants(self, h, z_0, gamma_0, eps, eps_hat, mask, epsilon=1e-10):
|
286 |
+
# Discrete properties are predicted directly from z_0
|
287 |
+
z_h = z_0[:, :, self.n_dims:]
|
288 |
+
|
289 |
+
# Take only part over x
|
290 |
+
eps_x = eps[:, :, :self.n_dims]
|
291 |
+
eps_hat_x = eps_hat[:, :, :self.n_dims]
|
292 |
+
|
293 |
+
# Compute sigma_0 and rescale to the integer scale of the data
|
294 |
+
sigma_0 = self.sigma(gamma_0, target_tensor=z_0) * self.norm_values[1]
|
295 |
+
|
296 |
+
# Computes the error for the distribution N(x | 1 / alpha_0 z_0 + sigma_0/alpha_0 eps_0, sigma_0 / alpha_0),
|
297 |
+
# the weighting in the epsilon parametrization is exactly '1'
|
298 |
+
log_p_x_given_z_without_constants = -0.5 * self.sum_except_batch((eps_x - eps_hat_x) ** 2)
|
299 |
+
|
300 |
+
# Categorical features
|
301 |
+
# Compute delta indicator masks
|
302 |
+
h = h * self.norm_values[1] + self.norm_biases[1]
|
303 |
+
estimated_h = z_h * self.norm_values[1] + self.norm_biases[1]
|
304 |
+
|
305 |
+
# Centered h_cat around 1, since onehot encoded
|
306 |
+
centered_h = estimated_h - 1
|
307 |
+
|
308 |
+
# Compute integrals from 0.5 to 1.5 of the normal distribution
|
309 |
+
# N(mean=centered_h_cat, stdev=sigma_0_cat)
|
310 |
+
log_p_h_proportional = torch.log(
|
311 |
+
self.cdf_standard_gaussian((centered_h + 0.5) / sigma_0) -
|
312 |
+
self.cdf_standard_gaussian((centered_h - 0.5) / sigma_0) +
|
313 |
+
epsilon
|
314 |
+
)
|
315 |
+
|
316 |
+
# Normalize the distribution over the categories
|
317 |
+
log_Z = torch.logsumexp(log_p_h_proportional, dim=2, keepdim=True)
|
318 |
+
log_probabilities = log_p_h_proportional - log_Z
|
319 |
+
|
320 |
+
# Select the log_prob of the current category using the onehot representation
|
321 |
+
log_p_h_given_z = self.sum_except_batch(log_probabilities * h * mask)
|
322 |
+
|
323 |
+
# Combine log probabilities for x and h
|
324 |
+
log_p_xh_given_z = log_p_x_given_z_without_constants + log_p_h_given_z
|
325 |
+
|
326 |
+
return log_p_xh_given_z
|
327 |
+
|
328 |
+
def sample_combined_position_feature_noise(self, n_samples, n_nodes, mask):
|
329 |
+
z_x = utils.sample_gaussian_with_mask(
|
330 |
+
size=(n_samples, n_nodes, self.n_dims),
|
331 |
+
device=mask.device,
|
332 |
+
node_mask=mask
|
333 |
+
)
|
334 |
+
z_h = utils.sample_gaussian_with_mask(
|
335 |
+
size=(n_samples, n_nodes, self.in_node_nf),
|
336 |
+
device=mask.device,
|
337 |
+
node_mask=mask
|
338 |
+
)
|
339 |
+
z = torch.cat([z_x, z_h], dim=2)
|
340 |
+
return z
|
341 |
+
|
342 |
+
def sample_normal(self, mu, sigma, node_mask):
|
343 |
+
"""Samples from a Normal distribution."""
|
344 |
+
eps = self.sample_combined_position_feature_noise(mu.size(0), mu.size(1), node_mask)
|
345 |
+
return mu + sigma * eps
|
346 |
+
|
347 |
+
def normalize(self, x, h):
|
348 |
+
new_x = x / self.norm_values[0]
|
349 |
+
new_h = (h.float() - self.norm_biases[1]) / self.norm_values[1]
|
350 |
+
return new_x, new_h
|
351 |
+
|
352 |
+
def unnormalize(self, x, h):
|
353 |
+
new_x = x * self.norm_values[0]
|
354 |
+
new_h = h * self.norm_values[1] + self.norm_biases[1]
|
355 |
+
return new_x, new_h
|
356 |
+
|
357 |
+
def unnormalize_z(self, z):
|
358 |
+
assert z.size(2) == self.n_dims + self.in_node_nf
|
359 |
+
x, h = z[:, :, :self.n_dims], z[:, :, self.n_dims:]
|
360 |
+
x, h = self.unnormalize(x, h)
|
361 |
+
return torch.cat([x, h], dim=2)
|
362 |
+
|
363 |
+
def delta_log_px(self, mask):
|
364 |
+
return -self.dimensionality(mask) * np.log(self.norm_values[0])
|
365 |
+
|
366 |
+
def dimensionality(self, mask):
|
367 |
+
return self.numbers_of_nodes(mask) * self.n_dims
|
368 |
+
|
369 |
+
def sigma(self, gamma, target_tensor):
|
370 |
+
"""Computes sigma given gamma."""
|
371 |
+
return self.inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), target_tensor)
|
372 |
+
|
373 |
+
def alpha(self, gamma, target_tensor):
|
374 |
+
"""Computes alpha given gamma."""
|
375 |
+
return self.inflate_batch_array(torch.sqrt(torch.sigmoid(-gamma)), target_tensor)
|
376 |
+
|
377 |
+
def SNR(self, gamma):
|
378 |
+
"""Computes signal to noise ratio (alpha^2/sigma^2) given gamma."""
|
379 |
+
return torch.exp(-gamma)
|
380 |
+
|
381 |
+
def sigma_and_alpha_t_given_s(self, gamma_t: torch.Tensor, gamma_s: torch.Tensor, target_tensor: torch.Tensor):
|
382 |
+
"""
|
383 |
+
Computes sigma t given s, using gamma_t and gamma_s. Used during sampling.
|
384 |
+
|
385 |
+
These are defined as:
|
386 |
+
alpha t given s = alpha t / alpha s,
|
387 |
+
sigma t given s = sqrt(1 - (alpha t given s) ^2 ).
|
388 |
+
"""
|
389 |
+
sigma2_t_given_s = self.inflate_batch_array(
|
390 |
+
-self.expm1(self.softplus(gamma_s) - self.softplus(gamma_t)),
|
391 |
+
target_tensor
|
392 |
+
)
|
393 |
+
|
394 |
+
# alpha_t_given_s = alpha_t / alpha_s
|
395 |
+
log_alpha2_t = F.logsigmoid(-gamma_t)
|
396 |
+
log_alpha2_s = F.logsigmoid(-gamma_s)
|
397 |
+
log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s
|
398 |
+
|
399 |
+
alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s)
|
400 |
+
alpha_t_given_s = self.inflate_batch_array(alpha_t_given_s, target_tensor)
|
401 |
+
sigma_t_given_s = torch.sqrt(sigma2_t_given_s)
|
402 |
+
|
403 |
+
return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s
|
404 |
+
|
405 |
+
@staticmethod
|
406 |
+
def numbers_of_nodes(mask):
|
407 |
+
return torch.sum(mask.squeeze(2), dim=1)
|
408 |
+
|
409 |
+
@staticmethod
|
410 |
+
def inflate_batch_array(array, target):
|
411 |
+
"""
|
412 |
+
Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,),
|
413 |
+
or possibly more empty axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape.
|
414 |
+
"""
|
415 |
+
target_shape = (array.size(0),) + (1,) * (len(target.size()) - 1)
|
416 |
+
return array.view(target_shape)
|
417 |
+
|
418 |
+
@staticmethod
|
419 |
+
def sum_except_batch(x):
|
420 |
+
return x.view(x.size(0), -1).sum(-1)
|
421 |
+
|
422 |
+
@staticmethod
|
423 |
+
def expm1(x: torch.Tensor) -> torch.Tensor:
|
424 |
+
return torch.expm1(x)
|
425 |
+
|
426 |
+
@staticmethod
|
427 |
+
def softplus(x: torch.Tensor) -> torch.Tensor:
|
428 |
+
return F.softplus(x)
|
429 |
+
|
430 |
+
@staticmethod
|
431 |
+
def cdf_standard_gaussian(x):
|
432 |
+
return 0.5 * (1. + torch.erf(x / math.sqrt(2)))
|
433 |
+
|
434 |
+
@staticmethod
|
435 |
+
def gaussian_kl(q_mu, q_sigma, p_mu, p_sigma):
|
436 |
+
"""
|
437 |
+
Computes the KL distance between two normal distributions.
|
438 |
+
Args:
|
439 |
+
q_mu: Mean of distribution q.
|
440 |
+
q_sigma: Standard deviation of distribution q.
|
441 |
+
p_mu: Mean of distribution p.
|
442 |
+
p_sigma: Standard deviation of distribution p.
|
443 |
+
Returns:
|
444 |
+
The KL distance, summed over all dimensions except the batch dim.
|
445 |
+
"""
|
446 |
+
kl = torch.log(p_sigma / q_sigma) + 0.5 * (q_sigma ** 2 + (q_mu - p_mu) ** 2) / (p_sigma ** 2) - 0.5
|
447 |
+
return EDM.sum_except_batch(kl)
|
448 |
+
|
449 |
+
@staticmethod
|
450 |
+
def gaussian_kl_for_dimension(q_mu, q_sigma, p_mu, p_sigma, d):
|
451 |
+
"""
|
452 |
+
Computes the KL distance between two normal distributions taking the dimension into account.
|
453 |
+
Args:
|
454 |
+
q_mu: Mean of distribution q.
|
455 |
+
q_sigma: Standard deviation of distribution q.
|
456 |
+
p_mu: Mean of distribution p.
|
457 |
+
p_sigma: Standard deviation of distribution p.
|
458 |
+
d: dimension
|
459 |
+
Returns:
|
460 |
+
The KL distance, summed over all dimensions except the batch dim.
|
461 |
+
"""
|
462 |
+
mu_norm_2 = EDM.sum_except_batch((q_mu - p_mu) ** 2)
|
463 |
+
return d * torch.log(p_sigma / q_sigma) + 0.5 * (d * q_sigma ** 2 + mu_norm_2) / (p_sigma ** 2) - 0.5 * d
|
464 |
+
|
465 |
+
|
466 |
+
class InpaintingEDM(EDM):
|
467 |
+
def forward(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context=None):
|
468 |
+
# Normalization and concatenation
|
469 |
+
x, h = self.normalize(x, h)
|
470 |
+
xh = torch.cat([x, h], dim=2)
|
471 |
+
|
472 |
+
# Volume change loss term
|
473 |
+
delta_log_px = self.delta_log_px(node_mask).mean()
|
474 |
+
|
475 |
+
# Sample t
|
476 |
+
t_int = torch.randint(0, self.T + 1, size=(x.size(0), 1), device=x.device).float()
|
477 |
+
s_int = t_int - 1
|
478 |
+
t = t_int / self.T
|
479 |
+
s = s_int / self.T
|
480 |
+
|
481 |
+
# Masks for t=0 and t>0
|
482 |
+
t_is_zero = (t_int == 0).squeeze().float()
|
483 |
+
t_is_not_zero = 1 - t_is_zero
|
484 |
+
|
485 |
+
# Compute gamma_t and gamma_s according to the noise schedule
|
486 |
+
gamma_t = self.inflate_batch_array(self.gamma(t), x)
|
487 |
+
gamma_s = self.inflate_batch_array(self.gamma(s), x)
|
488 |
+
|
489 |
+
# Compute alpha_t and sigma_t from gamma
|
490 |
+
alpha_t = self.alpha(gamma_t, x)
|
491 |
+
sigma_t = self.sigma(gamma_t, x)
|
492 |
+
|
493 |
+
# Sample noise
|
494 |
+
eps_t = self.sample_combined_position_feature_noise(n_samples=x.size(0), n_nodes=x.size(1), mask=node_mask)
|
495 |
+
|
496 |
+
# Sample z_t given x, h for timestep t, from q(z_t | x, h)
|
497 |
+
# Note: keep fragments unchanged
|
498 |
+
z_t = alpha_t * xh + sigma_t * eps_t
|
499 |
+
|
500 |
+
# Neural net prediction
|
501 |
+
eps_t_hat = self.dynamics.forward(
|
502 |
+
xh=z_t,
|
503 |
+
t=t,
|
504 |
+
node_mask=node_mask,
|
505 |
+
linker_mask=None,
|
506 |
+
context=context,
|
507 |
+
edge_mask=edge_mask,
|
508 |
+
)
|
509 |
+
|
510 |
+
# Computing basic error (further used for computing NLL and L2-loss)
|
511 |
+
error_t = self.sum_except_batch((eps_t - eps_t_hat) ** 2)
|
512 |
+
|
513 |
+
# Computing L2-loss for t>0
|
514 |
+
normalization = (self.n_dims + self.in_node_nf) * self.numbers_of_nodes(node_mask)
|
515 |
+
l2_loss = error_t / normalization
|
516 |
+
l2_loss = l2_loss.mean()
|
517 |
+
|
518 |
+
# The KL between q(z_T | x) and p(z_T) = Normal(0, 1) (should be close to zero)
|
519 |
+
kl_prior = self.kl_prior(xh, node_mask).mean()
|
520 |
+
|
521 |
+
# Computing NLL middle term
|
522 |
+
SNR_weight = (self.SNR(gamma_s - gamma_t) - 1).squeeze(1).squeeze(1)
|
523 |
+
loss_term_t = self.T * 0.5 * SNR_weight * error_t
|
524 |
+
loss_term_t = (loss_term_t * t_is_not_zero).sum() / t_is_not_zero.sum()
|
525 |
+
|
526 |
+
# Computing noise returned by dynamics
|
527 |
+
noise = torch.norm(eps_t_hat, dim=[1, 2])
|
528 |
+
noise_t = (noise * t_is_not_zero).sum() / t_is_not_zero.sum()
|
529 |
+
|
530 |
+
if t_is_zero.sum() > 0:
|
531 |
+
# The _constants_ depending on sigma_0 from the
|
532 |
+
# cross entropy term E_q(z0 | x) [log p(x | z0)]
|
533 |
+
neg_log_constants = -self.log_constant_of_p_x_given_z0(x, node_mask)
|
534 |
+
|
535 |
+
# Computes the L_0 term (even if gamma_t is not actually gamma_0)
|
536 |
+
# and selected only relevant via masking
|
537 |
+
loss_term_0 = -self.log_p_xh_given_z0_without_constants(h, z_t, gamma_t, eps_t, eps_t_hat, node_mask)
|
538 |
+
loss_term_0 = loss_term_0 + neg_log_constants
|
539 |
+
loss_term_0 = (loss_term_0 * t_is_zero).sum() / t_is_zero.sum()
|
540 |
+
|
541 |
+
# Computing noise returned by dynamics
|
542 |
+
noise_0 = (noise * t_is_zero).sum() / t_is_zero.sum()
|
543 |
+
else:
|
544 |
+
loss_term_0 = 0.
|
545 |
+
noise_0 = 0.
|
546 |
+
|
547 |
+
return delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0
|
548 |
+
|
549 |
+
@torch.no_grad()
|
550 |
+
def sample_chain(self, x, h, node_mask, edge_mask, fragment_mask, linker_mask, context, keep_frames=None):
|
551 |
+
n_samples = x.size(0)
|
552 |
+
n_nodes = x.size(1)
|
553 |
+
|
554 |
+
# Normalization and concatenation
|
555 |
+
x, h, = self.normalize(x, h)
|
556 |
+
xh = torch.cat([x, h], dim=2)
|
557 |
+
|
558 |
+
# Sampling initial noise
|
559 |
+
z = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask)
|
560 |
+
|
561 |
+
if keep_frames is None:
|
562 |
+
keep_frames = self.T
|
563 |
+
else:
|
564 |
+
assert keep_frames <= self.T
|
565 |
+
chain = torch.zeros((keep_frames,) + z.size(), device=z.device)
|
566 |
+
|
567 |
+
# Sample p(z_s | z_t)
|
568 |
+
for s in reversed(range(0, self.T)):
|
569 |
+
s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
|
570 |
+
t_array = s_array + 1
|
571 |
+
s_array = s_array / self.T
|
572 |
+
t_array = t_array / self.T
|
573 |
+
|
574 |
+
z_linker_only_sampled = self.sample_p_zs_given_zt(
|
575 |
+
s=s_array,
|
576 |
+
t=t_array,
|
577 |
+
z_t=z,
|
578 |
+
node_mask=node_mask,
|
579 |
+
edge_mask=edge_mask,
|
580 |
+
context=context,
|
581 |
+
)
|
582 |
+
z_fragments_only_sampled = self.sample_q_zs_given_zt_and_x(
|
583 |
+
s=s_array,
|
584 |
+
t=t_array,
|
585 |
+
z_t=z,
|
586 |
+
x=xh * fragment_mask,
|
587 |
+
node_mask=fragment_mask,
|
588 |
+
)
|
589 |
+
z = z_linker_only_sampled * linker_mask + z_fragments_only_sampled * fragment_mask
|
590 |
+
|
591 |
+
# Project down to avoid numerical runaway of the center of gravity
|
592 |
+
z_x = utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask)
|
593 |
+
z_h = z[:, :, self.n_dims:]
|
594 |
+
z = torch.cat([z_x, z_h], dim=2)
|
595 |
+
|
596 |
+
# Saving step to the chain
|
597 |
+
write_index = (s * keep_frames) // self.T
|
598 |
+
chain[write_index] = self.unnormalize_z(z)
|
599 |
+
|
600 |
+
# Finally sample p(x, h | z_0)
|
601 |
+
x_out_linker, h_out_linker = self.sample_p_xh_given_z0(
|
602 |
+
z_0=z,
|
603 |
+
node_mask=node_mask,
|
604 |
+
edge_mask=edge_mask,
|
605 |
+
context=context,
|
606 |
+
)
|
607 |
+
x_out_fragments, h_out_fragments = self.sample_q_xh_given_z0_and_x(z_0=z, node_mask=node_mask)
|
608 |
+
|
609 |
+
xh_out_linker = torch.cat([x_out_linker, h_out_linker], dim=2)
|
610 |
+
xh_out_fragments = torch.cat([x_out_fragments, h_out_fragments], dim=2)
|
611 |
+
xh_out = xh_out_linker * linker_mask + xh_out_fragments * fragment_mask
|
612 |
+
|
613 |
+
# Overwrite last frame with the resulting x and h
|
614 |
+
chain[0] = xh_out
|
615 |
+
|
616 |
+
return chain
|
617 |
+
|
618 |
+
def sample_p_zs_given_zt(self, s, t, z_t, node_mask, edge_mask, context):
|
619 |
+
"""Samples from zs ~ p(zs | zt). Only used during sampling"""
|
620 |
+
gamma_s = self.gamma(s)
|
621 |
+
gamma_t = self.gamma(t)
|
622 |
+
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t)
|
623 |
+
|
624 |
+
sigma_s = self.sigma(gamma_s, target_tensor=z_t)
|
625 |
+
sigma_t = self.sigma(gamma_t, target_tensor=z_t)
|
626 |
+
|
627 |
+
# Neural net prediction.
|
628 |
+
eps_hat = self.dynamics.forward(
|
629 |
+
xh=z_t,
|
630 |
+
t=t,
|
631 |
+
node_mask=node_mask,
|
632 |
+
linker_mask=None,
|
633 |
+
edge_mask=edge_mask,
|
634 |
+
context=context
|
635 |
+
)
|
636 |
+
|
637 |
+
# Checking that epsilon is centered around linker COM
|
638 |
+
utils.assert_mean_zero_with_mask(eps_hat[:, :, :self.n_dims], node_mask)
|
639 |
+
|
640 |
+
# Compute mu for p(z_s | z_t)
|
641 |
+
mu = z_t / alpha_t_given_s - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_hat
|
642 |
+
|
643 |
+
# Compute sigma for p(z_s | z_t)
|
644 |
+
sigma = sigma_t_given_s * sigma_s / sigma_t
|
645 |
+
|
646 |
+
# Sample z_s given the parameters derived from z_t
|
647 |
+
z_s = self.sample_normal(mu, sigma, node_mask)
|
648 |
+
return z_s
|
649 |
+
|
650 |
+
def sample_q_zs_given_zt_and_x(self, s, t, z_t, x, node_mask):
|
651 |
+
"""Samples from zs ~ q(zs | zt, x). Only used during sampling. Samples only linker features and coords"""
|
652 |
+
gamma_s = self.gamma(s)
|
653 |
+
gamma_t = self.gamma(t)
|
654 |
+
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t)
|
655 |
+
|
656 |
+
sigma_s = self.sigma(gamma_s, target_tensor=z_t)
|
657 |
+
sigma_t = self.sigma(gamma_t, target_tensor=z_t)
|
658 |
+
alpha_s = self.alpha(gamma_s, x)
|
659 |
+
|
660 |
+
mu = (
|
661 |
+
alpha_t_given_s * (sigma_s ** 2) / (sigma_t ** 2) * z_t +
|
662 |
+
alpha_s * sigma2_t_given_s / (sigma_t ** 2) * x
|
663 |
+
)
|
664 |
+
|
665 |
+
# Compute sigma for p(zs | zt)
|
666 |
+
sigma = sigma_t_given_s * sigma_s / sigma_t
|
667 |
+
|
668 |
+
# Sample zs given the parameters derived from zt
|
669 |
+
z_s = self.sample_normal(mu, sigma, node_mask)
|
670 |
+
return z_s
|
671 |
+
|
672 |
+
def sample_p_xh_given_z0(self, z_0, node_mask, edge_mask, context):
|
673 |
+
"""Samples x ~ p(x|z0). Samples only linker features and coords"""
|
674 |
+
zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device)
|
675 |
+
gamma_0 = self.gamma(zeros)
|
676 |
+
|
677 |
+
# Computes sqrt(sigma_0^2 / alpha_0^2)
|
678 |
+
sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1)
|
679 |
+
eps_hat = self.dynamics.forward(
|
680 |
+
xh=z_0,
|
681 |
+
t=zeros,
|
682 |
+
node_mask=node_mask,
|
683 |
+
linker_mask=None,
|
684 |
+
edge_mask=edge_mask,
|
685 |
+
context=context
|
686 |
+
)
|
687 |
+
utils.assert_mean_zero_with_mask(eps_hat[:, :, :self.n_dims], node_mask)
|
688 |
+
|
689 |
+
mu_x = self.compute_x_pred(eps_hat, z_0, gamma_0)
|
690 |
+
xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=node_mask)
|
691 |
+
|
692 |
+
x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:]
|
693 |
+
x, h = self.unnormalize(x, h)
|
694 |
+
h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask
|
695 |
+
|
696 |
+
return x, h
|
697 |
+
|
698 |
+
def sample_q_xh_given_z0_and_x(self, z_0, node_mask):
|
699 |
+
"""Samples x ~ q(x|z0). Samples only linker features and coords"""
|
700 |
+
zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device)
|
701 |
+
gamma_0 = self.gamma(zeros)
|
702 |
+
alpha_0 = self.alpha(gamma_0, z_0)
|
703 |
+
sigma_0 = self.sigma(gamma_0, z_0)
|
704 |
+
|
705 |
+
eps = self.sample_combined_position_feature_noise(z_0.size(0), z_0.size(1), node_mask)
|
706 |
+
|
707 |
+
xh = (1 / alpha_0) * z_0 - (sigma_0 / alpha_0) * eps
|
708 |
+
|
709 |
+
x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:]
|
710 |
+
x, h = self.unnormalize(x, h)
|
711 |
+
h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask
|
712 |
+
|
713 |
+
return x, h
|
714 |
+
|
715 |
+
def sample_combined_position_feature_noise(self, n_samples, n_nodes, mask):
|
716 |
+
z_x = utils.sample_center_gravity_zero_gaussian_with_mask(
|
717 |
+
size=(n_samples, n_nodes, self.n_dims),
|
718 |
+
device=mask.device,
|
719 |
+
node_mask=mask
|
720 |
+
)
|
721 |
+
z_h = utils.sample_gaussian_with_mask(
|
722 |
+
size=(n_samples, n_nodes, self.in_node_nf),
|
723 |
+
device=mask.device,
|
724 |
+
node_mask=mask
|
725 |
+
)
|
726 |
+
z = torch.cat([z_x, z_h], dim=2)
|
727 |
+
return z
|
728 |
+
|
729 |
+
def dimensionality(self, mask):
|
730 |
+
return (self.numbers_of_nodes(mask) - 1) * self.n_dims
|
src/egnn.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from src import utils
|
7 |
+
from pdb import set_trace
|
8 |
+
|
9 |
+
|
10 |
+
class GCL(nn.Module):
|
11 |
+
def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method, activation,
|
12 |
+
edges_in_d=0, nodes_att_dim=0, attention=False, normalization=None):
|
13 |
+
super(GCL, self).__init__()
|
14 |
+
input_edge = input_nf * 2
|
15 |
+
self.normalization_factor = normalization_factor
|
16 |
+
self.aggregation_method = aggregation_method
|
17 |
+
self.attention = attention
|
18 |
+
|
19 |
+
self.edge_mlp = nn.Sequential(
|
20 |
+
nn.Linear(input_edge + edges_in_d, hidden_nf),
|
21 |
+
activation,
|
22 |
+
nn.Linear(hidden_nf, hidden_nf),
|
23 |
+
activation)
|
24 |
+
|
25 |
+
if normalization is None:
|
26 |
+
self.node_mlp = nn.Sequential(
|
27 |
+
nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
|
28 |
+
activation,
|
29 |
+
nn.Linear(hidden_nf, output_nf)
|
30 |
+
)
|
31 |
+
elif normalization == 'batch_norm':
|
32 |
+
self.node_mlp = nn.Sequential(
|
33 |
+
nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
|
34 |
+
nn.BatchNorm1d(hidden_nf),
|
35 |
+
activation,
|
36 |
+
nn.Linear(hidden_nf, output_nf),
|
37 |
+
nn.BatchNorm1d(output_nf),
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
if self.attention:
|
43 |
+
self.att_mlp = nn.Sequential(nn.Linear(hidden_nf, 1), nn.Sigmoid())
|
44 |
+
|
45 |
+
def edge_model(self, source, target, edge_attr, edge_mask):
|
46 |
+
if edge_attr is None: # Unused.
|
47 |
+
out = torch.cat([source, target], dim=1)
|
48 |
+
else:
|
49 |
+
out = torch.cat([source, target, edge_attr], dim=1)
|
50 |
+
mij = self.edge_mlp(out)
|
51 |
+
|
52 |
+
if self.attention:
|
53 |
+
att_val = self.att_mlp(mij)
|
54 |
+
out = mij * att_val
|
55 |
+
else:
|
56 |
+
out = mij
|
57 |
+
|
58 |
+
if edge_mask is not None:
|
59 |
+
out = out * edge_mask
|
60 |
+
return out, mij
|
61 |
+
|
62 |
+
def node_model(self, x, edge_index, edge_attr, node_attr):
|
63 |
+
row, col = edge_index
|
64 |
+
agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
|
65 |
+
normalization_factor=self.normalization_factor,
|
66 |
+
aggregation_method=self.aggregation_method)
|
67 |
+
if node_attr is not None:
|
68 |
+
agg = torch.cat([x, agg, node_attr], dim=1)
|
69 |
+
else:
|
70 |
+
agg = torch.cat([x, agg], dim=1)
|
71 |
+
out = x + self.node_mlp(agg)
|
72 |
+
return out, agg
|
73 |
+
|
74 |
+
def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
|
75 |
+
row, col = edge_index
|
76 |
+
edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
|
77 |
+
h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
|
78 |
+
if node_mask is not None:
|
79 |
+
h = h * node_mask
|
80 |
+
return h, mij
|
81 |
+
|
82 |
+
|
83 |
+
class EquivariantUpdate(nn.Module):
|
84 |
+
def __init__(self, hidden_nf, normalization_factor, aggregation_method,
|
85 |
+
edges_in_d=1, activation=nn.SiLU(), tanh=False, coords_range=10.0):
|
86 |
+
super(EquivariantUpdate, self).__init__()
|
87 |
+
self.tanh = tanh
|
88 |
+
self.coords_range = coords_range
|
89 |
+
input_edge = hidden_nf * 2 + edges_in_d
|
90 |
+
layer = nn.Linear(hidden_nf, 1, bias=False)
|
91 |
+
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
|
92 |
+
self.coord_mlp = nn.Sequential(
|
93 |
+
nn.Linear(input_edge, hidden_nf),
|
94 |
+
activation,
|
95 |
+
nn.Linear(hidden_nf, hidden_nf),
|
96 |
+
activation,
|
97 |
+
layer)
|
98 |
+
self.normalization_factor = normalization_factor
|
99 |
+
self.aggregation_method = aggregation_method
|
100 |
+
|
101 |
+
def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask):
|
102 |
+
row, col = edge_index
|
103 |
+
input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1)
|
104 |
+
if self.tanh:
|
105 |
+
trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range
|
106 |
+
else:
|
107 |
+
trans = coord_diff * self.coord_mlp(input_tensor)
|
108 |
+
if edge_mask is not None:
|
109 |
+
trans = trans * edge_mask
|
110 |
+
agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0),
|
111 |
+
normalization_factor=self.normalization_factor,
|
112 |
+
aggregation_method=self.aggregation_method)
|
113 |
+
if linker_mask is not None:
|
114 |
+
agg = agg * linker_mask
|
115 |
+
|
116 |
+
coord = coord + agg
|
117 |
+
return coord
|
118 |
+
|
119 |
+
def forward(
|
120 |
+
self, h, coord, edge_index, coord_diff, edge_attr=None, linker_mask=None, node_mask=None, edge_mask=None
|
121 |
+
):
|
122 |
+
coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask)
|
123 |
+
if node_mask is not None:
|
124 |
+
coord = coord * node_mask
|
125 |
+
return coord
|
126 |
+
|
127 |
+
|
128 |
+
class EquivariantBlock(nn.Module):
|
129 |
+
def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', activation=nn.SiLU(), n_layers=2, attention=True,
|
130 |
+
norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None,
|
131 |
+
normalization_factor=100, aggregation_method='sum'):
|
132 |
+
super(EquivariantBlock, self).__init__()
|
133 |
+
self.hidden_nf = hidden_nf
|
134 |
+
self.device = device
|
135 |
+
self.n_layers = n_layers
|
136 |
+
self.coords_range_layer = float(coords_range)
|
137 |
+
self.norm_diff = norm_diff
|
138 |
+
self.norm_constant = norm_constant
|
139 |
+
self.sin_embedding = sin_embedding
|
140 |
+
self.normalization_factor = normalization_factor
|
141 |
+
self.aggregation_method = aggregation_method
|
142 |
+
|
143 |
+
for i in range(0, n_layers):
|
144 |
+
self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf,
|
145 |
+
activation=activation, attention=attention,
|
146 |
+
normalization_factor=self.normalization_factor,
|
147 |
+
aggregation_method=self.aggregation_method))
|
148 |
+
self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, activation=activation, tanh=tanh,
|
149 |
+
coords_range=self.coords_range_layer,
|
150 |
+
normalization_factor=self.normalization_factor,
|
151 |
+
aggregation_method=self.aggregation_method))
|
152 |
+
self.to(self.device)
|
153 |
+
|
154 |
+
def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None, edge_attr=None):
|
155 |
+
# Edit Emiel: Remove velocity as input
|
156 |
+
distances, coord_diff = coord2diff(x, edge_index, self.norm_constant)
|
157 |
+
if self.sin_embedding is not None:
|
158 |
+
distances = self.sin_embedding(distances)
|
159 |
+
edge_attr = torch.cat([distances, edge_attr], dim=1)
|
160 |
+
for i in range(0, self.n_layers):
|
161 |
+
h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
|
162 |
+
x = self._modules["gcl_equiv"](
|
163 |
+
h, x,
|
164 |
+
edge_index=edge_index,
|
165 |
+
coord_diff=coord_diff,
|
166 |
+
edge_attr=edge_attr,
|
167 |
+
linker_mask=linker_mask,
|
168 |
+
node_mask=node_mask,
|
169 |
+
edge_mask=edge_mask,
|
170 |
+
)
|
171 |
+
|
172 |
+
# Important, the bias of the last linear might be non-zero
|
173 |
+
if node_mask is not None:
|
174 |
+
h = h * node_mask
|
175 |
+
return h, x
|
176 |
+
|
177 |
+
|
178 |
+
class EGNN(nn.Module):
|
179 |
+
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', activation=nn.SiLU(), n_layers=3, attention=False,
|
180 |
+
norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2,
|
181 |
+
sin_embedding=False, normalization_factor=100, aggregation_method='sum'):
|
182 |
+
super(EGNN, self).__init__()
|
183 |
+
if out_node_nf is None:
|
184 |
+
out_node_nf = in_node_nf
|
185 |
+
self.hidden_nf = hidden_nf
|
186 |
+
self.device = device
|
187 |
+
self.n_layers = n_layers
|
188 |
+
self.coords_range_layer = float(coords_range/n_layers)
|
189 |
+
self.norm_diff = norm_diff
|
190 |
+
self.normalization_factor = normalization_factor
|
191 |
+
self.aggregation_method = aggregation_method
|
192 |
+
|
193 |
+
if sin_embedding:
|
194 |
+
self.sin_embedding = SinusoidsEmbeddingNew()
|
195 |
+
edge_feat_nf = self.sin_embedding.dim * 2
|
196 |
+
else:
|
197 |
+
self.sin_embedding = None
|
198 |
+
edge_feat_nf = 2
|
199 |
+
|
200 |
+
self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
|
201 |
+
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
|
202 |
+
for i in range(0, n_layers):
|
203 |
+
self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
|
204 |
+
activation=activation, n_layers=inv_sublayers,
|
205 |
+
attention=attention, norm_diff=norm_diff, tanh=tanh,
|
206 |
+
coords_range=coords_range, norm_constant=norm_constant,
|
207 |
+
sin_embedding=self.sin_embedding,
|
208 |
+
normalization_factor=self.normalization_factor,
|
209 |
+
aggregation_method=self.aggregation_method))
|
210 |
+
self.to(self.device)
|
211 |
+
|
212 |
+
def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None):
|
213 |
+
# Edit Emiel: Remove velocity as input
|
214 |
+
distances, _ = coord2diff(x, edge_index)
|
215 |
+
if self.sin_embedding is not None:
|
216 |
+
distances = self.sin_embedding(distances)
|
217 |
+
|
218 |
+
h = self.embedding(h)
|
219 |
+
for i in range(0, self.n_layers):
|
220 |
+
h, x = self._modules["e_block_%d" % i](
|
221 |
+
h, x, edge_index,
|
222 |
+
node_mask=node_mask,
|
223 |
+
linker_mask=linker_mask,
|
224 |
+
edge_mask=edge_mask,
|
225 |
+
edge_attr=distances
|
226 |
+
)
|
227 |
+
|
228 |
+
# Important, the bias of the last linear might be non-zero
|
229 |
+
h = self.embedding_out(h)
|
230 |
+
if node_mask is not None:
|
231 |
+
h = h * node_mask
|
232 |
+
return h, x
|
233 |
+
|
234 |
+
|
235 |
+
class GNN(nn.Module):
|
236 |
+
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, aggregation_method='sum', device='cpu',
|
237 |
+
activation=nn.SiLU(), n_layers=4, attention=False, normalization_factor=1,
|
238 |
+
out_node_nf=None, normalization=None):
|
239 |
+
super(GNN, self).__init__()
|
240 |
+
if out_node_nf is None:
|
241 |
+
out_node_nf = in_node_nf
|
242 |
+
self.hidden_nf = hidden_nf
|
243 |
+
self.device = device
|
244 |
+
self.n_layers = n_layers
|
245 |
+
|
246 |
+
# Encoder
|
247 |
+
self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
|
248 |
+
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
|
249 |
+
for i in range(0, n_layers):
|
250 |
+
self.add_module("gcl_%d" % i, GCL(
|
251 |
+
self.hidden_nf, self.hidden_nf, self.hidden_nf,
|
252 |
+
normalization_factor=normalization_factor,
|
253 |
+
aggregation_method=aggregation_method,
|
254 |
+
edges_in_d=in_edge_nf, activation=activation,
|
255 |
+
attention=attention, normalization=normalization))
|
256 |
+
self.to(self.device)
|
257 |
+
|
258 |
+
def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None):
|
259 |
+
# Edit Emiel: Remove velocity as input
|
260 |
+
h = self.embedding(h)
|
261 |
+
for i in range(0, self.n_layers):
|
262 |
+
h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
|
263 |
+
h = self.embedding_out(h)
|
264 |
+
|
265 |
+
# Important, the bias of the last linear might be non-zero
|
266 |
+
if node_mask is not None:
|
267 |
+
h = h * node_mask
|
268 |
+
return h
|
269 |
+
|
270 |
+
|
271 |
+
class SinusoidsEmbeddingNew(nn.Module):
|
272 |
+
def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4):
|
273 |
+
super().__init__()
|
274 |
+
self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1
|
275 |
+
self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res
|
276 |
+
self.dim = len(self.frequencies) * 2
|
277 |
+
|
278 |
+
def forward(self, x):
|
279 |
+
x = torch.sqrt(x + 1e-8)
|
280 |
+
emb = x * self.frequencies[None, :].to(x.device)
|
281 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
282 |
+
return emb.detach()
|
283 |
+
|
284 |
+
|
285 |
+
def coord2diff(x, edge_index, norm_constant=1):
|
286 |
+
row, col = edge_index
|
287 |
+
coord_diff = x[row] - x[col]
|
288 |
+
radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)
|
289 |
+
norm = torch.sqrt(radial + 1e-8)
|
290 |
+
coord_diff = coord_diff/(norm + norm_constant)
|
291 |
+
return radial, coord_diff
|
292 |
+
|
293 |
+
|
294 |
+
def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str):
|
295 |
+
"""Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
|
296 |
+
Normalization: 'sum' or 'mean'.
|
297 |
+
"""
|
298 |
+
result_shape = (num_segments, data.size(1))
|
299 |
+
result = data.new_full(result_shape, 0) # Init empty result tensor.
|
300 |
+
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
|
301 |
+
result.scatter_add_(0, segment_ids, data)
|
302 |
+
if aggregation_method == 'sum':
|
303 |
+
result = result / normalization_factor
|
304 |
+
|
305 |
+
if aggregation_method == 'mean':
|
306 |
+
norm = data.new_zeros(result.shape)
|
307 |
+
norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
|
308 |
+
norm[norm == 0] = 1
|
309 |
+
result = result / norm
|
310 |
+
return result
|
311 |
+
|
312 |
+
|
313 |
+
class Dynamics(nn.Module):
|
314 |
+
def __init__(
|
315 |
+
self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
|
316 |
+
n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
|
317 |
+
sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
|
318 |
+
normalization=None, centering=False,
|
319 |
+
):
|
320 |
+
super().__init__()
|
321 |
+
self.device = device
|
322 |
+
self.n_dims = n_dims
|
323 |
+
self.context_node_nf = context_node_nf
|
324 |
+
self.condition_time = condition_time
|
325 |
+
self.model = model
|
326 |
+
self.centering = centering
|
327 |
+
|
328 |
+
in_node_nf = in_node_nf + context_node_nf + condition_time
|
329 |
+
if self.model == 'egnn_dynamics':
|
330 |
+
self.dynamics = EGNN(
|
331 |
+
in_node_nf=in_node_nf,
|
332 |
+
in_edge_nf=1,
|
333 |
+
hidden_nf=hidden_nf, device=device,
|
334 |
+
activation=activation,
|
335 |
+
n_layers=n_layers,
|
336 |
+
attention=attention,
|
337 |
+
tanh=tanh,
|
338 |
+
norm_constant=norm_constant,
|
339 |
+
inv_sublayers=inv_sublayers,
|
340 |
+
sin_embedding=sin_embedding,
|
341 |
+
normalization_factor=normalization_factor,
|
342 |
+
aggregation_method=aggregation_method,
|
343 |
+
)
|
344 |
+
elif self.model == 'gnn_dynamics':
|
345 |
+
self.dynamics = GNN(
|
346 |
+
in_node_nf=in_node_nf+3,
|
347 |
+
in_edge_nf=0,
|
348 |
+
hidden_nf=hidden_nf,
|
349 |
+
out_node_nf=in_node_nf+3,
|
350 |
+
device=device,
|
351 |
+
activation=activation,
|
352 |
+
n_layers=n_layers,
|
353 |
+
attention=attention,
|
354 |
+
normalization_factor=normalization_factor,
|
355 |
+
aggregation_method=aggregation_method,
|
356 |
+
normalization=normalization,
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
raise NotImplementedError
|
360 |
+
|
361 |
+
self.edge_cache = {}
|
362 |
+
|
363 |
+
def forward(self, t, xh, node_mask, linker_mask, edge_mask, context):
|
364 |
+
"""
|
365 |
+
- t: (B)
|
366 |
+
- xh: (B, N, D), where D = 3 + nf
|
367 |
+
- node_mask: (B, N, 1)
|
368 |
+
- edge_mask: (B*N*N, 1)
|
369 |
+
- context: (B, N, C)
|
370 |
+
"""
|
371 |
+
|
372 |
+
bs, n_nodes = xh.shape[0], xh.shape[1]
|
373 |
+
edges = self.get_edges(n_nodes, bs) # (2, B*N)
|
374 |
+
node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
375 |
+
|
376 |
+
if linker_mask is not None:
|
377 |
+
linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
378 |
+
|
379 |
+
# Reshaping node features & adding time feature
|
380 |
+
xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
|
381 |
+
x = xh[:, :self.n_dims].clone() # (B*N, 3)
|
382 |
+
h = xh[:, self.n_dims:].clone() # (B*N, nf)
|
383 |
+
if self.condition_time:
|
384 |
+
if np.prod(t.size()) == 1:
|
385 |
+
# t is the same for all elements in batch.
|
386 |
+
h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
|
387 |
+
else:
|
388 |
+
# t is different over the batch dimension.
|
389 |
+
h_time = t.view(bs, 1).repeat(1, n_nodes)
|
390 |
+
h_time = h_time.view(bs * n_nodes, 1)
|
391 |
+
h = torch.cat([h, h_time], dim=1) # (B*N, nf+1)
|
392 |
+
if context is not None:
|
393 |
+
context = context.view(bs*n_nodes, self.context_node_nf)
|
394 |
+
h = torch.cat([h, context], dim=1)
|
395 |
+
|
396 |
+
# Forward EGNN
|
397 |
+
# Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3)
|
398 |
+
if self.model == 'egnn_dynamics':
|
399 |
+
h_final, x_final = self.dynamics(
|
400 |
+
h,
|
401 |
+
x,
|
402 |
+
edges,
|
403 |
+
node_mask=node_mask,
|
404 |
+
linker_mask=linker_mask,
|
405 |
+
edge_mask=edge_mask
|
406 |
+
)
|
407 |
+
vel = (x_final - x) * node_mask # This masking operation is redundant but just in case
|
408 |
+
elif self.model == 'gnn_dynamics':
|
409 |
+
xh = torch.cat([x, h], dim=1)
|
410 |
+
output = self.dynamics(xh, edges, node_mask=node_mask)
|
411 |
+
vel = output[:, 0:3] * node_mask
|
412 |
+
h_final = output[:, 3:]
|
413 |
+
else:
|
414 |
+
raise NotImplementedError
|
415 |
+
|
416 |
+
# Slice off context size
|
417 |
+
if context is not None:
|
418 |
+
h_final = h_final[:, :-self.context_node_nf]
|
419 |
+
|
420 |
+
# Slice off last dimension which represented time.
|
421 |
+
if self.condition_time:
|
422 |
+
h_final = h_final[:, :-1]
|
423 |
+
|
424 |
+
vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
|
425 |
+
h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
|
426 |
+
node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
|
427 |
+
|
428 |
+
if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final)):
|
429 |
+
raise utils.FoundNaNException(vel, h_final)
|
430 |
+
|
431 |
+
if self.centering:
|
432 |
+
vel = utils.remove_mean_with_mask(vel, node_mask)
|
433 |
+
|
434 |
+
return torch.cat([vel, h_final], dim=2)
|
435 |
+
|
436 |
+
def get_edges(self, n_nodes, batch_size):
|
437 |
+
if n_nodes in self.edge_cache:
|
438 |
+
edges_dic_b = self.edge_cache[n_nodes]
|
439 |
+
if batch_size in edges_dic_b:
|
440 |
+
return edges_dic_b[batch_size]
|
441 |
+
else:
|
442 |
+
# get edges for a single sample
|
443 |
+
rows, cols = [], []
|
444 |
+
for batch_idx in range(batch_size):
|
445 |
+
for i in range(n_nodes):
|
446 |
+
for j in range(n_nodes):
|
447 |
+
rows.append(i + batch_idx * n_nodes)
|
448 |
+
cols.append(j + batch_idx * n_nodes)
|
449 |
+
edges = [torch.LongTensor(rows).to(self.device), torch.LongTensor(cols).to(self.device)]
|
450 |
+
edges_dic_b[batch_size] = edges
|
451 |
+
return edges
|
452 |
+
else:
|
453 |
+
self.edge_cache[n_nodes] = {}
|
454 |
+
return self.get_edges(n_nodes, batch_size)
|
455 |
+
|
456 |
+
|
457 |
+
class DynamicsWithPockets(Dynamics):
|
458 |
+
def forward(self, t, xh, node_mask, linker_mask, edge_mask, context):
|
459 |
+
"""
|
460 |
+
- t: (B)
|
461 |
+
- xh: (B, N, D), where D = 3 + nf
|
462 |
+
- node_mask: (B, N, 1)
|
463 |
+
- edge_mask: (B*N*N, 1)
|
464 |
+
- context: (B, N, C)
|
465 |
+
"""
|
466 |
+
|
467 |
+
bs, n_nodes = xh.shape[0], xh.shape[1]
|
468 |
+
node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
469 |
+
|
470 |
+
if linker_mask is not None:
|
471 |
+
linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
472 |
+
|
473 |
+
# Reshaping node features & adding time feature
|
474 |
+
xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
|
475 |
+
x = xh[:, :self.n_dims].clone() # (B*N, 3)
|
476 |
+
h = xh[:, self.n_dims:].clone() # (B*N, nf)
|
477 |
+
|
478 |
+
edges = self.get_dist_edges(x, node_mask, edge_mask)
|
479 |
+
if self.condition_time:
|
480 |
+
if np.prod(t.size()) == 1:
|
481 |
+
# t is the same for all elements in batch.
|
482 |
+
h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
|
483 |
+
else:
|
484 |
+
# t is different over the batch dimension.
|
485 |
+
h_time = t.view(bs, 1).repeat(1, n_nodes)
|
486 |
+
h_time = h_time.view(bs * n_nodes, 1)
|
487 |
+
h = torch.cat([h, h_time], dim=1) # (B*N, nf+1)
|
488 |
+
if context is not None:
|
489 |
+
context = context.view(bs*n_nodes, self.context_node_nf)
|
490 |
+
h = torch.cat([h, context], dim=1)
|
491 |
+
|
492 |
+
# Forward EGNN
|
493 |
+
# Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3)
|
494 |
+
if self.model == 'egnn_dynamics':
|
495 |
+
h_final, x_final = self.dynamics(
|
496 |
+
h,
|
497 |
+
x,
|
498 |
+
edges,
|
499 |
+
node_mask=node_mask,
|
500 |
+
linker_mask=linker_mask,
|
501 |
+
edge_mask=None
|
502 |
+
)
|
503 |
+
vel = (x_final - x) * node_mask # This masking operation is redundant but just in case
|
504 |
+
elif self.model == 'gnn_dynamics':
|
505 |
+
xh = torch.cat([x, h], dim=1)
|
506 |
+
output = self.dynamics(xh, edges, node_mask=node_mask)
|
507 |
+
vel = output[:, 0:3] * node_mask
|
508 |
+
h_final = output[:, 3:]
|
509 |
+
else:
|
510 |
+
raise NotImplementedError
|
511 |
+
|
512 |
+
# Slice off context size
|
513 |
+
if context is not None:
|
514 |
+
h_final = h_final[:, :-self.context_node_nf]
|
515 |
+
|
516 |
+
# Slice off last dimension which represented time.
|
517 |
+
if self.condition_time:
|
518 |
+
h_final = h_final[:, :-1]
|
519 |
+
|
520 |
+
vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
|
521 |
+
h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
|
522 |
+
node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
|
523 |
+
|
524 |
+
if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final)):
|
525 |
+
raise utils.FoundNaNException(vel, h_final)
|
526 |
+
|
527 |
+
if self.centering:
|
528 |
+
vel = utils.remove_mean_with_mask(vel, node_mask)
|
529 |
+
|
530 |
+
return torch.cat([vel, h_final], dim=2)
|
531 |
+
|
532 |
+
@staticmethod
|
533 |
+
def get_dist_edges(x, node_mask, batch_mask):
|
534 |
+
node_mask = node_mask.squeeze().bool()
|
535 |
+
batch_adj = (batch_mask[:, None] == batch_mask[None, :])
|
536 |
+
nodes_adj = (node_mask[:, None] & node_mask[None, :])
|
537 |
+
dists_adj = (torch.cdist(x, x) <= 4)
|
538 |
+
rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
|
539 |
+
adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
|
540 |
+
edges = torch.stack(torch.where(adj))
|
541 |
+
return edges
|
src/lightning.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
import wandb
|
6 |
+
|
7 |
+
from src import metrics, utils, delinker
|
8 |
+
from src.const import LINKER_SIZE_DIST
|
9 |
+
from src.egnn import Dynamics, DynamicsWithPockets
|
10 |
+
from src.edm import EDM, InpaintingEDM
|
11 |
+
from src.datasets import (
|
12 |
+
ZincDataset, MOADDataset, create_templates_for_linker_generation, get_dataloader, collate
|
13 |
+
)
|
14 |
+
from src.linker_size import DistributionNodes
|
15 |
+
from src.molecule_builder import build_molecules
|
16 |
+
from src.visualizer import save_xyz_file, visualize_chain
|
17 |
+
from typing import Dict, List, Optional
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from pdb import set_trace
|
21 |
+
|
22 |
+
|
23 |
+
def get_activation(activation):
|
24 |
+
print(activation)
|
25 |
+
if activation == 'silu':
|
26 |
+
return torch.nn.SiLU()
|
27 |
+
else:
|
28 |
+
raise Exception("activation fn not supported yet. Add it here.")
|
29 |
+
|
30 |
+
|
31 |
+
class DDPM(pl.LightningModule):
|
32 |
+
train_dataset = None
|
33 |
+
val_dataset = None
|
34 |
+
test_dataset = None
|
35 |
+
starting_epoch = None
|
36 |
+
metrics: Dict[str, List[float]] = {}
|
37 |
+
|
38 |
+
FRAMES = 100
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
in_node_nf, n_dims, context_node_nf, hidden_nf, activation, tanh, n_layers, attention, norm_constant,
|
43 |
+
inv_sublayers, sin_embedding, normalization_factor, aggregation_method,
|
44 |
+
diffusion_steps, diffusion_noise_schedule, diffusion_noise_precision, diffusion_loss_type,
|
45 |
+
normalize_factors, include_charges, model,
|
46 |
+
data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
|
47 |
+
normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
|
48 |
+
center_of_mass='fragments', inpainting=False, anchors_context=True,
|
49 |
+
):
|
50 |
+
super(DDPM, self).__init__()
|
51 |
+
|
52 |
+
self.save_hyperparameters()
|
53 |
+
self.data_path = data_path
|
54 |
+
self.train_data_prefix = train_data_prefix
|
55 |
+
self.val_data_prefix = val_data_prefix
|
56 |
+
self.batch_size = batch_size
|
57 |
+
self.lr = lr
|
58 |
+
self.torch_device = torch_device
|
59 |
+
self.include_charges = include_charges
|
60 |
+
self.test_epochs = test_epochs
|
61 |
+
self.n_stability_samples = n_stability_samples
|
62 |
+
self.log_iterations = log_iterations
|
63 |
+
self.samples_dir = samples_dir
|
64 |
+
self.data_augmentation = data_augmentation
|
65 |
+
self.center_of_mass = center_of_mass
|
66 |
+
self.inpainting = inpainting
|
67 |
+
self.loss_type = diffusion_loss_type
|
68 |
+
|
69 |
+
self.n_dims = n_dims
|
70 |
+
self.num_classes = in_node_nf - include_charges
|
71 |
+
self.include_charges = include_charges
|
72 |
+
self.anchors_context = anchors_context
|
73 |
+
|
74 |
+
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
|
75 |
+
|
76 |
+
if type(activation) is str:
|
77 |
+
activation = get_activation(activation)
|
78 |
+
|
79 |
+
dynamics_class = DynamicsWithPockets if '.' in train_data_prefix else Dynamics
|
80 |
+
dynamics = dynamics_class(
|
81 |
+
in_node_nf=in_node_nf,
|
82 |
+
n_dims=n_dims,
|
83 |
+
context_node_nf=context_node_nf,
|
84 |
+
device=torch_device,
|
85 |
+
hidden_nf=hidden_nf,
|
86 |
+
activation=activation,
|
87 |
+
n_layers=n_layers,
|
88 |
+
attention=attention,
|
89 |
+
tanh=tanh,
|
90 |
+
norm_constant=norm_constant,
|
91 |
+
inv_sublayers=inv_sublayers,
|
92 |
+
sin_embedding=sin_embedding,
|
93 |
+
normalization_factor=normalization_factor,
|
94 |
+
aggregation_method=aggregation_method,
|
95 |
+
model=model,
|
96 |
+
normalization=normalization,
|
97 |
+
centering=inpainting,
|
98 |
+
)
|
99 |
+
edm_class = InpaintingEDM if inpainting else EDM
|
100 |
+
self.edm = edm_class(
|
101 |
+
dynamics=dynamics,
|
102 |
+
in_node_nf=in_node_nf,
|
103 |
+
n_dims=n_dims,
|
104 |
+
timesteps=diffusion_steps,
|
105 |
+
noise_schedule=diffusion_noise_schedule,
|
106 |
+
noise_precision=diffusion_noise_precision,
|
107 |
+
loss_type=diffusion_loss_type,
|
108 |
+
norm_values=normalize_factors,
|
109 |
+
)
|
110 |
+
self.linker_size_sampler = DistributionNodes(LINKER_SIZE_DIST)
|
111 |
+
|
112 |
+
def setup(self, stage: Optional[str] = None):
|
113 |
+
dataset_type = MOADDataset if '.' in self.train_data_prefix else ZincDataset
|
114 |
+
if stage == 'fit':
|
115 |
+
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
|
116 |
+
self.train_dataset = dataset_type(
|
117 |
+
data_path=self.data_path,
|
118 |
+
prefix=self.train_data_prefix,
|
119 |
+
device=self.torch_device
|
120 |
+
)
|
121 |
+
self.val_dataset = dataset_type(
|
122 |
+
data_path=self.data_path,
|
123 |
+
prefix=self.val_data_prefix,
|
124 |
+
device=self.torch_device
|
125 |
+
)
|
126 |
+
elif stage == 'val':
|
127 |
+
self.is_geom = ('geom' in self.val_data_prefix) or ('MOAD' in self.val_data_prefix)
|
128 |
+
self.val_dataset = dataset_type(
|
129 |
+
data_path=self.data_path,
|
130 |
+
prefix=self.val_data_prefix,
|
131 |
+
device=self.torch_device
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
raise NotImplementedError
|
135 |
+
|
136 |
+
def train_dataloader(self, collate_fn=collate):
|
137 |
+
return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_fn, shuffle=True)
|
138 |
+
|
139 |
+
def val_dataloader(self, collate_fn=collate):
|
140 |
+
return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_fn)
|
141 |
+
|
142 |
+
def test_dataloader(self, collate_fn=collate):
|
143 |
+
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_fn)
|
144 |
+
|
145 |
+
def forward(self, data, training):
|
146 |
+
x = data['positions']
|
147 |
+
h = data['one_hot']
|
148 |
+
node_mask = data['atom_mask']
|
149 |
+
edge_mask = data['edge_mask']
|
150 |
+
anchors = data['anchors']
|
151 |
+
fragment_mask = data['fragment_mask']
|
152 |
+
linker_mask = data['linker_mask']
|
153 |
+
|
154 |
+
# Anchors and fragments labels are used as context
|
155 |
+
if self.anchors_context:
|
156 |
+
context = torch.cat([anchors, fragment_mask], dim=-1)
|
157 |
+
else:
|
158 |
+
context = fragment_mask
|
159 |
+
|
160 |
+
# Add information about pocket to the context
|
161 |
+
if '.' in self.train_data_prefix:
|
162 |
+
fragment_pocket_mask = fragment_mask
|
163 |
+
fragment_only_mask = data['fragment_only_mask']
|
164 |
+
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
165 |
+
if self.anchors_context:
|
166 |
+
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
|
167 |
+
else:
|
168 |
+
context = torch.cat([fragment_only_mask, pocket_only_mask], dim=-1)
|
169 |
+
|
170 |
+
# Removing COM of fragment from the atom coordinates
|
171 |
+
if self.inpainting:
|
172 |
+
center_of_mass_mask = node_mask
|
173 |
+
elif self.center_of_mass == 'fragments':
|
174 |
+
center_of_mass_mask = fragment_mask
|
175 |
+
elif self.center_of_mass == 'anchors':
|
176 |
+
center_of_mass_mask = anchors
|
177 |
+
else:
|
178 |
+
raise NotImplementedError(self.center_of_mass)
|
179 |
+
x = utils.remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask)
|
180 |
+
utils.assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask)
|
181 |
+
|
182 |
+
# Applying random rotation
|
183 |
+
if training and self.data_augmentation:
|
184 |
+
x = utils.random_rotation(x)
|
185 |
+
|
186 |
+
return self.edm.forward(
|
187 |
+
x=x,
|
188 |
+
h=h,
|
189 |
+
node_mask=node_mask,
|
190 |
+
fragment_mask=fragment_mask,
|
191 |
+
linker_mask=linker_mask,
|
192 |
+
edge_mask=edge_mask,
|
193 |
+
context=context
|
194 |
+
)
|
195 |
+
|
196 |
+
def training_step(self, data, *args):
|
197 |
+
delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=True)
|
198 |
+
vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
|
199 |
+
if self.loss_type == 'l2':
|
200 |
+
loss = l2_loss
|
201 |
+
elif self.loss_type == 'vlb':
|
202 |
+
loss = vlb_loss
|
203 |
+
else:
|
204 |
+
raise NotImplementedError(self.loss_type)
|
205 |
+
|
206 |
+
training_metrics = {
|
207 |
+
'loss': loss,
|
208 |
+
'delta_log_px': delta_log_px,
|
209 |
+
'kl_prior': kl_prior,
|
210 |
+
'loss_term_t': loss_term_t,
|
211 |
+
'loss_term_0': loss_term_0,
|
212 |
+
'l2_loss': l2_loss,
|
213 |
+
'vlb_loss': vlb_loss,
|
214 |
+
'noise_t': noise_t,
|
215 |
+
'noise_0': noise_0
|
216 |
+
}
|
217 |
+
if self.log_iterations is not None and self.global_step % self.log_iterations == 0:
|
218 |
+
for metric_name, metric in training_metrics.items():
|
219 |
+
self.metrics.setdefault(f'{metric_name}/train', []).append(metric)
|
220 |
+
self.log(f'{metric_name}/train', metric, prog_bar=True)
|
221 |
+
return training_metrics
|
222 |
+
|
223 |
+
def validation_step(self, data, *args):
|
224 |
+
delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=False)
|
225 |
+
vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
|
226 |
+
if self.loss_type == 'l2':
|
227 |
+
loss = l2_loss
|
228 |
+
elif self.loss_type == 'vlb':
|
229 |
+
loss = vlb_loss
|
230 |
+
else:
|
231 |
+
raise NotImplementedError(self.loss_type)
|
232 |
+
return {
|
233 |
+
'loss': loss,
|
234 |
+
'delta_log_px': delta_log_px,
|
235 |
+
'kl_prior': kl_prior,
|
236 |
+
'loss_term_t': loss_term_t,
|
237 |
+
'loss_term_0': loss_term_0,
|
238 |
+
'l2_loss': l2_loss,
|
239 |
+
'vlb_loss': vlb_loss,
|
240 |
+
'noise_t': noise_t,
|
241 |
+
'noise_0': noise_0
|
242 |
+
}
|
243 |
+
|
244 |
+
def test_step(self, data, *args):
|
245 |
+
delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=False)
|
246 |
+
vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
|
247 |
+
if self.loss_type == 'l2':
|
248 |
+
loss = l2_loss
|
249 |
+
elif self.loss_type == 'vlb':
|
250 |
+
loss = vlb_loss
|
251 |
+
else:
|
252 |
+
raise NotImplementedError(self.loss_type)
|
253 |
+
return {
|
254 |
+
'loss': loss,
|
255 |
+
'delta_log_px': delta_log_px,
|
256 |
+
'kl_prior': kl_prior,
|
257 |
+
'loss_term_t': loss_term_t,
|
258 |
+
'loss_term_0': loss_term_0,
|
259 |
+
'l2_loss': l2_loss,
|
260 |
+
'vlb_loss': vlb_loss,
|
261 |
+
'noise_t': noise_t,
|
262 |
+
'noise_0': noise_0
|
263 |
+
}
|
264 |
+
|
265 |
+
def training_epoch_end(self, training_step_outputs):
|
266 |
+
for metric in training_step_outputs[0].keys():
|
267 |
+
avg_metric = self.aggregate_metric(training_step_outputs, metric)
|
268 |
+
self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
|
269 |
+
self.log(f'{metric}/train', avg_metric, prog_bar=True)
|
270 |
+
|
271 |
+
def validation_epoch_end(self, validation_step_outputs):
|
272 |
+
for metric in validation_step_outputs[0].keys():
|
273 |
+
avg_metric = self.aggregate_metric(validation_step_outputs, metric)
|
274 |
+
self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
|
275 |
+
self.log(f'{metric}/val', avg_metric, prog_bar=True)
|
276 |
+
|
277 |
+
if (self.current_epoch + 1) % self.test_epochs == 0:
|
278 |
+
sampling_results = self.sample_and_analyze(self.val_dataloader())
|
279 |
+
for metric_name, metric_value in sampling_results.items():
|
280 |
+
self.log(f'{metric_name}/val', metric_value, prog_bar=True)
|
281 |
+
self.metrics.setdefault(f'{metric_name}/val', []).append(metric_value)
|
282 |
+
|
283 |
+
# Logging the results corresponding to the best validation_and_connectivity
|
284 |
+
best_metrics, best_epoch = self.compute_best_validation_metrics()
|
285 |
+
self.log('best_epoch', int(best_epoch), prog_bar=True, batch_size=self.batch_size)
|
286 |
+
for metric, value in best_metrics.items():
|
287 |
+
self.log(f'best_{metric}', value, prog_bar=True, batch_size=self.batch_size)
|
288 |
+
|
289 |
+
def test_epoch_end(self, test_step_outputs):
|
290 |
+
for metric in test_step_outputs[0].keys():
|
291 |
+
avg_metric = self.aggregate_metric(test_step_outputs, metric)
|
292 |
+
self.metrics.setdefault(f'{metric}/test', []).append(avg_metric)
|
293 |
+
self.log(f'{metric}/test', avg_metric, prog_bar=True)
|
294 |
+
|
295 |
+
if (self.current_epoch + 1) % self.test_epochs == 0:
|
296 |
+
sampling_results = self.sample_and_analyze(self.test_dataloader())
|
297 |
+
for metric_name, metric_value in sampling_results.items():
|
298 |
+
self.log(f'{metric_name}/test', metric_value, prog_bar=True)
|
299 |
+
self.metrics.setdefault(f'{metric_name}/test', []).append(metric_value)
|
300 |
+
|
301 |
+
def generate_animation(self, chain_batch, node_mask, batch_i):
|
302 |
+
batch_indices, mol_indices = utils.get_batch_idx_for_animation(self.batch_size, batch_i)
|
303 |
+
for bi, mi in zip(batch_indices, mol_indices):
|
304 |
+
chain = chain_batch[:, bi, :, :]
|
305 |
+
name = f'mol_{mi}'
|
306 |
+
chain_output = os.path.join(self.samples_dir, f'epoch_{self.current_epoch}', name)
|
307 |
+
os.makedirs(chain_output, exist_ok=True)
|
308 |
+
|
309 |
+
one_hot = chain[:, :, 3:-1] if self.include_charges else chain[:, :, 3:]
|
310 |
+
positions = chain[:, :, :3]
|
311 |
+
chain_node_mask = torch.cat([node_mask[bi].unsqueeze(0) for _ in range(self.FRAMES)], dim=0)
|
312 |
+
names = [f'{name}_{j}' for j in range(self.FRAMES)]
|
313 |
+
|
314 |
+
save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=self.is_geom)
|
315 |
+
visualize_chain(chain_output, wandb=wandb, mode=name, is_geom=self.is_geom)
|
316 |
+
|
317 |
+
def sample_and_analyze(self, dataloader):
|
318 |
+
pred_molecules = []
|
319 |
+
true_molecules = []
|
320 |
+
true_fragments = []
|
321 |
+
|
322 |
+
for b, data in tqdm(enumerate(dataloader), total=len(dataloader), desc='Sampling'):
|
323 |
+
atom_mask = data['atom_mask']
|
324 |
+
fragment_mask = data['fragment_mask']
|
325 |
+
|
326 |
+
# Save molecules without pockets
|
327 |
+
if '.' in self.train_data_prefix:
|
328 |
+
atom_mask = data['atom_mask'] - data['pocket_mask']
|
329 |
+
fragment_mask = data['fragment_only_mask']
|
330 |
+
|
331 |
+
true_molecules_batch = build_molecules(
|
332 |
+
data['one_hot'],
|
333 |
+
data['positions'],
|
334 |
+
atom_mask,
|
335 |
+
is_geom=self.is_geom,
|
336 |
+
)
|
337 |
+
true_fragments_batch = build_molecules(
|
338 |
+
data['one_hot'],
|
339 |
+
data['positions'],
|
340 |
+
fragment_mask,
|
341 |
+
is_geom=self.is_geom,
|
342 |
+
)
|
343 |
+
|
344 |
+
for sample_idx in tqdm(range(self.n_stability_samples)):
|
345 |
+
try:
|
346 |
+
chain_batch, node_mask = self.sample_chain(data, keep_frames=self.FRAMES)
|
347 |
+
except utils.FoundNaNException as e:
|
348 |
+
for idx in e.x_h_nan_idx:
|
349 |
+
smiles = data['name'][idx]
|
350 |
+
print(f'FoundNaNException: [xh], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
|
351 |
+
for idx in e.only_x_nan_idx:
|
352 |
+
smiles = data['name'][idx]
|
353 |
+
print(f'FoundNaNException: [x ], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
|
354 |
+
for idx in e.only_h_nan_idx:
|
355 |
+
smiles = data['name'][idx]
|
356 |
+
print(f'FoundNaNException: [ h], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
|
357 |
+
continue
|
358 |
+
|
359 |
+
# Get final molecules from chains – for computing metrics
|
360 |
+
x, h = utils.split_features(
|
361 |
+
z=chain_batch[0],
|
362 |
+
n_dims=self.n_dims,
|
363 |
+
num_classes=self.num_classes,
|
364 |
+
include_charges=self.include_charges,
|
365 |
+
)
|
366 |
+
|
367 |
+
# Save molecules without pockets
|
368 |
+
if '.' in self.train_data_prefix:
|
369 |
+
node_mask = node_mask - data['pocket_mask']
|
370 |
+
|
371 |
+
one_hot = h['categorical']
|
372 |
+
pred_molecules_batch = build_molecules(one_hot, x, node_mask, is_geom=self.is_geom)
|
373 |
+
|
374 |
+
# Adding only results for valid ground truth molecules
|
375 |
+
for pred_mol, true_mol, frag in zip(pred_molecules_batch, true_molecules_batch, true_fragments_batch):
|
376 |
+
if metrics.is_valid(true_mol):
|
377 |
+
pred_molecules.append(pred_mol)
|
378 |
+
true_molecules.append(true_mol)
|
379 |
+
true_fragments.append(frag)
|
380 |
+
|
381 |
+
# Generate animation – will always do it for molecules with idx 0, 110 and 360
|
382 |
+
if self.samples_dir is not None and sample_idx == 0:
|
383 |
+
self.generate_animation(chain_batch=chain_batch, node_mask=node_mask, batch_i=b)
|
384 |
+
|
385 |
+
# Our own & DeLinker metrics
|
386 |
+
our_metrics = metrics.compute_metrics(
|
387 |
+
pred_molecules=pred_molecules,
|
388 |
+
true_molecules=true_molecules
|
389 |
+
)
|
390 |
+
delinker_metrics = delinker.get_delinker_metrics(
|
391 |
+
pred_molecules=pred_molecules,
|
392 |
+
true_molecules=true_molecules,
|
393 |
+
true_fragments=true_fragments
|
394 |
+
)
|
395 |
+
return {
|
396 |
+
**our_metrics,
|
397 |
+
**delinker_metrics
|
398 |
+
}
|
399 |
+
|
400 |
+
def sample_chain(self, data, sample_fn=None, keep_frames=None):
|
401 |
+
if sample_fn is None:
|
402 |
+
linker_sizes = data['linker_mask'].sum(1).view(-1).int()
|
403 |
+
else:
|
404 |
+
linker_sizes = sample_fn(data)
|
405 |
+
|
406 |
+
if self.inpainting:
|
407 |
+
template_data = data
|
408 |
+
else:
|
409 |
+
template_data = create_templates_for_linker_generation(data, linker_sizes)
|
410 |
+
|
411 |
+
x = template_data['positions']
|
412 |
+
node_mask = template_data['atom_mask']
|
413 |
+
edge_mask = template_data['edge_mask']
|
414 |
+
h = template_data['one_hot']
|
415 |
+
anchors = template_data['anchors']
|
416 |
+
fragment_mask = template_data['fragment_mask']
|
417 |
+
linker_mask = template_data['linker_mask']
|
418 |
+
|
419 |
+
# Anchors and fragments labels are used as context
|
420 |
+
if self.anchors_context:
|
421 |
+
context = torch.cat([anchors, fragment_mask], dim=-1)
|
422 |
+
else:
|
423 |
+
context = fragment_mask
|
424 |
+
|
425 |
+
# Add information about pocket to the context
|
426 |
+
if '.' in self.train_data_prefix:
|
427 |
+
fragment_pocket_mask = fragment_mask
|
428 |
+
fragment_only_mask = data['fragment_only_mask']
|
429 |
+
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
430 |
+
if self.anchors_context:
|
431 |
+
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
|
432 |
+
else:
|
433 |
+
context = torch.cat([fragment_only_mask, pocket_only_mask], dim=-1)
|
434 |
+
|
435 |
+
# Removing COM of fragment from the atom coordinates
|
436 |
+
if self.inpainting:
|
437 |
+
center_of_mass_mask = node_mask
|
438 |
+
elif self.center_of_mass == 'fragments':
|
439 |
+
center_of_mass_mask = fragment_mask
|
440 |
+
elif self.center_of_mass == 'anchors':
|
441 |
+
center_of_mass_mask = anchors
|
442 |
+
else:
|
443 |
+
raise NotImplementedError(self.center_of_mass)
|
444 |
+
x = utils.remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask)
|
445 |
+
|
446 |
+
chain = self.edm.sample_chain(
|
447 |
+
x=x,
|
448 |
+
h=h,
|
449 |
+
node_mask=node_mask,
|
450 |
+
edge_mask=edge_mask,
|
451 |
+
fragment_mask=fragment_mask,
|
452 |
+
linker_mask=linker_mask,
|
453 |
+
context=context,
|
454 |
+
keep_frames=keep_frames,
|
455 |
+
)
|
456 |
+
return chain, node_mask
|
457 |
+
|
458 |
+
def configure_optimizers(self):
|
459 |
+
return torch.optim.AdamW(self.edm.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
|
460 |
+
|
461 |
+
def compute_best_validation_metrics(self):
|
462 |
+
loss = self.metrics[f'validity_and_connectivity/val']
|
463 |
+
best_epoch = np.argmax(loss)
|
464 |
+
best_metrics = {
|
465 |
+
metric_name: metric_values[best_epoch]
|
466 |
+
for metric_name, metric_values in self.metrics.items()
|
467 |
+
if metric_name.endswith('/val')
|
468 |
+
}
|
469 |
+
return best_metrics, best_epoch
|
470 |
+
|
471 |
+
@staticmethod
|
472 |
+
def aggregate_metric(step_outputs, metric):
|
473 |
+
return torch.tensor([out[metric] for out in step_outputs]).mean()
|
src/linker_size.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from torch.distributions.categorical import Categorical
|
6 |
+
from src.egnn import GCL
|
7 |
+
|
8 |
+
|
9 |
+
class DistributionNodes:
|
10 |
+
def __init__(self, histogram):
|
11 |
+
|
12 |
+
self.n_nodes = []
|
13 |
+
prob = []
|
14 |
+
self.keys = {}
|
15 |
+
for i, nodes in enumerate(histogram):
|
16 |
+
self.n_nodes.append(nodes)
|
17 |
+
self.keys[nodes] = i
|
18 |
+
prob.append(histogram[nodes])
|
19 |
+
self.n_nodes = torch.tensor(self.n_nodes)
|
20 |
+
prob = np.array(prob)
|
21 |
+
prob = prob/np.sum(prob)
|
22 |
+
|
23 |
+
self.prob = torch.from_numpy(prob).float()
|
24 |
+
|
25 |
+
entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30))
|
26 |
+
print("Entropy of n_nodes: H[N]", entropy.item())
|
27 |
+
|
28 |
+
self.m = Categorical(torch.tensor(prob))
|
29 |
+
|
30 |
+
def sample(self, n_samples=1):
|
31 |
+
idx = self.m.sample((n_samples,))
|
32 |
+
return self.n_nodes[idx]
|
33 |
+
|
34 |
+
def log_prob(self, batch_n_nodes):
|
35 |
+
assert len(batch_n_nodes.size()) == 1
|
36 |
+
|
37 |
+
idcs = [self.keys[i.item()] for i in batch_n_nodes]
|
38 |
+
idcs = torch.tensor(idcs).to(batch_n_nodes.device)
|
39 |
+
|
40 |
+
log_p = torch.log(self.prob + 1e-30)
|
41 |
+
|
42 |
+
log_p = log_p.to(batch_n_nodes.device)
|
43 |
+
|
44 |
+
log_probs = log_p[idcs]
|
45 |
+
|
46 |
+
return log_probs
|
47 |
+
|
48 |
+
|
49 |
+
class SizeGNN(nn.Module):
|
50 |
+
def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_layers, normalization, device='cpu'):
|
51 |
+
super(SizeGNN, self).__init__()
|
52 |
+
self.hidden_nf = hidden_nf
|
53 |
+
self.out_node_nf = out_node_nf
|
54 |
+
self.device = device
|
55 |
+
|
56 |
+
self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
|
57 |
+
self.gcl1 = GCL(
|
58 |
+
input_nf=self.hidden_nf,
|
59 |
+
output_nf=self.hidden_nf,
|
60 |
+
hidden_nf=self.hidden_nf,
|
61 |
+
normalization_factor=1,
|
62 |
+
aggregation_method='sum',
|
63 |
+
edges_in_d=1,
|
64 |
+
activation=nn.ReLU(),
|
65 |
+
attention=False,
|
66 |
+
normalization=normalization
|
67 |
+
)
|
68 |
+
|
69 |
+
layers = []
|
70 |
+
for i in range(n_layers - 1):
|
71 |
+
layer = GCL(
|
72 |
+
input_nf=self.hidden_nf,
|
73 |
+
output_nf=self.hidden_nf,
|
74 |
+
hidden_nf=self.hidden_nf,
|
75 |
+
normalization_factor=1,
|
76 |
+
aggregation_method='sum',
|
77 |
+
edges_in_d=1,
|
78 |
+
activation=nn.ReLU(),
|
79 |
+
attention=False,
|
80 |
+
normalization=normalization
|
81 |
+
)
|
82 |
+
layers.append(layer)
|
83 |
+
|
84 |
+
self.gcl_layers = nn.ModuleList(layers)
|
85 |
+
self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf)
|
86 |
+
self.to(self.device)
|
87 |
+
|
88 |
+
def forward(self, h, edges, distances, node_mask, edge_mask):
|
89 |
+
h = self.embedding_in(h)
|
90 |
+
h, _ = self.gcl1(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
|
91 |
+
for gcl in self.gcl_layers:
|
92 |
+
h, _ = gcl(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
|
93 |
+
|
94 |
+
h = self.embedding_out(h)
|
95 |
+
return h
|
src/linker_size_lightning.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from src.const import ZINC_TRAIN_LINKER_ID2SIZE, ZINC_TRAIN_LINKER_SIZE2ID
|
5 |
+
from src.linker_size import SizeGNN
|
6 |
+
from src.egnn import coord2diff
|
7 |
+
from src.datasets import ZincDataset, get_dataloader, collate_with_fragment_edges
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
from torch.nn.functional import cross_entropy, mse_loss, sigmoid
|
10 |
+
|
11 |
+
from pdb import set_trace
|
12 |
+
|
13 |
+
|
14 |
+
class SizeClassifier(pl.LightningModule):
|
15 |
+
train_dataset = None
|
16 |
+
val_dataset = None
|
17 |
+
test_dataset = None
|
18 |
+
metrics: Dict[str, List[float]] = {}
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self, data_path, train_data_prefix, val_data_prefix,
|
22 |
+
in_node_nf, hidden_nf, out_node_nf, n_layers, batch_size, lr, torch_device,
|
23 |
+
normalization=None,
|
24 |
+
loss_weights=None,
|
25 |
+
min_linker_size=None,
|
26 |
+
linker_size2id=ZINC_TRAIN_LINKER_SIZE2ID,
|
27 |
+
linker_id2size=ZINC_TRAIN_LINKER_ID2SIZE,
|
28 |
+
task='classification',
|
29 |
+
):
|
30 |
+
super(SizeClassifier, self).__init__()
|
31 |
+
|
32 |
+
self.save_hyperparameters()
|
33 |
+
self.data_path = data_path
|
34 |
+
self.train_data_prefix = train_data_prefix
|
35 |
+
self.val_data_prefix = val_data_prefix
|
36 |
+
self.min_linker_size = min_linker_size
|
37 |
+
self.linker_size2id = linker_size2id
|
38 |
+
self.linker_id2size = linker_id2size
|
39 |
+
self.batch_size = batch_size
|
40 |
+
self.lr = lr
|
41 |
+
self.torch_device = torch_device
|
42 |
+
self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=torch_device)
|
43 |
+
self.gnn = SizeGNN(
|
44 |
+
in_node_nf=in_node_nf,
|
45 |
+
hidden_nf=hidden_nf,
|
46 |
+
out_node_nf=out_node_nf,
|
47 |
+
n_layers=n_layers,
|
48 |
+
device=torch_device,
|
49 |
+
normalization=normalization,
|
50 |
+
)
|
51 |
+
|
52 |
+
def setup(self, stage: Optional[str] = None):
|
53 |
+
if stage == 'fit':
|
54 |
+
self.train_dataset = ZincDataset(
|
55 |
+
data_path=self.data_path,
|
56 |
+
prefix=self.train_data_prefix,
|
57 |
+
device=self.torch_device
|
58 |
+
)
|
59 |
+
self.val_dataset = ZincDataset(
|
60 |
+
data_path=self.data_path,
|
61 |
+
prefix=self.val_data_prefix,
|
62 |
+
device=self.torch_device
|
63 |
+
)
|
64 |
+
elif stage == 'val':
|
65 |
+
self.val_dataset = ZincDataset(
|
66 |
+
data_path=self.data_path,
|
67 |
+
prefix=self.val_data_prefix,
|
68 |
+
device=self.torch_device
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
def train_dataloader(self):
|
74 |
+
return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True)
|
75 |
+
|
76 |
+
def val_dataloader(self):
|
77 |
+
return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
78 |
+
|
79 |
+
def test_dataloader(self):
|
80 |
+
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
81 |
+
|
82 |
+
def forward(self, data):
|
83 |
+
h = data['one_hot']
|
84 |
+
x = data['positions']
|
85 |
+
fragment_mask = data['fragment_mask']
|
86 |
+
linker_mask = data['linker_mask']
|
87 |
+
edge_mask = data['edge_mask']
|
88 |
+
edges = data['edges']
|
89 |
+
|
90 |
+
# Considering only fragments
|
91 |
+
x = x * fragment_mask
|
92 |
+
h = h * fragment_mask
|
93 |
+
|
94 |
+
# Reshaping
|
95 |
+
bs, n_nodes = x.shape[0], x.shape[1]
|
96 |
+
fragment_mask = fragment_mask.view(bs * n_nodes, 1)
|
97 |
+
x = x.view(bs * n_nodes, -1)
|
98 |
+
h = h.view(bs * n_nodes, -1)
|
99 |
+
|
100 |
+
# Prediction
|
101 |
+
distances, _ = coord2diff(x, edges)
|
102 |
+
distance_edge_mask = (edge_mask.bool() & (distances < 6)).long()
|
103 |
+
output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
|
104 |
+
output = output.view(bs, n_nodes, -1).mean(1)
|
105 |
+
|
106 |
+
true = self.get_true_labels(linker_mask)
|
107 |
+
loss = cross_entropy(output, true, weight=self.loss_weights)
|
108 |
+
|
109 |
+
return output, loss
|
110 |
+
|
111 |
+
def get_true_labels(self, linker_mask):
|
112 |
+
labels = []
|
113 |
+
sizes = linker_mask.squeeze().sum(-1).long().detach().cpu().numpy()
|
114 |
+
for size in sizes:
|
115 |
+
label = self.linker_size2id.get(size)
|
116 |
+
if label is None:
|
117 |
+
label = self.linker_size2id[max(self.linker_id2size)]
|
118 |
+
labels.append(label)
|
119 |
+
labels = torch.tensor(labels, device=linker_mask.device, dtype=torch.long)
|
120 |
+
return labels
|
121 |
+
|
122 |
+
def training_step(self, data, *args):
|
123 |
+
_, loss = self.forward(data)
|
124 |
+
return {'loss': loss}
|
125 |
+
|
126 |
+
def validation_step(self, data, *args):
|
127 |
+
_, loss = self.forward(data)
|
128 |
+
return {'loss': loss}
|
129 |
+
|
130 |
+
def test_step(self, data, *args):
|
131 |
+
loss = self.forward(data)
|
132 |
+
return {'loss': loss}
|
133 |
+
|
134 |
+
def training_epoch_end(self, training_step_outputs):
|
135 |
+
for metric in training_step_outputs[0].keys():
|
136 |
+
avg_metric = self.aggregate_metric(training_step_outputs, metric)
|
137 |
+
self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
|
138 |
+
self.log(f'{metric}/train', avg_metric, prog_bar=True)
|
139 |
+
|
140 |
+
def validation_epoch_end(self, validation_step_outputs):
|
141 |
+
for metric in validation_step_outputs[0].keys():
|
142 |
+
avg_metric = self.aggregate_metric(validation_step_outputs, metric)
|
143 |
+
self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
|
144 |
+
self.log(f'{metric}/val', avg_metric, prog_bar=True)
|
145 |
+
|
146 |
+
correct = 0
|
147 |
+
total = 0
|
148 |
+
for data in self.val_dataloader():
|
149 |
+
output, _ = self.forward(data)
|
150 |
+
pred = output.argmax(dim=-1)
|
151 |
+
true = self.get_true_labels(data['linker_mask'])
|
152 |
+
correct += (pred == true).sum()
|
153 |
+
total += len(pred)
|
154 |
+
|
155 |
+
accuracy = correct / total
|
156 |
+
self.metrics.setdefault(f'accuracy/val', []).append(accuracy)
|
157 |
+
self.log(f'accuracy/val', accuracy, prog_bar=True)
|
158 |
+
|
159 |
+
def configure_optimizers(self):
|
160 |
+
return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def aggregate_metric(step_outputs, metric):
|
164 |
+
return torch.tensor([out[metric] for out in step_outputs]).mean()
|
165 |
+
|
166 |
+
|
167 |
+
class SizeOrdinalClassifier(pl.LightningModule):
|
168 |
+
train_dataset = None
|
169 |
+
val_dataset = None
|
170 |
+
test_dataset = None
|
171 |
+
metrics: Dict[str, List[float]] = {}
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self, data_path, train_data_prefix, val_data_prefix,
|
175 |
+
in_node_nf, hidden_nf, out_node_nf, n_layers, batch_size, lr, torch_device,
|
176 |
+
normalization=None,
|
177 |
+
min_linker_size=None,
|
178 |
+
linker_size2id=ZINC_TRAIN_LINKER_SIZE2ID,
|
179 |
+
linker_id2size=ZINC_TRAIN_LINKER_ID2SIZE,
|
180 |
+
task='ordinal',
|
181 |
+
):
|
182 |
+
super(SizeOrdinalClassifier, self).__init__()
|
183 |
+
|
184 |
+
self.save_hyperparameters()
|
185 |
+
self.data_path = data_path
|
186 |
+
self.train_data_prefix = train_data_prefix
|
187 |
+
self.val_data_prefix = val_data_prefix
|
188 |
+
self.min_linker_size = min_linker_size
|
189 |
+
self.batch_size = batch_size
|
190 |
+
self.lr = lr
|
191 |
+
self.torch_device = torch_device
|
192 |
+
self.linker_size2id = linker_size2id
|
193 |
+
self.linker_id2size = linker_id2size
|
194 |
+
self.gnn = SizeGNN(
|
195 |
+
in_node_nf=in_node_nf,
|
196 |
+
hidden_nf=hidden_nf,
|
197 |
+
out_node_nf=out_node_nf,
|
198 |
+
n_layers=n_layers,
|
199 |
+
device=torch_device,
|
200 |
+
normalization=normalization,
|
201 |
+
)
|
202 |
+
|
203 |
+
def setup(self, stage: Optional[str] = None):
|
204 |
+
if stage == 'fit':
|
205 |
+
self.train_dataset = ZincDataset(
|
206 |
+
data_path=self.data_path,
|
207 |
+
prefix=self.train_data_prefix,
|
208 |
+
device=self.torch_device
|
209 |
+
)
|
210 |
+
self.val_dataset = ZincDataset(
|
211 |
+
data_path=self.data_path,
|
212 |
+
prefix=self.val_data_prefix,
|
213 |
+
device=self.torch_device
|
214 |
+
)
|
215 |
+
elif stage == 'val':
|
216 |
+
self.val_dataset = ZincDataset(
|
217 |
+
data_path=self.data_path,
|
218 |
+
prefix=self.val_data_prefix,
|
219 |
+
device=self.torch_device
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
raise NotImplementedError
|
223 |
+
|
224 |
+
def train_dataloader(self):
|
225 |
+
return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True)
|
226 |
+
|
227 |
+
def val_dataloader(self):
|
228 |
+
return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
229 |
+
|
230 |
+
def test_dataloader(self):
|
231 |
+
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
232 |
+
|
233 |
+
def forward(self, data):
|
234 |
+
h = data['one_hot']
|
235 |
+
x = data['positions']
|
236 |
+
fragment_mask = data['fragment_mask']
|
237 |
+
linker_mask = data['linker_mask']
|
238 |
+
edge_mask = data['edge_mask']
|
239 |
+
edges = data['edges']
|
240 |
+
|
241 |
+
# Considering only fragments
|
242 |
+
x = x * fragment_mask
|
243 |
+
h = h * fragment_mask
|
244 |
+
|
245 |
+
# Reshaping
|
246 |
+
bs, n_nodes = x.shape[0], x.shape[1]
|
247 |
+
fragment_mask = fragment_mask.view(bs * n_nodes, 1)
|
248 |
+
x = x.view(bs * n_nodes, -1)
|
249 |
+
h = h.view(bs * n_nodes, -1)
|
250 |
+
|
251 |
+
# Prediction
|
252 |
+
distances, _ = coord2diff(x, edges)
|
253 |
+
distance_edge_mask = (edge_mask.bool() & (distances < 6)).long()
|
254 |
+
output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
|
255 |
+
output = output.view(bs, n_nodes, -1).mean(1)
|
256 |
+
output = sigmoid(output)
|
257 |
+
|
258 |
+
true = self.get_true_labels(linker_mask)
|
259 |
+
loss = self.ordinal_loss(output, true)
|
260 |
+
|
261 |
+
return output, loss
|
262 |
+
|
263 |
+
def ordinal_loss(self, pred, true):
|
264 |
+
target = torch.zeros_like(pred, device=self.torch_device)
|
265 |
+
for i, label in enumerate(true):
|
266 |
+
target[i, 0:label + 1] = 1
|
267 |
+
|
268 |
+
return mse_loss(pred, target, reduction='none').sum(1).mean()
|
269 |
+
|
270 |
+
def get_true_labels(self, linker_mask):
|
271 |
+
labels = []
|
272 |
+
sizes = linker_mask.squeeze().sum(-1).long().detach().cpu().numpy()
|
273 |
+
for size in sizes:
|
274 |
+
label = self.linker_size2id.get(size)
|
275 |
+
if label is None:
|
276 |
+
label = self.linker_size2id[max(self.linker_id2size)]
|
277 |
+
labels.append(label)
|
278 |
+
labels = torch.tensor(labels, device=linker_mask.device, dtype=torch.long)
|
279 |
+
return labels
|
280 |
+
|
281 |
+
@staticmethod
|
282 |
+
def prediction2label(pred):
|
283 |
+
return torch.cumprod(pred > 0.5, dim=1).sum(dim=1) - 1
|
284 |
+
|
285 |
+
def training_step(self, data, *args):
|
286 |
+
_, loss = self.forward(data)
|
287 |
+
return {'loss': loss}
|
288 |
+
|
289 |
+
def validation_step(self, data, *args):
|
290 |
+
_, loss = self.forward(data)
|
291 |
+
return {'loss': loss}
|
292 |
+
|
293 |
+
def test_step(self, data, *args):
|
294 |
+
loss = self.forward(data)
|
295 |
+
return {'loss': loss}
|
296 |
+
|
297 |
+
def training_epoch_end(self, training_step_outputs):
|
298 |
+
for metric in training_step_outputs[0].keys():
|
299 |
+
avg_metric = self.aggregate_metric(training_step_outputs, metric)
|
300 |
+
self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
|
301 |
+
self.log(f'{metric}/train', avg_metric, prog_bar=True)
|
302 |
+
|
303 |
+
def validation_epoch_end(self, validation_step_outputs):
|
304 |
+
for metric in validation_step_outputs[0].keys():
|
305 |
+
avg_metric = self.aggregate_metric(validation_step_outputs, metric)
|
306 |
+
self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
|
307 |
+
self.log(f'{metric}/val', avg_metric, prog_bar=True)
|
308 |
+
|
309 |
+
correct = 0
|
310 |
+
total = 0
|
311 |
+
for data in self.val_dataloader():
|
312 |
+
output, _ = self.forward(data)
|
313 |
+
pred = self.prediction2label(output)
|
314 |
+
true = self.get_true_labels(data['linker_mask'])
|
315 |
+
correct += (pred == true).sum()
|
316 |
+
total += len(pred)
|
317 |
+
|
318 |
+
accuracy = correct / total
|
319 |
+
self.metrics.setdefault(f'accuracy/val', []).append(accuracy)
|
320 |
+
self.log(f'accuracy/val', accuracy, prog_bar=True)
|
321 |
+
|
322 |
+
def configure_optimizers(self):
|
323 |
+
return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
|
324 |
+
|
325 |
+
@staticmethod
|
326 |
+
def aggregate_metric(step_outputs, metric):
|
327 |
+
return torch.tensor([out[metric] for out in step_outputs]).mean()
|
328 |
+
|
329 |
+
|
330 |
+
class SizeRegressor(pl.LightningModule):
|
331 |
+
train_dataset = None
|
332 |
+
val_dataset = None
|
333 |
+
test_dataset = None
|
334 |
+
metrics: Dict[str, List[float]] = {}
|
335 |
+
|
336 |
+
def __init__(
|
337 |
+
self, data_path, train_data_prefix, val_data_prefix,
|
338 |
+
in_node_nf, hidden_nf, n_layers, batch_size, lr, torch_device,
|
339 |
+
normalization=None, task='regression',
|
340 |
+
):
|
341 |
+
super(SizeRegressor, self).__init__()
|
342 |
+
|
343 |
+
self.save_hyperparameters()
|
344 |
+
self.data_path = data_path
|
345 |
+
self.train_data_prefix = train_data_prefix
|
346 |
+
self.val_data_prefix = val_data_prefix
|
347 |
+
self.batch_size = batch_size
|
348 |
+
self.lr = lr
|
349 |
+
self.torch_device = torch_device
|
350 |
+
self.gnn = SizeGNN(
|
351 |
+
in_node_nf=in_node_nf,
|
352 |
+
hidden_nf=hidden_nf,
|
353 |
+
out_node_nf=1,
|
354 |
+
n_layers=n_layers,
|
355 |
+
device=torch_device,
|
356 |
+
normalization=normalization,
|
357 |
+
)
|
358 |
+
|
359 |
+
def setup(self, stage: Optional[str] = None):
|
360 |
+
if stage == 'fit':
|
361 |
+
self.train_dataset = ZincDataset(
|
362 |
+
data_path=self.data_path,
|
363 |
+
prefix=self.train_data_prefix,
|
364 |
+
device=self.torch_device
|
365 |
+
)
|
366 |
+
self.val_dataset = ZincDataset(
|
367 |
+
data_path=self.data_path,
|
368 |
+
prefix=self.val_data_prefix,
|
369 |
+
device=self.torch_device
|
370 |
+
)
|
371 |
+
elif stage == 'val':
|
372 |
+
self.val_dataset = ZincDataset(
|
373 |
+
data_path=self.data_path,
|
374 |
+
prefix=self.val_data_prefix,
|
375 |
+
device=self.torch_device
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
raise NotImplementedError
|
379 |
+
|
380 |
+
def train_dataloader(self):
|
381 |
+
return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True)
|
382 |
+
|
383 |
+
def val_dataloader(self):
|
384 |
+
return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
385 |
+
|
386 |
+
def test_dataloader(self):
|
387 |
+
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
388 |
+
|
389 |
+
def forward(self, data):
|
390 |
+
h = data['one_hot']
|
391 |
+
x = data['positions']
|
392 |
+
fragment_mask = data['fragment_mask']
|
393 |
+
linker_mask = data['linker_mask']
|
394 |
+
edge_mask = data['edge_mask']
|
395 |
+
edges = data['edges']
|
396 |
+
|
397 |
+
# Considering only fragments
|
398 |
+
x = x * fragment_mask
|
399 |
+
h = h * fragment_mask
|
400 |
+
|
401 |
+
# Reshaping
|
402 |
+
bs, n_nodes = x.shape[0], x.shape[1]
|
403 |
+
fragment_mask = fragment_mask.view(bs * n_nodes, 1)
|
404 |
+
x = x.view(bs * n_nodes, -1)
|
405 |
+
h = h.view(bs * n_nodes, -1)
|
406 |
+
|
407 |
+
# Prediction
|
408 |
+
distances, _ = coord2diff(x, edges)
|
409 |
+
distance_edge_mask = (edge_mask.bool() & (distances < 6)).long()
|
410 |
+
output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
|
411 |
+
output = output.view(bs, n_nodes, -1).mean(1).squeeze()
|
412 |
+
|
413 |
+
true = linker_mask.squeeze().sum(-1).float()
|
414 |
+
loss = mse_loss(output, true)
|
415 |
+
|
416 |
+
return output, loss
|
417 |
+
|
418 |
+
def training_step(self, data, *args):
|
419 |
+
_, loss = self.forward(data)
|
420 |
+
return {'loss': loss}
|
421 |
+
|
422 |
+
def validation_step(self, data, *args):
|
423 |
+
_, loss = self.forward(data)
|
424 |
+
return {'loss': loss}
|
425 |
+
|
426 |
+
def test_step(self, data, *args):
|
427 |
+
loss = self.forward(data)
|
428 |
+
return {'loss': loss}
|
429 |
+
|
430 |
+
def training_epoch_end(self, training_step_outputs):
|
431 |
+
for metric in training_step_outputs[0].keys():
|
432 |
+
avg_metric = self.aggregate_metric(training_step_outputs, metric)
|
433 |
+
self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
|
434 |
+
self.log(f'{metric}/train', avg_metric, prog_bar=True)
|
435 |
+
|
436 |
+
def validation_epoch_end(self, validation_step_outputs):
|
437 |
+
for metric in validation_step_outputs[0].keys():
|
438 |
+
avg_metric = self.aggregate_metric(validation_step_outputs, metric)
|
439 |
+
self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
|
440 |
+
self.log(f'{metric}/val', avg_metric, prog_bar=True)
|
441 |
+
|
442 |
+
correct = 0
|
443 |
+
total = 0
|
444 |
+
for data in self.val_dataloader():
|
445 |
+
output, _ = self.forward(data)
|
446 |
+
pred = torch.round(output).long()
|
447 |
+
true = data['linker_mask'].squeeze().sum(-1).long()
|
448 |
+
correct += (pred == true).sum()
|
449 |
+
total += len(pred)
|
450 |
+
|
451 |
+
accuracy = correct / total
|
452 |
+
self.metrics.setdefault(f'accuracy/val', []).append(accuracy)
|
453 |
+
self.log(f'accuracy/val', accuracy, prog_bar=True)
|
454 |
+
|
455 |
+
def configure_optimizers(self):
|
456 |
+
return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
|
457 |
+
|
458 |
+
@staticmethod
|
459 |
+
def aggregate_metric(step_outputs, metric):
|
460 |
+
return torch.tensor([out[metric] for out in step_outputs]).mean()
|
src/metrics.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from rdkit import Chem
|
4 |
+
from rdkit.Chem import AllChem
|
5 |
+
from src import const
|
6 |
+
from src.molecule_builder import get_bond_order
|
7 |
+
from scipy.stats import wasserstein_distance
|
8 |
+
|
9 |
+
from pdb import set_trace
|
10 |
+
|
11 |
+
|
12 |
+
def is_valid(mol):
|
13 |
+
try:
|
14 |
+
Chem.SanitizeMol(mol)
|
15 |
+
except ValueError:
|
16 |
+
return False
|
17 |
+
return True
|
18 |
+
|
19 |
+
|
20 |
+
def is_connected(mol):
|
21 |
+
try:
|
22 |
+
mol_frags = Chem.GetMolFrags(mol, asMols=True)
|
23 |
+
except Chem.rdchem.AtomValenceException:
|
24 |
+
return False
|
25 |
+
if len(mol_frags) != 1:
|
26 |
+
return False
|
27 |
+
return True
|
28 |
+
|
29 |
+
|
30 |
+
def get_valid_molecules(molecules):
|
31 |
+
valid = []
|
32 |
+
for mol in molecules:
|
33 |
+
if is_valid(mol):
|
34 |
+
valid.append(mol)
|
35 |
+
return valid
|
36 |
+
|
37 |
+
|
38 |
+
def get_connected_molecules(molecules):
|
39 |
+
connected = []
|
40 |
+
for mol in molecules:
|
41 |
+
if is_connected(mol):
|
42 |
+
connected.append(mol)
|
43 |
+
return connected
|
44 |
+
|
45 |
+
|
46 |
+
def get_unique_smiles(valid_molecules):
|
47 |
+
unique = set()
|
48 |
+
for mol in valid_molecules:
|
49 |
+
unique.add(Chem.MolToSmiles(mol))
|
50 |
+
return list(unique)
|
51 |
+
|
52 |
+
|
53 |
+
def get_novel_smiles(unique_true_smiles, unique_pred_smiles):
|
54 |
+
return list(set(unique_pred_smiles).difference(set(unique_true_smiles)))
|
55 |
+
|
56 |
+
|
57 |
+
def compute_energy(mol):
|
58 |
+
mp = AllChem.MMFFGetMoleculeProperties(mol)
|
59 |
+
energy = AllChem.MMFFGetMoleculeForceField(mol, mp, confId=0).CalcEnergy()
|
60 |
+
return energy
|
61 |
+
|
62 |
+
|
63 |
+
def wasserstein_distance_between_energies(true_molecules, pred_molecules):
|
64 |
+
true_energy_dist = []
|
65 |
+
for mol in true_molecules:
|
66 |
+
try:
|
67 |
+
energy = compute_energy(mol)
|
68 |
+
true_energy_dist.append(energy)
|
69 |
+
except:
|
70 |
+
continue
|
71 |
+
|
72 |
+
pred_energy_dist = []
|
73 |
+
for mol in pred_molecules:
|
74 |
+
try:
|
75 |
+
energy = compute_energy(mol)
|
76 |
+
pred_energy_dist.append(energy)
|
77 |
+
except:
|
78 |
+
continue
|
79 |
+
|
80 |
+
if len(true_energy_dist) > 0 and len(pred_energy_dist) > 0:
|
81 |
+
return wasserstein_distance(true_energy_dist, pred_energy_dist)
|
82 |
+
else:
|
83 |
+
return 0
|
84 |
+
|
85 |
+
|
86 |
+
def compute_metrics(pred_molecules, true_molecules):
|
87 |
+
if len(pred_molecules) == 0:
|
88 |
+
return {
|
89 |
+
'validity': 0,
|
90 |
+
'validity_and_connectivity': 0,
|
91 |
+
'validity_as_in_delinker': 0,
|
92 |
+
'uniqueness': 0,
|
93 |
+
'novelty': 0,
|
94 |
+
'energies': 0,
|
95 |
+
}
|
96 |
+
|
97 |
+
# Passing rdkit.Chem.Sanitize filter
|
98 |
+
true_valid = get_valid_molecules(true_molecules)
|
99 |
+
pred_valid = get_valid_molecules(pred_molecules)
|
100 |
+
validity = len(pred_valid) / len(pred_molecules)
|
101 |
+
|
102 |
+
# Checking if molecule consists of a single connected part
|
103 |
+
true_valid_and_connected = get_connected_molecules(true_valid)
|
104 |
+
pred_valid_and_connected = get_connected_molecules(pred_valid)
|
105 |
+
validity_and_connectivity = len(pred_valid_and_connected) / len(pred_molecules)
|
106 |
+
|
107 |
+
# Unique molecules
|
108 |
+
true_unique = get_unique_smiles(true_valid_and_connected)
|
109 |
+
pred_unique = get_unique_smiles(pred_valid_and_connected)
|
110 |
+
uniqueness = len(pred_unique) / len(pred_valid_and_connected) if len(pred_valid_and_connected) > 0 else 0
|
111 |
+
|
112 |
+
# Novel molecules
|
113 |
+
pred_novel = get_novel_smiles(true_unique, pred_unique)
|
114 |
+
novelty = len(pred_novel) / len(pred_unique) if len(pred_unique) > 0 else 0
|
115 |
+
|
116 |
+
# Difference between Energy distributions
|
117 |
+
energies = wasserstein_distance_between_energies(true_valid_and_connected, pred_valid_and_connected)
|
118 |
+
|
119 |
+
return {
|
120 |
+
'validity': validity,
|
121 |
+
'validity_and_connectivity': validity_and_connectivity,
|
122 |
+
'uniqueness': uniqueness,
|
123 |
+
'novelty': novelty,
|
124 |
+
'energies': energies,
|
125 |
+
}
|
126 |
+
|
127 |
+
|
128 |
+
# def check_stability(positions, atom_types):
|
129 |
+
# assert len(positions.shape) == 2
|
130 |
+
# assert positions.shape[1] == 3
|
131 |
+
# x = positions[:, 0]
|
132 |
+
# y = positions[:, 1]
|
133 |
+
# z = positions[:, 2]
|
134 |
+
#
|
135 |
+
# nr_bonds = np.zeros(len(x), dtype='int')
|
136 |
+
# for i in range(len(x)):
|
137 |
+
# for j in range(i + 1, len(x)):
|
138 |
+
# p1 = np.array([x[i], y[i], z[i]])
|
139 |
+
# p2 = np.array([x[j], y[j], z[j]])
|
140 |
+
# dist = np.sqrt(np.sum((p1 - p2) ** 2))
|
141 |
+
# atom1, atom2 = const.IDX2ATOM[atom_types[i].item()], const.IDX2ATOM[atom_types[j].item()]
|
142 |
+
# order = get_bond_order(atom1, atom2, dist)
|
143 |
+
# nr_bonds[i] += order
|
144 |
+
# nr_bonds[j] += order
|
145 |
+
# nr_stable_bonds = 0
|
146 |
+
# for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds):
|
147 |
+
# possible_bonds = const.ALLOWED_BONDS[const.IDX2ATOM[atom_type_i.item()]]
|
148 |
+
# if type(possible_bonds) == int:
|
149 |
+
# is_stable = possible_bonds == nr_bonds_i
|
150 |
+
# else:
|
151 |
+
# is_stable = nr_bonds_i in possible_bonds
|
152 |
+
# nr_stable_bonds += int(is_stable)
|
153 |
+
#
|
154 |
+
# molecule_stable = nr_stable_bonds == len(x)
|
155 |
+
# return molecule_stable, nr_stable_bonds, len(x)
|
156 |
+
#
|
157 |
+
#
|
158 |
+
# def count_stable_molecules(one_hot, x, node_mask):
|
159 |
+
# stable_molecules = 0
|
160 |
+
# for i in range(len(one_hot)):
|
161 |
+
# mol_size = node_mask[i].sum()
|
162 |
+
# atom_types = one_hot[i][:mol_size, :].argmax(dim=1).detach().cpu()
|
163 |
+
# positions = x[i][:mol_size, :].detach().cpu()
|
164 |
+
# stable, _, _ = check_stability(positions, atom_types)
|
165 |
+
# stable_molecules += int(stable)
|
166 |
+
#
|
167 |
+
# return stable_molecules
|
src/molecule_builder.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from rdkit import Chem, Geometry
|
5 |
+
|
6 |
+
from src import const
|
7 |
+
|
8 |
+
|
9 |
+
def create_conformer(coords):
|
10 |
+
conformer = Chem.Conformer()
|
11 |
+
for i, (x, y, z) in enumerate(coords):
|
12 |
+
conformer.SetAtomPosition(i, Geometry.Point3D(x, y, z))
|
13 |
+
return conformer
|
14 |
+
|
15 |
+
|
16 |
+
def build_molecules(one_hot, x, node_mask, is_geom, margins=const.MARGINS_EDM):
|
17 |
+
molecules = []
|
18 |
+
for i in range(len(one_hot)):
|
19 |
+
mask = node_mask[i].squeeze() == 1
|
20 |
+
atom_types = one_hot[i][mask].argmax(dim=1).detach().cpu()
|
21 |
+
positions = x[i][mask].detach().cpu()
|
22 |
+
mol = build_molecule(positions, atom_types, is_geom, margins=margins)
|
23 |
+
molecules.append(mol)
|
24 |
+
|
25 |
+
return molecules
|
26 |
+
|
27 |
+
|
28 |
+
def build_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM):
|
29 |
+
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
|
30 |
+
X, A, E = build_xae_molecule(positions, atom_types, is_geom=is_geom, margins=margins)
|
31 |
+
mol = Chem.RWMol()
|
32 |
+
for atom in X:
|
33 |
+
a = Chem.Atom(idx2atom[atom.item()])
|
34 |
+
mol.AddAtom(a)
|
35 |
+
|
36 |
+
all_bonds = torch.nonzero(A)
|
37 |
+
for bond in all_bonds:
|
38 |
+
mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()])
|
39 |
+
|
40 |
+
mol.AddConformer(create_conformer(positions.detach().cpu().numpy().astype(np.float64)))
|
41 |
+
return mol
|
42 |
+
|
43 |
+
|
44 |
+
def build_xae_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM):
|
45 |
+
""" Returns a triplet (X, A, E): atom_types, adjacency matrix, edge_types
|
46 |
+
args:
|
47 |
+
positions: N x 3 (already masked to keep final number nodes)
|
48 |
+
atom_types: N
|
49 |
+
returns:
|
50 |
+
X: N (int)
|
51 |
+
A: N x N (bool) (binary adjacency matrix)
|
52 |
+
E: N x N (int) (bond type, 0 if no bond) such that A = E.bool()
|
53 |
+
"""
|
54 |
+
n = positions.shape[0]
|
55 |
+
X = atom_types
|
56 |
+
A = torch.zeros((n, n), dtype=torch.bool)
|
57 |
+
E = torch.zeros((n, n), dtype=torch.int)
|
58 |
+
|
59 |
+
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
|
60 |
+
|
61 |
+
pos = positions.unsqueeze(0)
|
62 |
+
dists = torch.cdist(pos, pos, p=2).squeeze(0)
|
63 |
+
for i in range(n):
|
64 |
+
for j in range(i):
|
65 |
+
|
66 |
+
pair = sorted([atom_types[i], atom_types[j]])
|
67 |
+
order = get_bond_order(idx2atom[pair[0].item()], idx2atom[pair[1].item()], dists[i, j], margins=margins)
|
68 |
+
|
69 |
+
# TODO: a batched version of get_bond_order to avoid the for loop
|
70 |
+
if order > 0:
|
71 |
+
# Warning: the graph should be DIRECTED
|
72 |
+
A[i, j] = 1
|
73 |
+
E[i, j] = order
|
74 |
+
|
75 |
+
return X, A, E
|
76 |
+
|
77 |
+
|
78 |
+
def get_bond_order(atom1, atom2, distance, check_exists=True, margins=const.MARGINS_EDM):
|
79 |
+
distance = 100 * distance # We change the metric
|
80 |
+
|
81 |
+
# Check exists for large molecules where some atom pairs do not have a
|
82 |
+
# typical bond length.
|
83 |
+
if check_exists:
|
84 |
+
if atom1 not in const.BONDS_1:
|
85 |
+
return 0
|
86 |
+
if atom2 not in const.BONDS_1[atom1]:
|
87 |
+
return 0
|
88 |
+
|
89 |
+
# margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples
|
90 |
+
if distance < const.BONDS_1[atom1][atom2] + margins[0]:
|
91 |
+
|
92 |
+
# Check if atoms in bonds2 dictionary.
|
93 |
+
if atom1 in const.BONDS_2 and atom2 in const.BONDS_2[atom1]:
|
94 |
+
thr_bond2 = const.BONDS_2[atom1][atom2] + margins[1]
|
95 |
+
if distance < thr_bond2:
|
96 |
+
if atom1 in const.BONDS_3 and atom2 in const.BONDS_3[atom1]:
|
97 |
+
thr_bond3 = const.BONDS_3[atom1][atom2] + margins[2]
|
98 |
+
if distance < thr_bond3:
|
99 |
+
return 3 # Triple
|
100 |
+
return 2 # Double
|
101 |
+
return 1 # Single
|
102 |
+
return 0 # No bond
|
src/noise.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def clip_noise_schedule(alphas2, clip_value=0.001):
|
8 |
+
"""
|
9 |
+
For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during
|
10 |
+
sampling.
|
11 |
+
"""
|
12 |
+
alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)
|
13 |
+
|
14 |
+
alphas_step = (alphas2[1:] / alphas2[:-1])
|
15 |
+
|
16 |
+
alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.)
|
17 |
+
alphas2 = np.cumprod(alphas_step, axis=0)
|
18 |
+
|
19 |
+
return alphas2
|
20 |
+
|
21 |
+
|
22 |
+
def polynomial_schedule(timesteps: int, s=1e-4, power=3.):
|
23 |
+
"""
|
24 |
+
A noise schedule based on a simple polynomial equation: 1 - x^power.
|
25 |
+
"""
|
26 |
+
steps = timesteps + 1
|
27 |
+
x = np.linspace(0, steps, steps)
|
28 |
+
alphas2 = (1 - np.power(x / steps, power)) ** 2
|
29 |
+
|
30 |
+
alphas2 = clip_noise_schedule(alphas2, clip_value=0.001)
|
31 |
+
|
32 |
+
precision = 1 - 2 * s
|
33 |
+
|
34 |
+
alphas2 = precision * alphas2 + s
|
35 |
+
|
36 |
+
return alphas2
|
37 |
+
|
38 |
+
|
39 |
+
def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1):
|
40 |
+
"""
|
41 |
+
cosine schedule
|
42 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
43 |
+
"""
|
44 |
+
steps = timesteps + 2
|
45 |
+
x = np.linspace(0, steps, steps)
|
46 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
47 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
48 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
49 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
50 |
+
alphas = 1. - betas
|
51 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
52 |
+
|
53 |
+
if raise_to_power != 1:
|
54 |
+
alphas_cumprod = np.power(alphas_cumprod, raise_to_power)
|
55 |
+
|
56 |
+
return alphas_cumprod
|
57 |
+
|
58 |
+
|
59 |
+
class PositiveLinear(torch.nn.Module):
|
60 |
+
"""Linear layer with weights forced to be positive."""
|
61 |
+
|
62 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
63 |
+
weight_init_offset: int = -2):
|
64 |
+
super(PositiveLinear, self).__init__()
|
65 |
+
self.in_features = in_features
|
66 |
+
self.out_features = out_features
|
67 |
+
self.weight = torch.nn.Parameter(
|
68 |
+
torch.empty((out_features, in_features)))
|
69 |
+
if bias:
|
70 |
+
self.bias = torch.nn.Parameter(torch.empty(out_features))
|
71 |
+
else:
|
72 |
+
self.register_parameter('bias', None)
|
73 |
+
self.weight_init_offset = weight_init_offset
|
74 |
+
self.reset_parameters()
|
75 |
+
|
76 |
+
def reset_parameters(self) -> None:
|
77 |
+
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
78 |
+
|
79 |
+
with torch.no_grad():
|
80 |
+
self.weight.add_(self.weight_init_offset)
|
81 |
+
|
82 |
+
if self.bias is not None:
|
83 |
+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
|
84 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
85 |
+
torch.nn.init.uniform_(self.bias, -bound, bound)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
positive_weight = F.softplus(self.weight)
|
89 |
+
return F.linear(x, positive_weight, self.bias)
|
90 |
+
|
91 |
+
|
92 |
+
class PredefinedNoiseSchedule(torch.nn.Module):
|
93 |
+
"""
|
94 |
+
Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, noise_schedule, timesteps, precision):
|
98 |
+
super(PredefinedNoiseSchedule, self).__init__()
|
99 |
+
self.timesteps = timesteps
|
100 |
+
|
101 |
+
if noise_schedule == 'cosine':
|
102 |
+
alphas2 = cosine_beta_schedule(timesteps)
|
103 |
+
elif 'polynomial' in noise_schedule:
|
104 |
+
splits = noise_schedule.split('_')
|
105 |
+
assert len(splits) == 2
|
106 |
+
power = float(splits[1])
|
107 |
+
alphas2 = polynomial_schedule(timesteps, s=precision, power=power)
|
108 |
+
else:
|
109 |
+
raise ValueError(noise_schedule)
|
110 |
+
|
111 |
+
# print('alphas2', alphas2)
|
112 |
+
|
113 |
+
sigmas2 = 1 - alphas2
|
114 |
+
|
115 |
+
log_alphas2 = np.log(alphas2)
|
116 |
+
log_sigmas2 = np.log(sigmas2)
|
117 |
+
|
118 |
+
log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2
|
119 |
+
|
120 |
+
# print('gamma', -log_alphas2_to_sigmas2)
|
121 |
+
|
122 |
+
self.gamma = torch.nn.Parameter(
|
123 |
+
torch.from_numpy(-log_alphas2_to_sigmas2).float(),
|
124 |
+
requires_grad=False)
|
125 |
+
|
126 |
+
def forward(self, t):
|
127 |
+
t_int = torch.round(t * self.timesteps).long()
|
128 |
+
return self.gamma[t_int]
|
129 |
+
|
130 |
+
|
131 |
+
class GammaNetwork(torch.nn.Module):
|
132 |
+
"""The gamma network models a monotonic increasing function. Construction as in the VDM paper."""
|
133 |
+
|
134 |
+
def __init__(self):
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.l1 = PositiveLinear(1, 1)
|
138 |
+
self.l2 = PositiveLinear(1, 1024)
|
139 |
+
self.l3 = PositiveLinear(1024, 1)
|
140 |
+
|
141 |
+
self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.]))
|
142 |
+
self.gamma_1 = torch.nn.Parameter(torch.tensor([10.]))
|
143 |
+
self.show_schedule()
|
144 |
+
|
145 |
+
def show_schedule(self, num_steps=50):
|
146 |
+
t = torch.linspace(0, 1, num_steps).view(num_steps, 1)
|
147 |
+
gamma = self.forward(t)
|
148 |
+
print('Gamma schedule:')
|
149 |
+
print(gamma.detach().cpu().numpy().reshape(num_steps))
|
150 |
+
|
151 |
+
def gamma_tilde(self, t):
|
152 |
+
l1_t = self.l1(t)
|
153 |
+
return l1_t + self.l3(torch.sigmoid(self.l2(l1_t)))
|
154 |
+
|
155 |
+
def forward(self, t):
|
156 |
+
zeros, ones = torch.zeros_like(t), torch.ones_like(t)
|
157 |
+
# Not super efficient.
|
158 |
+
gamma_tilde_0 = self.gamma_tilde(zeros)
|
159 |
+
gamma_tilde_1 = self.gamma_tilde(ones)
|
160 |
+
gamma_tilde_t = self.gamma_tilde(t)
|
161 |
+
|
162 |
+
# Normalize to [0, 1]
|
163 |
+
normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / (
|
164 |
+
gamma_tilde_1 - gamma_tilde_0)
|
165 |
+
|
166 |
+
# Rescale to [gamma_0, gamma_1]
|
167 |
+
gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma
|
168 |
+
|
169 |
+
return gamma
|
src/utils.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class Logger(object):
|
8 |
+
def __init__(self, logpath, syspart=sys.stdout):
|
9 |
+
self.terminal = syspart
|
10 |
+
self.log = open(logpath, "a")
|
11 |
+
|
12 |
+
def write(self, message):
|
13 |
+
|
14 |
+
self.terminal.write(message)
|
15 |
+
self.log.write(message)
|
16 |
+
self.log.flush()
|
17 |
+
|
18 |
+
def flush(self):
|
19 |
+
# this flush method is needed for python 3 compatibility.
|
20 |
+
# this handles the flush command by doing nothing.
|
21 |
+
# you might want to specify some extra behavior here.
|
22 |
+
pass
|
23 |
+
|
24 |
+
def log(*args):
|
25 |
+
print(f'[{datetime.now()}]', *args)
|
26 |
+
|
27 |
+
class EMA:
|
28 |
+
def __init__(self, beta):
|
29 |
+
super().__init__()
|
30 |
+
self.beta = beta
|
31 |
+
|
32 |
+
def update_model_average(self, ma_model, current_model):
|
33 |
+
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
34 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
35 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
36 |
+
|
37 |
+
def update_average(self, old, new):
|
38 |
+
if old is None:
|
39 |
+
return new
|
40 |
+
return old * self.beta + (1 - self.beta) * new
|
41 |
+
|
42 |
+
|
43 |
+
def sum_except_batch(x):
|
44 |
+
return x.reshape(x.size(0), -1).sum(dim=-1)
|
45 |
+
|
46 |
+
|
47 |
+
def remove_mean(x):
|
48 |
+
mean = torch.mean(x, dim=1, keepdim=True)
|
49 |
+
x = x - mean
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
def remove_mean_with_mask(x, node_mask):
|
54 |
+
masked_max_abs_value = (x * (1 - node_mask)).abs().sum().item()
|
55 |
+
assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high'
|
56 |
+
N = node_mask.sum(1, keepdims=True)
|
57 |
+
|
58 |
+
mean = torch.sum(x, dim=1, keepdim=True) / N
|
59 |
+
x = x - mean * node_mask
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
def remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask):
|
64 |
+
"""
|
65 |
+
Subtract center of mass of fragments from coordinates of all atoms
|
66 |
+
"""
|
67 |
+
x_masked = x * center_of_mass_mask
|
68 |
+
N = center_of_mass_mask.sum(1, keepdims=True)
|
69 |
+
mean = torch.sum(x_masked, dim=1, keepdim=True) / N
|
70 |
+
x = x - mean * node_mask
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
def assert_mean_zero(x):
|
75 |
+
mean = torch.mean(x, dim=1, keepdim=True)
|
76 |
+
assert mean.abs().max().item() < 1e-4
|
77 |
+
|
78 |
+
|
79 |
+
def assert_mean_zero_with_mask(x, node_mask, eps=1e-10):
|
80 |
+
assert_correctly_masked(x, node_mask)
|
81 |
+
largest_value = x.abs().max().item()
|
82 |
+
error = torch.sum(x, dim=1, keepdim=True).abs().max().item()
|
83 |
+
rel_error = error / (largest_value + eps)
|
84 |
+
assert rel_error < 1e-2, f'Mean is not zero, relative_error {rel_error}'
|
85 |
+
|
86 |
+
|
87 |
+
def assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask, eps=1e-10):
|
88 |
+
assert_correctly_masked(x, node_mask)
|
89 |
+
x_masked = x * center_of_mass_mask
|
90 |
+
largest_value = x_masked.abs().max().item()
|
91 |
+
error = torch.sum(x_masked, dim=1, keepdim=True).abs().max().item()
|
92 |
+
rel_error = error / (largest_value + eps)
|
93 |
+
assert rel_error < 1e-2, f'Partial mean is not zero, relative_error {rel_error}'
|
94 |
+
|
95 |
+
|
96 |
+
def assert_correctly_masked(variable, node_mask):
|
97 |
+
assert (variable * (1 - node_mask)).abs().max().item() < 1e-4, \
|
98 |
+
'Variables not masked properly.'
|
99 |
+
|
100 |
+
|
101 |
+
def check_mask_correct(variables, node_mask):
|
102 |
+
for i, variable in enumerate(variables):
|
103 |
+
if len(variable) > 0:
|
104 |
+
assert_correctly_masked(variable, node_mask)
|
105 |
+
|
106 |
+
|
107 |
+
def center_gravity_zero_gaussian_log_likelihood(x):
|
108 |
+
assert len(x.size()) == 3
|
109 |
+
B, N, D = x.size()
|
110 |
+
assert_mean_zero(x)
|
111 |
+
|
112 |
+
# r is invariant to a basis change in the relevant hyperplane.
|
113 |
+
r2 = sum_except_batch(x.pow(2))
|
114 |
+
|
115 |
+
# The relevant hyperplane is (N-1) * D dimensional.
|
116 |
+
degrees_of_freedom = (N-1) * D
|
117 |
+
|
118 |
+
# Normalizing constant and logpx are computed:
|
119 |
+
log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi)
|
120 |
+
log_px = -0.5 * r2 + log_normalizing_constant
|
121 |
+
|
122 |
+
return log_px
|
123 |
+
|
124 |
+
|
125 |
+
def sample_center_gravity_zero_gaussian(size, device):
|
126 |
+
assert len(size) == 3
|
127 |
+
x = torch.randn(size, device=device)
|
128 |
+
|
129 |
+
# This projection only works because Gaussian is rotation invariant around
|
130 |
+
# zero and samples are independent!
|
131 |
+
x_projected = remove_mean(x)
|
132 |
+
return x_projected
|
133 |
+
|
134 |
+
|
135 |
+
def center_gravity_zero_gaussian_log_likelihood_with_mask(x, node_mask):
|
136 |
+
assert len(x.size()) == 3
|
137 |
+
B, N_embedded, D = x.size()
|
138 |
+
assert_mean_zero_with_mask(x, node_mask)
|
139 |
+
|
140 |
+
# r is invariant to a basis change in the relevant hyperplane, the masked
|
141 |
+
# out values will have zero contribution.
|
142 |
+
r2 = sum_except_batch(x.pow(2))
|
143 |
+
|
144 |
+
# The relevant hyperplane is (N-1) * D dimensional.
|
145 |
+
N = node_mask.squeeze(2).sum(1) # N has shape [B]
|
146 |
+
degrees_of_freedom = (N-1) * D
|
147 |
+
|
148 |
+
# Normalizing constant and logpx are computed:
|
149 |
+
log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi)
|
150 |
+
log_px = -0.5 * r2 + log_normalizing_constant
|
151 |
+
|
152 |
+
return log_px
|
153 |
+
|
154 |
+
|
155 |
+
def sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask):
|
156 |
+
assert len(size) == 3
|
157 |
+
x = torch.randn(size, device=device)
|
158 |
+
|
159 |
+
x_masked = x * node_mask
|
160 |
+
|
161 |
+
# This projection only works because Gaussian is rotation invariant around
|
162 |
+
# zero and samples are independent!
|
163 |
+
# TODO: check it
|
164 |
+
x_projected = remove_mean_with_mask(x_masked, node_mask)
|
165 |
+
return x_projected
|
166 |
+
|
167 |
+
|
168 |
+
def standard_gaussian_log_likelihood(x):
|
169 |
+
# Normalizing constant and logpx are computed:
|
170 |
+
log_px = sum_except_batch(-0.5 * x * x - 0.5 * np.log(2*np.pi))
|
171 |
+
return log_px
|
172 |
+
|
173 |
+
|
174 |
+
def sample_gaussian(size, device):
|
175 |
+
x = torch.randn(size, device=device)
|
176 |
+
return x
|
177 |
+
|
178 |
+
|
179 |
+
def standard_gaussian_log_likelihood_with_mask(x, node_mask):
|
180 |
+
# Normalizing constant and logpx are computed:
|
181 |
+
log_px_elementwise = -0.5 * x * x - 0.5 * np.log(2*np.pi)
|
182 |
+
log_px = sum_except_batch(log_px_elementwise * node_mask)
|
183 |
+
return log_px
|
184 |
+
|
185 |
+
|
186 |
+
def sample_gaussian_with_mask(size, device, node_mask):
|
187 |
+
x = torch.randn(size, device=device)
|
188 |
+
x_masked = x * node_mask
|
189 |
+
return x_masked
|
190 |
+
|
191 |
+
|
192 |
+
def concatenate_features(x, h):
|
193 |
+
xh = torch.cat([x, h['categorical']], dim=2)
|
194 |
+
if 'integer' in h:
|
195 |
+
xh = torch.cat([xh, h['integer']], dim=2)
|
196 |
+
return xh
|
197 |
+
|
198 |
+
|
199 |
+
def split_features(z, n_dims, num_classes, include_charges):
|
200 |
+
assert z.size(2) == n_dims + num_classes + include_charges
|
201 |
+
x = z[:, :, 0:n_dims]
|
202 |
+
h = {'categorical': z[:, :, n_dims:n_dims+num_classes]}
|
203 |
+
if include_charges:
|
204 |
+
h['integer'] = z[:, :, n_dims+num_classes:n_dims+num_classes+1]
|
205 |
+
|
206 |
+
return x, h
|
207 |
+
|
208 |
+
|
209 |
+
# For gradient clipping
|
210 |
+
|
211 |
+
class Queue:
|
212 |
+
def __init__(self, max_len=50):
|
213 |
+
self.items = []
|
214 |
+
self.max_len = max_len
|
215 |
+
|
216 |
+
def __len__(self):
|
217 |
+
return len(self.items)
|
218 |
+
|
219 |
+
def add(self, item):
|
220 |
+
self.items.insert(0, item)
|
221 |
+
if len(self) > self.max_len:
|
222 |
+
self.items.pop()
|
223 |
+
|
224 |
+
def mean(self):
|
225 |
+
return np.mean(self.items)
|
226 |
+
|
227 |
+
def std(self):
|
228 |
+
return np.std(self.items)
|
229 |
+
|
230 |
+
|
231 |
+
def gradient_clipping(flow, gradnorm_queue):
|
232 |
+
# Allow gradient norm to be 150% + 2 * stdev of the recent history.
|
233 |
+
max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std()
|
234 |
+
|
235 |
+
# Clips gradient and returns the norm
|
236 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
237 |
+
flow.parameters(), max_norm=max_grad_norm, norm_type=2.0)
|
238 |
+
|
239 |
+
if float(grad_norm) > max_grad_norm:
|
240 |
+
gradnorm_queue.add(float(max_grad_norm))
|
241 |
+
else:
|
242 |
+
gradnorm_queue.add(float(grad_norm))
|
243 |
+
|
244 |
+
if float(grad_norm) > max_grad_norm:
|
245 |
+
print(f'Clipped gradient with value {grad_norm:.1f} while allowed {max_grad_norm:.1f}')
|
246 |
+
return grad_norm
|
247 |
+
|
248 |
+
|
249 |
+
def disable_rdkit_logging():
|
250 |
+
"""
|
251 |
+
Disables RDKit whiny logging.
|
252 |
+
"""
|
253 |
+
import rdkit.rdBase as rkrb
|
254 |
+
import rdkit.RDLogger as rkl
|
255 |
+
logger = rkl.logger()
|
256 |
+
logger.setLevel(rkl.ERROR)
|
257 |
+
rkrb.DisableLog('rdApp.error')
|
258 |
+
|
259 |
+
|
260 |
+
class FoundNaNException(Exception):
|
261 |
+
def __init__(self, x, h):
|
262 |
+
x_nan_idx = self.find_nan_idx(x)
|
263 |
+
h_nan_idx = self.find_nan_idx(h)
|
264 |
+
|
265 |
+
self.x_h_nan_idx = x_nan_idx & h_nan_idx
|
266 |
+
self.only_x_nan_idx = x_nan_idx.difference(h_nan_idx)
|
267 |
+
self.only_h_nan_idx = h_nan_idx.difference(x_nan_idx)
|
268 |
+
|
269 |
+
@staticmethod
|
270 |
+
def find_nan_idx(z):
|
271 |
+
idx = set()
|
272 |
+
for i in range(z.shape[0]):
|
273 |
+
if torch.any(torch.isnan(z[i])):
|
274 |
+
idx.add(i)
|
275 |
+
return idx
|
276 |
+
|
277 |
+
|
278 |
+
def get_batch_idx_for_animation(batch_size, batch_idx):
|
279 |
+
batch_indices = []
|
280 |
+
mol_indices = []
|
281 |
+
for idx in [0, 110, 360]:
|
282 |
+
if idx // batch_size == batch_idx:
|
283 |
+
batch_indices.append(idx % batch_size)
|
284 |
+
mol_indices.append(idx)
|
285 |
+
return batch_indices, mol_indices
|
286 |
+
|
287 |
+
|
288 |
+
# Rotation data augmntation
|
289 |
+
def random_rotation(x):
|
290 |
+
bs, n_nodes, n_dims = x.size()
|
291 |
+
device = x.device
|
292 |
+
angle_range = np.pi * 2
|
293 |
+
if n_dims == 2:
|
294 |
+
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
|
295 |
+
cos_theta = torch.cos(theta)
|
296 |
+
sin_theta = torch.sin(theta)
|
297 |
+
R_row0 = torch.cat([cos_theta, -sin_theta], dim=2)
|
298 |
+
R_row1 = torch.cat([sin_theta, cos_theta], dim=2)
|
299 |
+
R = torch.cat([R_row0, R_row1], dim=1)
|
300 |
+
|
301 |
+
x = x.transpose(1, 2)
|
302 |
+
x = torch.matmul(R, x)
|
303 |
+
x = x.transpose(1, 2)
|
304 |
+
|
305 |
+
elif n_dims == 3:
|
306 |
+
|
307 |
+
# Build Rx
|
308 |
+
Rx = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
|
309 |
+
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
|
310 |
+
cos = torch.cos(theta)
|
311 |
+
sin = torch.sin(theta)
|
312 |
+
Rx[:, 1:2, 1:2] = cos
|
313 |
+
Rx[:, 1:2, 2:3] = sin
|
314 |
+
Rx[:, 2:3, 1:2] = - sin
|
315 |
+
Rx[:, 2:3, 2:3] = cos
|
316 |
+
|
317 |
+
# Build Ry
|
318 |
+
Ry = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
|
319 |
+
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
|
320 |
+
cos = torch.cos(theta)
|
321 |
+
sin = torch.sin(theta)
|
322 |
+
Ry[:, 0:1, 0:1] = cos
|
323 |
+
Ry[:, 0:1, 2:3] = -sin
|
324 |
+
Ry[:, 2:3, 0:1] = sin
|
325 |
+
Ry[:, 2:3, 2:3] = cos
|
326 |
+
|
327 |
+
# Build Rz
|
328 |
+
Rz = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
|
329 |
+
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
|
330 |
+
cos = torch.cos(theta)
|
331 |
+
sin = torch.sin(theta)
|
332 |
+
Rz[:, 0:1, 0:1] = cos
|
333 |
+
Rz[:, 0:1, 1:2] = sin
|
334 |
+
Rz[:, 1:2, 0:1] = -sin
|
335 |
+
Rz[:, 1:2, 1:2] = cos
|
336 |
+
|
337 |
+
x = x.transpose(1, 2)
|
338 |
+
x = torch.matmul(Rx, x)
|
339 |
+
#x = torch.matmul(Rx.transpose(1, 2), x)
|
340 |
+
x = torch.matmul(Ry, x)
|
341 |
+
#x = torch.matmul(Ry.transpose(1, 2), x)
|
342 |
+
x = torch.matmul(Rz, x)
|
343 |
+
#x = torch.matmul(Rz.transpose(1, 2), x)
|
344 |
+
x = x.transpose(1, 2)
|
345 |
+
else:
|
346 |
+
raise Exception("Not implemented Error")
|
347 |
+
|
348 |
+
return x.contiguous()
|
src/visualizer.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import imageio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import glob
|
7 |
+
import random
|
8 |
+
|
9 |
+
from sklearn.decomposition import PCA
|
10 |
+
from src import const
|
11 |
+
from src.molecule_builder import get_bond_order
|
12 |
+
|
13 |
+
|
14 |
+
def save_xyz_file(path, one_hot, positions, node_mask, names, is_geom, suffix=''):
|
15 |
+
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
|
16 |
+
|
17 |
+
for batch_i in range(one_hot.size(0)):
|
18 |
+
mask = node_mask[batch_i].squeeze()
|
19 |
+
n_atoms = mask.sum()
|
20 |
+
atom_idx = torch.where(mask)[0]
|
21 |
+
|
22 |
+
f = open(os.path.join(path, f'{names[batch_i]}_{suffix}.xyz'), "w")
|
23 |
+
f.write("%d\n\n" % n_atoms)
|
24 |
+
atoms = torch.argmax(one_hot[batch_i], dim=1)
|
25 |
+
for atom_i in atom_idx:
|
26 |
+
atom = atoms[atom_i].item()
|
27 |
+
atom = idx2atom[atom]
|
28 |
+
f.write("%s %.9f %.9f %.9f\n" % (
|
29 |
+
atom, positions[batch_i, atom_i, 0], positions[batch_i, atom_i, 1], positions[batch_i, atom_i, 2]
|
30 |
+
))
|
31 |
+
f.close()
|
32 |
+
|
33 |
+
|
34 |
+
def load_xyz_files(path, suffix=''):
|
35 |
+
files = []
|
36 |
+
for fname in os.listdir(path):
|
37 |
+
if fname.endswith(f'_{suffix}.xyz'):
|
38 |
+
files.append(fname)
|
39 |
+
files = sorted(files, key=lambda f: -int(f.replace(f'_{suffix}.xyz', '').split('_')[-1]))
|
40 |
+
return [os.path.join(path, fname) for fname in files]
|
41 |
+
|
42 |
+
|
43 |
+
def load_molecule_xyz(file, is_geom):
|
44 |
+
atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX
|
45 |
+
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
|
46 |
+
with open(file, encoding='utf8') as f:
|
47 |
+
n_atoms = int(f.readline())
|
48 |
+
one_hot = torch.zeros(n_atoms, len(idx2atom))
|
49 |
+
charges = torch.zeros(n_atoms, 1)
|
50 |
+
positions = torch.zeros(n_atoms, 3)
|
51 |
+
f.readline()
|
52 |
+
atoms = f.readlines()
|
53 |
+
for i in range(n_atoms):
|
54 |
+
atom = atoms[i].split(' ')
|
55 |
+
atom_type = atom[0]
|
56 |
+
one_hot[i, atom2idx[atom_type]] = 1
|
57 |
+
position = torch.Tensor([float(e) for e in atom[1:]])
|
58 |
+
positions[i, :] = position
|
59 |
+
return positions, one_hot, charges
|
60 |
+
|
61 |
+
|
62 |
+
def draw_sphere(ax, x, y, z, size, color, alpha):
|
63 |
+
u = np.linspace(0, 2 * np.pi, 100)
|
64 |
+
v = np.linspace(0, np.pi, 100)
|
65 |
+
|
66 |
+
xs = size * np.outer(np.cos(u), np.sin(v))
|
67 |
+
ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
|
68 |
+
zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
|
69 |
+
ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)
|
70 |
+
|
71 |
+
|
72 |
+
def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None):
|
73 |
+
x = positions[:, 0]
|
74 |
+
y = positions[:, 1]
|
75 |
+
z = positions[:, 2]
|
76 |
+
# Hydrogen, Carbon, Nitrogen, Oxygen, Flourine
|
77 |
+
|
78 |
+
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
|
79 |
+
|
80 |
+
colors_dic = np.array(const.COLORS)
|
81 |
+
radius_dic = np.array(const.RADII)
|
82 |
+
area_dic = 1500 * radius_dic ** 2
|
83 |
+
|
84 |
+
areas = area_dic[atom_type]
|
85 |
+
radii = radius_dic[atom_type]
|
86 |
+
colors = colors_dic[atom_type]
|
87 |
+
|
88 |
+
if fragment_mask is None:
|
89 |
+
fragment_mask = torch.ones(len(x))
|
90 |
+
|
91 |
+
for i in range(len(x)):
|
92 |
+
for j in range(i + 1, len(x)):
|
93 |
+
p1 = np.array([x[i], y[i], z[i]])
|
94 |
+
p2 = np.array([x[j], y[j], z[j]])
|
95 |
+
dist = np.sqrt(np.sum((p1 - p2) ** 2))
|
96 |
+
atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
|
97 |
+
draw_edge_int = get_bond_order(atom1, atom2, dist)
|
98 |
+
line_width = (3 - 2) * 2 * 2
|
99 |
+
draw_edge = draw_edge_int > 0
|
100 |
+
if draw_edge:
|
101 |
+
if draw_edge_int == 4:
|
102 |
+
linewidth_factor = 1.5
|
103 |
+
else:
|
104 |
+
linewidth_factor = 1
|
105 |
+
linewidth_factor *= 0.5
|
106 |
+
ax.plot(
|
107 |
+
[x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
|
108 |
+
linewidth=line_width * linewidth_factor * 2,
|
109 |
+
c=hex_bg_color,
|
110 |
+
alpha=alpha
|
111 |
+
)
|
112 |
+
|
113 |
+
# from pdb import set_trace
|
114 |
+
# set_trace()
|
115 |
+
|
116 |
+
if spheres_3d:
|
117 |
+
# idx = torch.where(fragment_mask[:len(x)] == 0)[0]
|
118 |
+
# ax.scatter(
|
119 |
+
# x[idx],
|
120 |
+
# y[idx],
|
121 |
+
# z[idx],
|
122 |
+
# alpha=0.9 * alpha,
|
123 |
+
# edgecolors='#FCBA03',
|
124 |
+
# facecolors='none',
|
125 |
+
# linewidths=2,
|
126 |
+
# s=900
|
127 |
+
# )
|
128 |
+
for i, j, k, s, c, f in zip(x, y, z, radii, colors, fragment_mask):
|
129 |
+
if f == 1:
|
130 |
+
alpha = 1.0
|
131 |
+
|
132 |
+
draw_sphere(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)
|
133 |
+
|
134 |
+
else:
|
135 |
+
ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors)
|
136 |
+
|
137 |
+
|
138 |
+
def plot_data3d(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
|
139 |
+
bg='black', alpha=1., fragment_mask=None):
|
140 |
+
black = (0, 0, 0)
|
141 |
+
white = (1, 1, 1)
|
142 |
+
hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'
|
143 |
+
|
144 |
+
fig = plt.figure(figsize=(10, 10))
|
145 |
+
ax = fig.add_subplot(projection='3d')
|
146 |
+
ax.set_aspect('auto')
|
147 |
+
ax.view_init(elev=camera_elev, azim=camera_azim)
|
148 |
+
if bg == 'black':
|
149 |
+
ax.set_facecolor(black)
|
150 |
+
else:
|
151 |
+
ax.set_facecolor(white)
|
152 |
+
ax.xaxis.pane.set_alpha(0)
|
153 |
+
ax.yaxis.pane.set_alpha(0)
|
154 |
+
ax.zaxis.pane.set_alpha(0)
|
155 |
+
ax._axis3don = False
|
156 |
+
|
157 |
+
if bg == 'black':
|
158 |
+
ax.w_xaxis.line.set_color("black")
|
159 |
+
else:
|
160 |
+
ax.w_xaxis.line.set_color("white")
|
161 |
+
|
162 |
+
plot_molecule(
|
163 |
+
ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask
|
164 |
+
)
|
165 |
+
|
166 |
+
max_value = positions.abs().max().item()
|
167 |
+
axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
|
168 |
+
ax.set_xlim(-axis_lim, axis_lim)
|
169 |
+
ax.set_ylim(-axis_lim, axis_lim)
|
170 |
+
ax.set_zlim(-axis_lim, axis_lim)
|
171 |
+
dpi = 120 if spheres_3d else 50
|
172 |
+
|
173 |
+
if save_path is not None:
|
174 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
|
175 |
+
# plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)
|
176 |
+
|
177 |
+
if spheres_3d:
|
178 |
+
img = imageio.imread(save_path)
|
179 |
+
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
|
180 |
+
imageio.imsave(save_path, img_brighter)
|
181 |
+
else:
|
182 |
+
plt.show()
|
183 |
+
plt.close()
|
184 |
+
|
185 |
+
|
186 |
+
def visualize_chain(
|
187 |
+
path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None
|
188 |
+
):
|
189 |
+
files = load_xyz_files(path)
|
190 |
+
save_paths = []
|
191 |
+
|
192 |
+
# Fit PCA to the final molecule – to obtain the best orientation for visualization
|
193 |
+
positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
|
194 |
+
pca = PCA(n_components=3)
|
195 |
+
pca.fit(positions)
|
196 |
+
|
197 |
+
for i in range(len(files)):
|
198 |
+
file = files[i]
|
199 |
+
|
200 |
+
positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
|
201 |
+
atom_type = torch.argmax(one_hot, dim=1).numpy()
|
202 |
+
|
203 |
+
# Transform positions of each frame according to the best orientation of the last frame
|
204 |
+
positions = pca.transform(positions)
|
205 |
+
positions = torch.tensor(positions)
|
206 |
+
|
207 |
+
fn = file[:-4] + '.png'
|
208 |
+
plot_data3d(
|
209 |
+
positions, atom_type,
|
210 |
+
save_path=fn,
|
211 |
+
spheres_3d=spheres_3d,
|
212 |
+
alpha=alpha,
|
213 |
+
bg=bg,
|
214 |
+
camera_elev=90,
|
215 |
+
camera_azim=90,
|
216 |
+
is_geom=is_geom,
|
217 |
+
fragment_mask=fragment_mask,
|
218 |
+
)
|
219 |
+
save_paths.append(fn)
|
220 |
+
|
221 |
+
imgs = [imageio.imread(fn) for fn in save_paths]
|
222 |
+
dirname = os.path.dirname(save_paths[0])
|
223 |
+
gif_path = dirname + '/output.gif'
|
224 |
+
imageio.mimsave(gif_path, imgs, subrectangles=True)
|
225 |
+
|
226 |
+
if wandb is not None:
|
227 |
+
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})
|