Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Datasets consisting of proteins.""" | |
from typing import Dict, Mapping, Optional, Sequence | |
from alphafold.model.tf import protein_features | |
import numpy as np | |
import tensorflow.compat.v1 as tf | |
TensorDict = Dict[str, tf.Tensor] | |
def parse_tfexample( | |
raw_data: bytes, | |
features: protein_features.FeaturesMetadata, | |
key: Optional[str] = None) -> Dict[str, tf.train.Feature]: | |
"""Read a single TF Example proto and return a subset of its features. | |
Args: | |
raw_data: A serialized tf.Example proto. | |
features: A dictionary of features, mapping string feature names to a tuple | |
(dtype, shape). This dictionary should be a subset of | |
protein_features.FEATURES (or the dictionary itself for all features). | |
key: Optional string with the SSTable key of that tf.Example. This will be | |
added into features as a 'key' but only if requested in features. | |
Returns: | |
A dictionary of features mapping feature names to features. Only the given | |
features are returned, all other ones are filtered out. | |
""" | |
feature_map = { | |
k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) | |
for k, v in features.items() | |
} | |
parsed_features = tf.io.parse_single_example(raw_data, feature_map) | |
reshaped_features = parse_reshape_logic(parsed_features, features, key=key) | |
return reshaped_features | |
def _first(tensor: tf.Tensor) -> tf.Tensor: | |
"""Returns the 1st element - the input can be a tensor or a scalar.""" | |
return tf.reshape(tensor, shape=(-1,))[0] | |
def parse_reshape_logic( | |
parsed_features: TensorDict, | |
features: protein_features.FeaturesMetadata, | |
key: Optional[str] = None) -> TensorDict: | |
"""Transforms parsed serial features to the correct shape.""" | |
# Find out what is the number of sequences and the number of alignments. | |
num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) | |
if "num_alignments" in parsed_features: | |
num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) | |
else: | |
num_msa = 0 | |
if "template_domain_names" in parsed_features: | |
num_templates = tf.cast( | |
tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) | |
else: | |
num_templates = 0 | |
if key is not None and "key" in features: | |
parsed_features["key"] = [key] # Expand dims from () to (1,). | |
# Reshape the tensors according to the sequence length and num alignments. | |
for k, v in parsed_features.items(): | |
new_shape = protein_features.shape( | |
feature_name=k, | |
num_residues=num_residues, | |
msa_length=num_msa, | |
num_templates=num_templates, | |
features=features) | |
new_shape_size = tf.constant(1, dtype=tf.int32) | |
for dim in new_shape: | |
new_shape_size *= tf.cast(dim, tf.int32) | |
assert_equal = tf.assert_equal( | |
tf.size(v), new_shape_size, | |
name="assert_%s_shape_correct" % k, | |
message="The size of feature %s (%s) could not be reshaped " | |
"into %s" % (k, tf.size(v), new_shape)) | |
if "template" not in k: | |
# Make sure the feature we are reshaping is not empty. | |
assert_non_empty = tf.assert_greater( | |
tf.size(v), 0, name="assert_%s_non_empty" % k, | |
message="The feature %s is not set in the tf.Example. Either do not " | |
"request the feature or use a tf.Example that has the " | |
"feature set." % k) | |
with tf.control_dependencies([assert_non_empty, assert_equal]): | |
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) | |
else: | |
with tf.control_dependencies([assert_equal]): | |
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) | |
return parsed_features | |
def _make_features_metadata( | |
feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: | |
"""Makes a feature name to type and shape mapping from a list of names.""" | |
# Make sure these features are always read. | |
required_features = ["aatype", "sequence", "seq_length"] | |
feature_names = list(set(feature_names) | set(required_features)) | |
features_metadata = {name: protein_features.FEATURES[name] | |
for name in feature_names} | |
return features_metadata | |
def create_tensor_dict( | |
raw_data: bytes, | |
features: Sequence[str], | |
key: Optional[str] = None, | |
) -> TensorDict: | |
"""Creates a dictionary of tensor features. | |
Args: | |
raw_data: A serialized tf.Example proto. | |
features: A list of strings of feature names to be returned in the dataset. | |
key: Optional string with the SSTable key of that tf.Example. This will be | |
added into features as a 'key' but only if requested in features. | |
Returns: | |
A dictionary of features mapping feature names to features. Only the given | |
features are returned, all other ones are filtered out. | |
""" | |
features_metadata = _make_features_metadata(features) | |
return parse_tfexample(raw_data, features_metadata, key) | |
def np_to_tensor_dict( | |
np_example: Mapping[str, np.ndarray], | |
features: Sequence[str], | |
) -> TensorDict: | |
"""Creates dict of tensors from a dict of NumPy arrays. | |
Args: | |
np_example: A dict of NumPy feature arrays. | |
features: A list of strings of feature names to be returned in the dataset. | |
Returns: | |
A dictionary of features mapping feature names to features. Only the given | |
features are returned, all other ones are filtered out. | |
""" | |
features_metadata = _make_features_metadata(features) | |
tensor_dict = {k: tf.constant(v) for k, v in np_example.items() | |
if k in features_metadata} | |
# Ensures shapes are as expected. Needed for setting size of empty features | |
# e.g. when no template hits were found. | |
tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) | |
return tensor_dict | |