Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""lDDT protein distance score.""" | |
import jax.numpy as jnp | |
def lddt(predicted_points, | |
true_points, | |
true_points_mask, | |
cutoff=15., | |
per_residue=False): | |
"""Measure (approximate) lDDT for a batch of coordinates. | |
lDDT reference: | |
Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local | |
superposition-free score for comparing protein structures and models using | |
distance difference tests. Bioinformatics 29, 2722–2728 (2013). | |
lDDT is a measure of the difference between the true distance matrix and the | |
distance matrix of the predicted points. The difference is computed only on | |
points closer than cutoff *in the true structure*. | |
This function does not compute the exact lDDT value that the original paper | |
describes because it does not include terms for physical feasibility | |
(e.g. bond length violations). Therefore this is only an approximate | |
lDDT score. | |
Args: | |
predicted_points: (batch, length, 3) array of predicted 3D points | |
true_points: (batch, length, 3) array of true 3D points | |
true_points_mask: (batch, length, 1) binary-valued float array. This mask | |
should be 1 for points that exist in the true points. | |
cutoff: Maximum distance for a pair of points to be included | |
per_residue: If true, return score for each residue. Note that the overall | |
lDDT is not exactly the mean of the per_residue lDDT's because some | |
residues have more contacts than others. | |
Returns: | |
An (approximate, see above) lDDT score in the range 0-1. | |
""" | |
assert len(predicted_points.shape) == 3 | |
assert predicted_points.shape[-1] == 3 | |
assert true_points_mask.shape[-1] == 1 | |
assert len(true_points_mask.shape) == 3 | |
# Compute true and predicted distance matrices. | |
dmat_true = jnp.sqrt(1e-10 + jnp.sum( | |
(true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) | |
dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( | |
(predicted_points[:, :, None] - | |
predicted_points[:, None, :])**2, axis=-1)) | |
dists_to_score = ( | |
(dmat_true < cutoff).astype(jnp.float32) * true_points_mask * | |
jnp.transpose(true_points_mask, [0, 2, 1]) * | |
(1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. | |
) | |
# Shift unscored distances to be far away. | |
dist_l1 = jnp.abs(dmat_true - dmat_predicted) | |
# True lDDT uses a number of fixed bins. | |
# We ignore the physical plausibility correction to lDDT, though. | |
score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + | |
(dist_l1 < 1.0).astype(jnp.float32) + | |
(dist_l1 < 2.0).astype(jnp.float32) + | |
(dist_l1 < 4.0).astype(jnp.float32)) | |
# Normalize over the appropriate axes. | |
reduce_axes = (-1,) if per_residue else (-2, -1) | |
norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) | |
score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) | |
return score | |