import os from collections import Counter from functools import partial import numpy as np import pandas as pd import scipy.sparse import torch from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import MACCSkeys from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan from rdkit.Chem.QED import qed from rdkit.Chem.Scaffolds import MurckoScaffold from rdkit.Chem import Descriptors from moses.metrics.SA_Score import sascorer from moses.metrics.NP_Score import npscorer from moses.utils import mapper, get_mol _base_dir = os.path.split(__file__)[0] _mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv')) _pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'), names=['smarts', 'names']) _filters = [Chem.MolFromSmarts(x) for x in _mcf.append(_pains, sort=True)['smarts'].values] def canonic_smiles(smiles_or_mol): mol = get_mol(smiles_or_mol) if mol is None: return None return Chem.MolToSmiles(mol) def logP(mol): """ Computes RDKit's logP """ return Chem.Crippen.MolLogP(mol) def SA(mol): """ Computes RDKit's Synthetic Accessibility score """ return sascorer.calculateScore(mol) def NP(mol): """ Computes RDKit's Natural Product-likeness score """ return npscorer.scoreMol(mol) def QED(mol): """ Computes RDKit's QED score """ return qed(mol) def weight(mol): """ Computes molecular weight for given molecule. Returns float, """ return Descriptors.MolWt(mol) def get_n_rings(mol): """ Computes the number of rings in a molecule """ return mol.GetRingInfo().NumRings() def fragmenter(mol): """ fragment mol using BRICS and return smiles list """ fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol)) fgs_smi = Chem.MolToSmiles(fgs).split(".") return fgs_smi def compute_fragments(mol_list, n_jobs=1): """ fragment list of mols using BRICS and return smiles list """ fragments = Counter() for mol_frag in mapper(n_jobs)(fragmenter, mol_list): fragments.update(mol_frag) return fragments def compute_scaffolds(mol_list, n_jobs=1, min_rings=2): """ Extracts a scafold from a molecule in a form of a canonic SMILES """ scaffolds = Counter() map_ = mapper(n_jobs) scaffolds = Counter( map_(partial(compute_scaffold, min_rings=min_rings), mol_list)) if None in scaffolds: scaffolds.pop(None) return scaffolds def compute_scaffold(mol, min_rings=2): mol = get_mol(mol) try: scaffold = MurckoScaffold.GetScaffoldForMol(mol) except (ValueError, RuntimeError): return None n_rings = get_n_rings(scaffold) scaffold_smiles = Chem.MolToSmiles(scaffold) if scaffold_smiles == '' or n_rings < min_rings: return None return scaffold_smiles def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1): """ For each molecule in gen_vecs finds closest molecule in stock_vecs. Returns average tanimoto score for between these molecules Parameters: stock_vecs: numpy array gen_vecs: numpy array agg: max or mean p: power for averaging: (mean x^p)^(1/p) """ assert agg in ['max', 'mean'], "Can aggregate only max or mean" agg_tanimoto = np.zeros(len(gen_vecs)) total = np.zeros(len(gen_vecs)) for j in range(0, stock_vecs.shape[0], batch_size): x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() for i in range(0, gen_vecs.shape[0], batch_size): y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() y_gen = y_gen.transpose(0, 1) tp = torch.mm(x_stock, y_gen) jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() jac[np.isnan(jac)] = 1 if p != 1: jac = jac**p if agg == 'max': agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) elif agg == 'mean': agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) total[i:i + y_gen.shape[1]] += jac.shape[0] if agg == 'mean': agg_tanimoto /= total if p != 1: agg_tanimoto = (agg_tanimoto)**(1/p) return np.mean(agg_tanimoto) def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2, morgan__n=1024, *args, **kwargs): """ Generates fingerprint for SMILES If smiles is invalid, returns None Returns numpy array of fingerprint bits Parameters: smiles: SMILES string type: type of fingerprint: [MACCS|morgan] dtype: if not None, specifies the dtype of returned array """ fp_type = fp_type.lower() molecule = get_mol(smiles_or_mol, *args, **kwargs) if molecule is None: return None if fp_type == 'maccs': keys = MACCSkeys.GenMACCSKeys(molecule) keys = np.array(keys.GetOnBits()) fingerprint = np.zeros(166, dtype='uint8') if len(keys) != 0: fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero elif fp_type == 'morgan': fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n), dtype='uint8') else: raise ValueError("Unknown fingerprint type {}".format(fp_type)) if dtype is not None: fingerprint = fingerprint.astype(dtype) return fingerprint def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args, **kwargs): ''' Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10) Inserts np.NaN to rows corresponding to incorrect smiles. IMPORTANT: if there is at least one np.NaN, the dtype would be float Parameters: smiles_mols_array: list/array/pd.Series of smiles or already computed RDKit molecules n_jobs: number of parralel workers to execute already_unique: flag for performance reasons, if smiles array is big and already unique. Its value is set to True if smiles_mols_array contain RDKit molecules already. ''' if isinstance(smiles_mols_array, pd.Series): smiles_mols_array = smiles_mols_array.values else: smiles_mols_array = np.asarray(smiles_mols_array) if not isinstance(smiles_mols_array[0], str): already_unique = True if not already_unique: smiles_mols_array, inv_index = np.unique(smiles_mols_array, return_inverse=True) fps = mapper(n_jobs)( partial(fingerprint, *args, **kwargs), smiles_mols_array ) length = 1 for fp in fps: if fp is not None: length = fp.shape[-1] first_fp = fp break fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :] for fp in fps] if scipy.sparse.issparse(first_fp): fps = scipy.sparse.vstack(fps).tocsr() else: fps = np.vstack(fps) if not already_unique: return fps[inv_index] return fps def mol_passes_filters(mol, allowed=None, isomericSmiles=False): """ Checks if mol * passes MCF and PAINS filters, * has only allowed atoms * is not charged """ allowed = allowed or {'C', 'N', 'S', 'O', 'F', 'Cl', 'Br', 'H'} mol = get_mol(mol) if mol is None: return False ring_info = mol.GetRingInfo() if ring_info.NumRings() != 0 and any( len(x) >= 8 for x in ring_info.AtomRings() ): return False h_mol = Chem.AddHs(mol) if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()): return False if any(atom.GetSymbol() not in allowed for atom in mol.GetAtoms()): return False if any(h_mol.HasSubstructMatch(smarts) for smarts in _filters): return False smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles) if smiles is None or len(smiles) == 0: return False if Chem.MolFromSmiles(smiles) is None: return False return True