Spaces:
Running
on
T4
Running
on
T4
import numpy as np | |
from alphafold.common import residue_constants | |
def make_atom14_positions(prot): | |
"""Constructs denser atom positions (14 dimensions instead of 37).""" | |
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 | |
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 | |
restype_atom14_mask = [] | |
for rt in residue_constants.restypes: | |
atom_names = residue_constants.restype_name_to_atom14_names[ | |
residue_constants.restype_1to3[rt]] | |
restype_atom14_to_atom37.append([ | |
(residue_constants.atom_order[name] if name else 0) | |
for name in atom_names | |
]) | |
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} | |
restype_atom37_to_atom14.append([ | |
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) | |
for name in residue_constants.atom_types | |
]) | |
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) | |
# Add dummy mapping for restype 'UNK'. | |
restype_atom14_to_atom37.append([0] * 14) | |
restype_atom37_to_atom14.append([0] * 37) | |
restype_atom14_mask.append([0.] * 14) | |
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) | |
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) | |
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) | |
# Create the mapping for (residx, atom14) --> atom37, i.e. an array | |
# with shape (num_res, 14) containing the atom37 indices for this protein. | |
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]] | |
residx_atom14_mask = restype_atom14_mask[prot["aatype"]] | |
# Create a mask for known ground truth positions. | |
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis( | |
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32) | |
# Gather the ground truth positions. | |
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * ( | |
np.take_along_axis(prot["all_atom_positions"], | |
residx_atom14_to_atom37[..., None], | |
axis=1)) | |
prot["atom14_atom_exists"] = residx_atom14_mask | |
prot["atom14_gt_exists"] = residx_atom14_gt_mask | |
prot["atom14_gt_positions"] = residx_atom14_gt_positions | |
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37 | |
# Create the gather indices for mapping back. | |
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]] | |
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14 | |
# Create the corresponding mask. | |
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) | |
for restype, restype_letter in enumerate(residue_constants.restypes): | |
restype_name = residue_constants.restype_1to3[restype_letter] | |
atom_names = residue_constants.residue_atoms[restype_name] | |
for atom_name in atom_names: | |
atom_type = residue_constants.atom_order[atom_name] | |
restype_atom37_mask[restype, atom_type] = 1 | |
residx_atom37_mask = restype_atom37_mask[prot["aatype"]] | |
prot["atom37_atom_exists"] = residx_atom37_mask | |
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide | |
# alternative ground truth coordinates where the naming is swapped | |
restype_3 = [ | |
residue_constants.restype_1to3[res] for res in residue_constants.restypes | |
] | |
restype_3 += ["UNK"] | |
# Matrices for renaming ambiguous atoms. | |
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} | |
for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): | |
correspondences = np.arange(14) | |
for source_atom_swap, target_atom_swap in swap.items(): | |
source_index = residue_constants.restype_name_to_atom14_names[ | |
resname].index(source_atom_swap) | |
target_index = residue_constants.restype_name_to_atom14_names[ | |
resname].index(target_atom_swap) | |
correspondences[source_index] = target_index | |
correspondences[target_index] = source_index | |
renaming_matrix = np.zeros((14, 14), dtype=np.float32) | |
for index, correspondence in enumerate(correspondences): | |
renaming_matrix[index, correspondence] = 1. | |
all_matrices[resname] = renaming_matrix.astype(np.float32) | |
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) | |
# Pick the transformation matrices for the given residue sequence | |
# shape (num_res, 14, 14). | |
renaming_transform = renaming_matrices[prot["aatype"]] | |
# Apply it to the ground truth positions. shape (num_res, 14, 3). | |
alternative_gt_positions = np.einsum("rac,rab->rbc", | |
residx_atom14_gt_positions, | |
renaming_transform) | |
prot["atom14_alt_gt_positions"] = alternative_gt_positions | |
# Create the mask for the alternative ground truth (differs from the | |
# ground truth mask, if only one of the atoms in an ambiguous pair has a | |
# ground truth position). | |
alternative_gt_mask = np.einsum("ra,rab->rb", | |
residx_atom14_gt_mask, | |
renaming_transform) | |
prot["atom14_alt_gt_exists"] = alternative_gt_mask | |
# Create an ambiguous atoms mask. shape: (21, 14). | |
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) | |
for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): | |
for atom_name1, atom_name2 in swap.items(): | |
restype = residue_constants.restype_order[ | |
residue_constants.restype_3to1[resname]] | |
atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index( | |
atom_name1) | |
atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index( | |
atom_name2) | |
restype_atom14_is_ambiguous[restype, atom_idx1] = 1 | |
restype_atom14_is_ambiguous[restype, atom_idx2] = 1 | |
# From this create an ambiguous_mask for the given sequence. | |
prot["atom14_atom_is_ambiguous"] = ( | |
restype_atom14_is_ambiguous[prot["aatype"]]) | |
return prot |