# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utility functions for creating a tf.train.Example proto of image triplets.""" import io import os from typing import Any, List, Mapping, Optional from absl import logging import apache_beam as beam import numpy as np import PIL.Image import six from skimage import transform import tensorflow as tf _UINT8_MAX_F = float(np.iinfo(np.uint8).max) _GAMMA = 2.2 def _resample_image(image: np.ndarray, resample_image_width: int, resample_image_height: int) -> np.ndarray: """Re-samples and returns an `image` to be `resample_image_size`.""" # Convert image from uint8 gamma [0..255] to float linear [0..1]. image = image.astype(np.float32) / _UINT8_MAX_F image = np.power(np.clip(image, 0, 1), _GAMMA) # Re-size the image resample_image_size = (resample_image_height, resample_image_width) image = transform.resize_local_mean(image, resample_image_size) # Convert back from float linear [0..1] to uint8 gamma [0..255]. image = np.power(np.clip(image, 0, 1), 1.0 / _GAMMA) image = np.clip(image * _UINT8_MAX_F + 0.5, 0.0, _UINT8_MAX_F).astype(np.uint8) return image def generate_image_triplet_example( triplet_dict: Mapping[str, str], scale_factor: int = 1, center_crop_factor: int = 1) -> Optional[tf.train.Example]: """Generates and serializes a tf.train.Example proto from an image triplet. Default setting creates a triplet Example with the input images unchanged. Images are processed in the order of center-crop then downscale. Args: triplet_dict: A dict of image key to filepath of the triplet images. scale_factor: An integer scale factor to isotropically downsample images. center_crop_factor: An integer cropping factor to center crop images with the original resolution but isotropically downsized by the factor. Returns: tf.train.Example proto, or None upon error. Raises: ValueError if triplet_dict length is different from three or the scale input arguments are non-positive. """ if len(triplet_dict) != 3: raise ValueError( f'Length of triplet_dict must be exactly 3, not {len(triplet_dict)}.') if scale_factor <= 0 or center_crop_factor <= 0: raise ValueError(f'(scale_factor, center_crop_factor) must be positive, ' f'Not ({scale_factor}, {center_crop_factor}).') feature = {} # Keep track of the path where the images came from for debugging purposes. mid_frame_path = os.path.dirname(triplet_dict['frame_1']) feature['path'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[six.ensure_binary(mid_frame_path)])) for image_key, image_path in triplet_dict.items(): if not tf.io.gfile.exists(image_path): logging.error('File not found: %s', image_path) return None # Note: we need both the raw bytes and the image size. # PIL.Image does not expose a method to grab the original bytes. # (Also it is not aware of non-local file systems.) # So we read with tf.io.gfile.GFile to get the bytes, and then wrap the # bytes in BytesIO to let PIL.Image open the image. try: byte_array = tf.io.gfile.GFile(image_path, 'rb').read() except tf.errors.InvalidArgumentError: logging.exception('Cannot read image file: %s', image_path) return None try: pil_image = PIL.Image.open(io.BytesIO(byte_array)) except PIL.UnidentifiedImageError: logging.exception('Cannot decode image file: %s', image_path) return None width, height = pil_image.size pil_image_format = pil_image.format # Optionally center-crop images and downsize images # by `center_crop_factor`. if center_crop_factor > 1: image = np.array(pil_image) quarter_height = image.shape[0] // (2 * center_crop_factor) quarter_width = image.shape[1] // (2 * center_crop_factor) image = image[quarter_height:-quarter_height, quarter_width:-quarter_width, :] pil_image = PIL.Image.fromarray(image) # Update image properties. height, width, _ = image.shape buffer = io.BytesIO() try: pil_image.save(buffer, format='PNG') except OSError: logging.exception('Cannot encode image file: %s', image_path) return None byte_array = buffer.getvalue() # Optionally downsample images by `scale_factor`. if scale_factor > 1: image = np.array(pil_image) image = _resample_image(image, image.shape[1] // scale_factor, image.shape[0] // scale_factor) pil_image = PIL.Image.fromarray(image) # Update image properties. height, width, _ = image.shape buffer = io.BytesIO() try: pil_image.save(buffer, format='PNG') except OSError: logging.exception('Cannot encode image file: %s', image_path) return None byte_array = buffer.getvalue() # Create tf Features. image_feature = tf.train.Feature( bytes_list=tf.train.BytesList(value=[byte_array])) height_feature = tf.train.Feature( int64_list=tf.train.Int64List(value=[height])) width_feature = tf.train.Feature( int64_list=tf.train.Int64List(value=[width])) encoding = tf.train.Feature( bytes_list=tf.train.BytesList( value=[six.ensure_binary(pil_image_format.lower())])) # Update feature map. feature[f'{image_key}/encoded'] = image_feature feature[f'{image_key}/format'] = encoding feature[f'{image_key}/height'] = height_feature feature[f'{image_key}/width'] = width_feature # Create tf Example. features = tf.train.Features(feature=feature) example = tf.train.Example(features=features) return example class ExampleGenerator(beam.DoFn): """Generate a tf.train.Example per input image triplet filepaths.""" def __init__(self, images_map: Mapping[str, Any], scale_factor: int = 1, center_crop_factor: int = 1): """Initializes the map of 3 images to add to each tf.train.Example. Args: images_map: Map from image key to image filepath. scale_factor: A scale factor to downsample frames. center_crop_factor: A factor to centercrop and downsize frames. """ super().__init__() self._images_map = images_map self._scale_factor = scale_factor self._center_crop_factor = center_crop_factor def process(self, triplet_dict: Mapping[str, str]) -> List[bytes]: """Generates a serialized tf.train.Example for a triplet of images. Args: triplet_dict: A dict of image key to filepath of the triplet images. Returns: A serialized tf.train.Example proto. No shuffling is applied. """ example = generate_image_triplet_example(triplet_dict, self._scale_factor, self._center_crop_factor) if example: return [example.SerializeToString()] else: return []