Spaces:
Running
on
A10G
Running
on
A10G
import numpy as np | |
from rdkit import Chem | |
from rdkit.Chem import AllChem | |
from src import const | |
from src.molecule_builder import get_bond_order | |
from scipy.stats import wasserstein_distance | |
from pdb import set_trace | |
def is_valid(mol): | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return False | |
return True | |
def is_connected(mol): | |
try: | |
mol_frags = Chem.GetMolFrags(mol, asMols=True) | |
except Chem.rdchem.AtomValenceException: | |
return False | |
if len(mol_frags) != 1: | |
return False | |
return True | |
def get_valid_molecules(molecules): | |
valid = [] | |
for mol in molecules: | |
if is_valid(mol): | |
valid.append(mol) | |
return valid | |
def get_connected_molecules(molecules): | |
connected = [] | |
for mol in molecules: | |
if is_connected(mol): | |
connected.append(mol) | |
return connected | |
def get_unique_smiles(valid_molecules): | |
unique = set() | |
for mol in valid_molecules: | |
unique.add(Chem.MolToSmiles(mol)) | |
return list(unique) | |
def get_novel_smiles(unique_true_smiles, unique_pred_smiles): | |
return list(set(unique_pred_smiles).difference(set(unique_true_smiles))) | |
def compute_energy(mol): | |
mp = AllChem.MMFFGetMoleculeProperties(mol) | |
energy = AllChem.MMFFGetMoleculeForceField(mol, mp, confId=0).CalcEnergy() | |
return energy | |
def wasserstein_distance_between_energies(true_molecules, pred_molecules): | |
true_energy_dist = [] | |
for mol in true_molecules: | |
try: | |
energy = compute_energy(mol) | |
true_energy_dist.append(energy) | |
except: | |
continue | |
pred_energy_dist = [] | |
for mol in pred_molecules: | |
try: | |
energy = compute_energy(mol) | |
pred_energy_dist.append(energy) | |
except: | |
continue | |
if len(true_energy_dist) > 0 and len(pred_energy_dist) > 0: | |
return wasserstein_distance(true_energy_dist, pred_energy_dist) | |
else: | |
return 0 | |
def compute_metrics(pred_molecules, true_molecules): | |
if len(pred_molecules) == 0: | |
return { | |
'validity': 0, | |
'validity_and_connectivity': 0, | |
'validity_as_in_delinker': 0, | |
'uniqueness': 0, | |
'novelty': 0, | |
'energies': 0, | |
} | |
# Passing rdkit.Chem.Sanitize filter | |
true_valid = get_valid_molecules(true_molecules) | |
pred_valid = get_valid_molecules(pred_molecules) | |
validity = len(pred_valid) / len(pred_molecules) | |
# Checking if molecule consists of a single connected part | |
true_valid_and_connected = get_connected_molecules(true_valid) | |
pred_valid_and_connected = get_connected_molecules(pred_valid) | |
validity_and_connectivity = len(pred_valid_and_connected) / len(pred_molecules) | |
# Unique molecules | |
true_unique = get_unique_smiles(true_valid_and_connected) | |
pred_unique = get_unique_smiles(pred_valid_and_connected) | |
uniqueness = len(pred_unique) / len(pred_valid_and_connected) if len(pred_valid_and_connected) > 0 else 0 | |
# Novel molecules | |
pred_novel = get_novel_smiles(true_unique, pred_unique) | |
novelty = len(pred_novel) / len(pred_unique) if len(pred_unique) > 0 else 0 | |
# Difference between Energy distributions | |
energies = wasserstein_distance_between_energies(true_valid_and_connected, pred_valid_and_connected) | |
return { | |
'validity': validity, | |
'validity_and_connectivity': validity_and_connectivity, | |
'uniqueness': uniqueness, | |
'novelty': novelty, | |
'energies': energies, | |
} | |
# def check_stability(positions, atom_types): | |
# assert len(positions.shape) == 2 | |
# assert positions.shape[1] == 3 | |
# x = positions[:, 0] | |
# y = positions[:, 1] | |
# z = positions[:, 2] | |
# | |
# nr_bonds = np.zeros(len(x), dtype='int') | |
# for i in range(len(x)): | |
# for j in range(i + 1, len(x)): | |
# p1 = np.array([x[i], y[i], z[i]]) | |
# p2 = np.array([x[j], y[j], z[j]]) | |
# dist = np.sqrt(np.sum((p1 - p2) ** 2)) | |
# atom1, atom2 = const.IDX2ATOM[atom_types[i].item()], const.IDX2ATOM[atom_types[j].item()] | |
# order = get_bond_order(atom1, atom2, dist) | |
# nr_bonds[i] += order | |
# nr_bonds[j] += order | |
# nr_stable_bonds = 0 | |
# for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds): | |
# possible_bonds = const.ALLOWED_BONDS[const.IDX2ATOM[atom_type_i.item()]] | |
# if type(possible_bonds) == int: | |
# is_stable = possible_bonds == nr_bonds_i | |
# else: | |
# is_stable = nr_bonds_i in possible_bonds | |
# nr_stable_bonds += int(is_stable) | |
# | |
# molecule_stable = nr_stable_bonds == len(x) | |
# return molecule_stable, nr_stable_bonds, len(x) | |
# | |
# | |
# def count_stable_molecules(one_hot, x, node_mask): | |
# stable_molecules = 0 | |
# for i in range(len(one_hot)): | |
# mol_size = node_mask[i].sum() | |
# atom_types = one_hot[i][:mol_size, :].argmax(dim=1).detach().cpu() | |
# positions = x[i][:mol_size, :].detach().cpu() | |
# stable, _, _ = check_stability(positions, atom_types) | |
# stable_molecules += int(stable) | |
# | |
# return stable_molecules | |