|
import logging |
|
import os |
|
import sys |
|
import time |
|
import accelerate |
|
from absl import app |
|
import gin |
|
from internal import configs |
|
from internal import datasets |
|
from internal import image |
|
from internal import models |
|
from internal import raw_utils |
|
from internal import ref_utils |
|
from internal import train_utils |
|
from internal import checkpoints |
|
from internal import utils |
|
from internal import vis |
|
import numpy as np |
|
import torch |
|
import tensorboardX |
|
from torch.utils._pytree import tree_map |
|
|
|
configs.define_common_flags() |
|
|
|
|
|
def summarize_results(folder, scene_names, num_buckets): |
|
metric_names = ['psnrs', 'ssims', 'lpips'] |
|
num_iters = 1000000 |
|
precisions = [3, 4, 4, 4] |
|
|
|
results = [] |
|
for scene_name in scene_names: |
|
test_preds_folder = os.path.join(folder, scene_name, 'test_preds') |
|
values = [] |
|
for metric_name in metric_names: |
|
filename = os.path.join(folder, scene_name, 'test_preds', f'{metric_name}_{num_iters}.txt') |
|
with utils.open_file(filename) as f: |
|
v = np.array([float(s) for s in f.readline().split(' ')]) |
|
values.append(np.mean(np.reshape(v, [-1, num_buckets]), 0)) |
|
results.append(np.concatenate(values)) |
|
avg_results = np.mean(np.array(results), 0) |
|
|
|
psnr, ssim, lpips = np.mean(np.reshape(avg_results, [-1, num_buckets]), 1) |
|
|
|
mse = np.exp(-0.1 * np.log(10.) * psnr) |
|
dssim = np.sqrt(1 - ssim) |
|
avg_avg = np.exp(np.mean(np.log(np.array([mse, dssim, lpips])))) |
|
|
|
s = [] |
|
for i, v in enumerate(np.reshape(avg_results, [-1, num_buckets])): |
|
s.append(' '.join([f'{s:0.{precisions[i]}f}' for s in v])) |
|
s.append(f'{avg_avg:0.{precisions[-1]}f}') |
|
return ' | '.join(s) |
|
|
|
|
|
def main(unused_argv): |
|
config = configs.load_config() |
|
config.exp_path = os.path.join('exp', config.exp_name) |
|
config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') |
|
config.render_dir = os.path.join(config.exp_path, 'render') |
|
|
|
accelerator = accelerate.Accelerator() |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s: %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
force=True, |
|
handlers=[logging.StreamHandler(sys.stdout), |
|
logging.FileHandler(os.path.join(config.exp_path, 'log_eval.txt'))], |
|
level=logging.INFO, |
|
) |
|
sys.excepthook = utils.handle_exception |
|
logger = accelerate.logging.get_logger(__name__) |
|
logger.info(config) |
|
logger.info(accelerator.state, main_process_only=False) |
|
|
|
config.world_size = accelerator.num_processes |
|
config.global_rank = accelerator.process_index |
|
accelerate.utils.set_seed(config.seed, device_specific=True) |
|
model = models.Model(config=config) |
|
model.eval() |
|
model.to(accelerator.device) |
|
|
|
dataset = datasets.load_dataset('test', config.data_dir, config) |
|
dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), |
|
shuffle=False, |
|
batch_size=1, |
|
collate_fn=dataset.collate_fn, |
|
) |
|
tb_process_fn = lambda x: x.transpose(2, 0, 1) if len(x.shape) == 3 else x[None] |
|
if config.rawnerf_mode: |
|
postprocess_fn = dataset.metadata['postprocess_fn'] |
|
else: |
|
postprocess_fn = lambda z: z |
|
|
|
if config.eval_raw_affine_cc: |
|
cc_fun = raw_utils.match_images_affine |
|
else: |
|
cc_fun = image.color_correct |
|
|
|
model = accelerator.prepare(model) |
|
|
|
metric_harness = image.MetricHarness() |
|
|
|
last_step = 0 |
|
out_dir = os.path.join(config.exp_path, |
|
'path_renders' if config.render_path else 'test_preds') |
|
path_fn = lambda x: os.path.join(out_dir, x) |
|
|
|
if not config.eval_only_once: |
|
summary_writer = tensorboardX.SummaryWriter( |
|
os.path.join(config.exp_path, 'eval')) |
|
while True: |
|
step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) |
|
if step <= last_step: |
|
logger.info(f'Checkpoint step {step} <= last step {last_step}, sleeping.') |
|
time.sleep(10) |
|
continue |
|
logger.info(f'Evaluating checkpoint at step {step}.') |
|
if config.eval_save_output and (not utils.isdir(out_dir)): |
|
utils.makedirs(out_dir) |
|
|
|
num_eval = min(dataset.size, config.eval_dataset_limit) |
|
perm = np.random.permutation(num_eval) |
|
showcase_indices = np.sort(perm[:config.num_showcase_images]) |
|
metrics = [] |
|
metrics_cc = [] |
|
showcases = [] |
|
render_times = [] |
|
for idx, batch in enumerate(dataloader): |
|
batch = accelerate.utils.send_to_device(batch, accelerator.device) |
|
eval_start_time = time.time() |
|
if idx >= num_eval: |
|
logger.info(f'Skipping image {idx + 1}/{dataset.size}') |
|
continue |
|
logger.info(f'Evaluating image {idx + 1}/{dataset.size}') |
|
rendering = models.render_image(model, accelerator, |
|
batch, False, 1, config) |
|
|
|
if not accelerator.is_main_process: |
|
continue |
|
|
|
render_times.append((time.time() - eval_start_time)) |
|
logger.info(f'Rendered in {render_times[-1]:0.3f}s') |
|
|
|
cc_start_time = time.time() |
|
rendering['rgb_cc'] = cc_fun(rendering['rgb'], batch['rgb']) |
|
|
|
rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering) |
|
batch = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, batch) |
|
|
|
gt_rgb = batch['rgb'] |
|
logger.info(f'Color corrected in {(time.time() - cc_start_time):0.3f}s') |
|
|
|
if not config.eval_only_once and idx in showcase_indices: |
|
showcase_idx = idx if config.deterministic_showcase else len(showcases) |
|
showcases.append((showcase_idx, rendering, batch)) |
|
if not config.render_path: |
|
rgb = postprocess_fn(rendering['rgb']) |
|
rgb_cc = postprocess_fn(rendering['rgb_cc']) |
|
rgb_gt = postprocess_fn(gt_rgb) |
|
|
|
if config.eval_quantize_metrics: |
|
|
|
rgb = np.round(rgb * 255) / 255 |
|
rgb_cc = np.round(rgb_cc * 255) / 255 |
|
|
|
if config.eval_crop_borders > 0: |
|
crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c] |
|
rgb = crop_fn(rgb) |
|
rgb_cc = crop_fn(rgb_cc) |
|
rgb_gt = crop_fn(rgb_gt) |
|
|
|
metric = metric_harness(rgb, rgb_gt) |
|
metric_cc = metric_harness(rgb_cc, rgb_gt) |
|
|
|
if config.compute_disp_metrics: |
|
for tag in ['mean', 'median']: |
|
key = f'distance_{tag}' |
|
if key in rendering: |
|
disparity = 1 / (1 + rendering[key]) |
|
metric[f'disparity_{tag}_mse'] = float( |
|
((disparity - batch['disps']) ** 2).mean()) |
|
|
|
if config.compute_normal_metrics: |
|
weights = rendering['acc'] * batch['alphas'] |
|
normalized_normals_gt = ref_utils.l2_normalize_np(batch['normals']) |
|
for key, val in rendering.items(): |
|
if key.startswith('normals') and val is not None: |
|
normalized_normals = ref_utils.l2_normalize_np(val) |
|
metric[key + '_mae'] = ref_utils.compute_weighted_mae_np( |
|
weights, normalized_normals, normalized_normals_gt) |
|
|
|
for m, v in metric.items(): |
|
logger.info(f'{m:30s} = {v:.4f}') |
|
|
|
metrics.append(metric) |
|
metrics_cc.append(metric_cc) |
|
|
|
if config.eval_save_output and (config.eval_render_interval > 0): |
|
if (idx % config.eval_render_interval) == 0: |
|
utils.save_img_u8(postprocess_fn(rendering['rgb']), |
|
path_fn(f'color_{idx:03d}.png')) |
|
utils.save_img_u8(postprocess_fn(rendering['rgb_cc']), |
|
path_fn(f'color_cc_{idx:03d}.png')) |
|
|
|
for key in ['distance_mean', 'distance_median']: |
|
if key in rendering: |
|
utils.save_img_f32(rendering[key], |
|
path_fn(f'{key}_{idx:03d}.tiff')) |
|
|
|
for key in ['normals']: |
|
if key in rendering: |
|
utils.save_img_u8(rendering[key] / 2. + 0.5, |
|
path_fn(f'{key}_{idx:03d}.png')) |
|
|
|
utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff')) |
|
|
|
if (not config.eval_only_once) and accelerator.is_main_process: |
|
summary_writer.add_scalar('eval_median_render_time', np.median(render_times), |
|
step) |
|
for name in metrics[0]: |
|
scores = [m[name] for m in metrics] |
|
summary_writer.add_scalar('eval_metrics/' + name, np.mean(scores), step) |
|
summary_writer.add_histogram('eval_metrics/' + 'perimage_' + name, scores, |
|
step) |
|
for name in metrics_cc[0]: |
|
scores = [m[name] for m in metrics_cc] |
|
summary_writer.add_scalar('eval_metrics_cc/' + name, np.mean(scores), step) |
|
summary_writer.add_histogram('eval_metrics_cc/' + 'perimage_' + name, |
|
scores, step) |
|
|
|
for i, r, b in showcases: |
|
if config.vis_decimate > 1: |
|
d = config.vis_decimate |
|
decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d] |
|
else: |
|
decimate_fn = lambda x: x |
|
r = tree_map(decimate_fn, r) |
|
b = tree_map(decimate_fn, b) |
|
visualizations = vis.visualize_suite(r, b) |
|
for k, v in visualizations.items(): |
|
if k == 'color': |
|
v = postprocess_fn(v) |
|
summary_writer.add_image(f'output_{k}_{i}', tb_process_fn(v), step) |
|
if not config.render_path: |
|
target = postprocess_fn(b['rgb']) |
|
summary_writer.add_image(f'true_color_{i}', tb_process_fn(target), step) |
|
pred = postprocess_fn(visualizations['color']) |
|
residual = np.clip(pred - target + 0.5, 0, 1) |
|
summary_writer.add_image(f'true_residual_{i}', tb_process_fn(residual), step) |
|
if config.compute_normal_metrics: |
|
summary_writer.add_image(f'true_normals_{i}', tb_process_fn(b['normals']) / 2. + 0.5, |
|
step) |
|
|
|
if (config.eval_save_output and (not config.render_path) and |
|
accelerator.is_main_process): |
|
with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f: |
|
f.write(' '.join([str(r) for r in render_times])) |
|
logger.info(f'metrics:') |
|
results = {} |
|
num_buckets = config.multiscale_levels if config.multiscale else 1 |
|
for name in metrics[0]: |
|
with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f: |
|
ms = [m[name] for m in metrics] |
|
f.write(' '.join([str(m) for m in ms])) |
|
results[name] = ' | '.join( |
|
list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist()))) |
|
with utils.open_file(path_fn(f'metric_avg_{step}.txt'), 'w') as f: |
|
for name in metrics[0]: |
|
f.write(f'{name}: {results[name]}\n') |
|
logger.info(f'{name}: {results[name]}') |
|
logger.info(f'metrics_cc:') |
|
results_cc = {} |
|
for name in metrics_cc[0]: |
|
with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f: |
|
ms = [m[name] for m in metrics_cc] |
|
f.write(' '.join([str(m) for m in ms])) |
|
results_cc[name] = ' | '.join( |
|
list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist()))) |
|
with utils.open_file(path_fn(f'metric_cc_avg_{step}.txt'), 'w') as f: |
|
for name in metrics[0]: |
|
f.write(f'{name}: {results_cc[name]}\n') |
|
logger.info(f'{name}: {results_cc[name]}') |
|
if config.eval_save_ray_data: |
|
for i, r, b in showcases: |
|
rays = {k: v for k, v in r.items() if 'ray_' in k} |
|
np.set_printoptions(threshold=sys.maxsize) |
|
with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f: |
|
f.write(repr(rays)) |
|
|
|
if config.eval_only_once: |
|
break |
|
if config.early_exit_steps is not None: |
|
num_steps = config.early_exit_steps |
|
else: |
|
num_steps = config.max_steps |
|
if int(step) >= num_steps: |
|
break |
|
last_step = step |
|
logger.info('Finish evaluation.') |
|
|
|
|
|
if __name__ == '__main__': |
|
with gin.config_scope('eval'): |
|
app.run(main) |
|
|