# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Evaluate the frame interpolation model from a tfrecord and store results. This script runs the inference on examples in a tfrecord and generates images and numeric results according to the gin config. For details, see the run_evaluation() function below. Usage example: python3 -m frame_interpolation.eval.eval_cli -- \ --gin_config \ --base_folder \ --label < the foldername of the training session> or python3 -m frame_interpolation.eval.eval_cli -- \ --gin_config \ --model_path The output is saved at the parent directory of the `model_path`: /batch_eval. The evaluation is run on a GPU by default. Add the `--mode` argument for others. """ import collections import os from typing import Any, Dict from . import util from absl import app from absl import flags from absl import logging import gin.tf from ..losses import losses import numpy as np import tensorflow as tf from ..training import data_lib _GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.') _LABEL = flags.DEFINE_string( 'label', None, 'Descriptive label for the training session to eval.') _BASE_FOLDER = flags.DEFINE_string('base_folder', None, 'Root folder of training sessions.') _MODEL_PATH = flags.DEFINE_string( name='model_path', default=None, help='The path of the TF2 saved model to use. If _MODEL_PATH argument is ' 'directly specified, _LABEL and _BASE_FOLDER arguments will be ignored.') _OUTPUT_FRAMES = flags.DEFINE_boolean( name='output_frames', default=False, help='If true, saves the the inputs, groud-truth and interpolated frames.') _MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'], 'Device to run evaluations.') @gin.configurable('experiment') def _get_experiment_config(name) -> Dict[str, Any]: """Fetches the gin config.""" return { 'name': name, } def _set_visible_devices(): """Set the visible devices according to running mode.""" mode_devices = tf.config.list_physical_devices(_MODE.value.upper()) tf.config.set_visible_devices([], 'GPU') tf.config.set_visible_devices([], 'TPU') tf.config.set_visible_devices(mode_devices, _MODE.value.upper()) return @gin.configurable('evaluation') def run_evaluation(model_path, tfrecord, output_dir, max_examples, metrics): """Runs the eval loop for examples in the tfrecord. The evaluation is run for the first 'max_examples' number of examples, and resulting images are stored into the given output_dir. Any tensor that appears like an image is stored with its name -- this may include intermediate results, depending on what the model outputs. Additionally, numeric results are stored into results.csv file within the same directory. This includes per-example metrics and the mean across the whole dataset. Args: model_path: Directory TF2 saved model. tfrecord: Directory to the tfrecord eval data. output_dir: Directory to store the results into. max_examples: Maximum examples to evaluate. metrics: The names of loss functions to use. """ model = tf.saved_model.load(model_path) # Store a 'readme.txt' that contains information on where the data came from. with tf.io.gfile.GFile(os.path.join(output_dir, 'readme.txt'), mode='w') as f: print('Results for:', file=f) print(f' model: {model_path}', file=f) print(f' tfrecord: {tfrecord}', file=f) with tf.io.gfile.GFile( os.path.join(output_dir, 'results.csv'), mode='w') as csv_file: test_losses = losses.test_losses(metrics, [ 1.0, ] * len(metrics)) title_row = ['key'] + list(test_losses) print(', '.join(title_row), file=csv_file) datasets = data_lib.create_eval_datasets( batch_size=1, files=[tfrecord], names=[os.path.basename(output_dir)], max_examples=max_examples) dataset = datasets[os.path.basename(output_dir)] all_losses = collections.defaultdict(list) for example in dataset: inputs = { 'x0': example['x0'], 'x1': example['x1'], 'time': example['time'][..., tf.newaxis], } prediction = model(inputs, training=False) # Get the key from encoded mid-frame path. path = example['path'][0].numpy().decode('utf-8') key = path.rsplit('.', 1)[0].rsplit(os.sep)[-1] # Combines both inputs and outputs into a single dictionary: combined = {**prediction, **example} if _OUTPUT_FRAMES.value else {} for name in combined: image = combined[name] if isinstance(image, tf.Tensor): # This saves any tensor that has a shape that can be interpreted # as an image, e.g. (1, H, W, C), where the batch dimension is always # 1, H and W are the image height and width, and C is either 1 or 3 # (grayscale or color image). if len(image.shape) == 4 and (image.shape[-1] == 1 or image.shape[-1] == 3): util.write_image( os.path.join(output_dir, f'{key}_{name}.png'), image[0].numpy()) # Evaluate losses if the dataset has ground truth 'y', otherwise just do # a visual eval. if 'y' in example: loss_values = [] # Clip interpolator output to the range [0,1]. Clipping is done only # on the eval loop to get better metrics, but not on the training loop # so gradients are not killed. prediction['image'] = tf.clip_by_value(prediction['image'], 0., 1.) for loss_name, (loss_value_fn, loss_weight_fn) in test_losses.items(): loss_value = loss_value_fn(example, prediction) * loss_weight_fn(0) loss_values.append(loss_value.numpy()) all_losses[loss_name].append(loss_value.numpy()) print(f'{key}, {str(loss_values)[1:-1]}', file=csv_file) if all_losses: totals = [np.mean(all_losses[loss_name]) for loss_name in test_losses] print(f'mean, {str(totals)[1:-1]}', file=csv_file) totals_dict = { loss_name: np.mean(all_losses[loss_name]) for loss_name in test_losses } logging.info('mean, %s', totals_dict) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if _MODEL_PATH.value is not None: model_path = _MODEL_PATH.value else: model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'saved_model') gin.parse_config_files_and_bindings( config_files=[_GIN_CONFIG.value], bindings=None, skip_unknown=True) config = _get_experiment_config() # pylint: disable=no-value-for-parameter eval_name = config['name'] output_dir = os.path.join( os.path.dirname(model_path), 'batch_eval', eval_name) logging.info('Creating output_dir @ %s ...', output_dir) # Copy config file to /