|
import glob |
|
import logging |
|
import os |
|
import shutil |
|
import sys |
|
|
|
import numpy as np |
|
import random |
|
|
|
import time |
|
|
|
from absl import app |
|
import gin |
|
from internal import configs |
|
from internal import datasets |
|
from internal import image |
|
from internal import models |
|
from internal import train_utils |
|
from internal import utils |
|
from internal import vis |
|
from internal import checkpoints |
|
import torch |
|
import accelerate |
|
import tensorboardX |
|
from tqdm import tqdm |
|
from tqdm.contrib.logging import logging_redirect_tqdm |
|
from torch.utils._pytree import tree_map |
|
|
|
configs.define_common_flags() |
|
|
|
TIME_PRECISION = 1000 |
|
|
|
|
|
|
|
def main(unused_argv): |
|
config = configs.load_config() |
|
config.exp_path = os.path.join("exp", config.exp_name) |
|
config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') |
|
utils.makedirs(config.exp_path) |
|
with utils.open_file(os.path.join(config.exp_path, 'config.gin'), 'w') as f: |
|
f.write(gin.config_str()) |
|
|
|
|
|
accelerator = accelerate.Accelerator() |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s: %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
force=True, |
|
handlers=[logging.StreamHandler(sys.stdout), |
|
logging.FileHandler(os.path.join(config.exp_path, 'log_train.txt'))], |
|
level=logging.INFO, |
|
) |
|
sys.excepthook = utils.handle_exception |
|
logger = accelerate.logging.get_logger(__name__) |
|
logger.info(config) |
|
logger.info(accelerator.state, main_process_only=False) |
|
|
|
config.world_size = accelerator.num_processes |
|
config.global_rank = accelerator.process_index |
|
if config.batch_size % accelerator.num_processes != 0: |
|
config.batch_size -= config.batch_size % accelerator.num_processes != 0 |
|
logger.info('turn batch size to', config.batch_size) |
|
|
|
|
|
accelerate.utils.set_seed(config.seed, device_specific=True) |
|
|
|
model = models.Model(config=config) |
|
optimizer, lr_fn = train_utils.create_optimizer(config, model) |
|
|
|
|
|
dataset = datasets.load_dataset('train', config.data_dir, config) |
|
test_dataset = datasets.load_dataset('test', config.data_dir, config) |
|
dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), |
|
num_workers=8, |
|
shuffle=True, |
|
batch_size=1, |
|
collate_fn=dataset.collate_fn, |
|
persistent_workers=True, |
|
) |
|
test_dataloader = torch.utils.data.DataLoader(np.arange(len(test_dataset)), |
|
num_workers=4, |
|
shuffle=False, |
|
batch_size=1, |
|
persistent_workers=True, |
|
collate_fn=test_dataset.collate_fn, |
|
) |
|
if config.rawnerf_mode: |
|
postprocess_fn = test_dataset.metadata['postprocess_fn'] |
|
else: |
|
postprocess_fn = lambda z, _=None: z |
|
|
|
|
|
model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer) |
|
|
|
if config.resume_from_checkpoint: |
|
init_step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) |
|
else: |
|
init_step = 0 |
|
|
|
module = accelerator.unwrap_model(model) |
|
dataiter = iter(dataloader) |
|
test_dataiter = iter(test_dataloader) |
|
|
|
num_params = train_utils.tree_len(list(model.parameters())) |
|
logger.info(f'Number of parameters being optimized: {num_params}') |
|
|
|
if (dataset.size > module.num_glo_embeddings and module.num_glo_features > 0): |
|
raise ValueError(f'Number of glo embeddings {module.num_glo_embeddings} ' |
|
f'must be at least equal to number of train images ' |
|
f'{dataset.size}') |
|
|
|
|
|
metric_harness = image.MetricHarness() |
|
|
|
|
|
if accelerator.is_main_process: |
|
summary_writer = tensorboardX.SummaryWriter(config.exp_path) |
|
|
|
tb_process_fn = lambda x: x.transpose(2, 0, 1) if len(x.shape) == 3 else x[None] |
|
|
|
if config.rawnerf_mode: |
|
for name, data in zip(['train', 'test'], [dataset, test_dataset]): |
|
|
|
for key in ['exposure_idx', 'exposure_values', 'unique_shutters']: |
|
summary_writer.add_text(f'{name}_{key}', str(data.metadata[key]), 0) |
|
logger.info("Begin training...") |
|
step = init_step + 1 |
|
total_time = 0 |
|
total_steps = 0 |
|
reset_stats = True |
|
if config.early_exit_steps is not None: |
|
num_steps = config.early_exit_steps |
|
else: |
|
num_steps = config.max_steps |
|
init_step = 0 |
|
with logging_redirect_tqdm(): |
|
tbar = tqdm(range(init_step + 1, num_steps + 1), |
|
desc='Training', initial=init_step, total=num_steps, |
|
disable=not accelerator.is_main_process) |
|
for step in tbar: |
|
try: |
|
batch = next(dataiter) |
|
except StopIteration: |
|
dataiter = iter(dataloader) |
|
batch = next(dataiter) |
|
batch = accelerate.utils.send_to_device(batch, accelerator.device) |
|
if reset_stats and accelerator.is_main_process: |
|
stats_buffer = [] |
|
train_start_time = time.time() |
|
reset_stats = False |
|
|
|
|
|
learning_rate = lr_fn(step) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = learning_rate |
|
|
|
|
|
train_frac = np.clip((step - 1) / (config.max_steps - 1), 0, 1) |
|
|
|
|
|
compute_extras = (config.compute_disp_metrics or config.compute_normal_metrics) |
|
optimizer.zero_grad() |
|
with accelerator.autocast(): |
|
renderings, ray_history = model( |
|
True, |
|
batch, |
|
train_frac=train_frac, |
|
compute_extras=compute_extras, |
|
zero_glo=False) |
|
|
|
losses = {} |
|
|
|
|
|
data_loss, stats = train_utils.compute_data_loss(batch, renderings, config) |
|
losses['data'] = data_loss |
|
|
|
|
|
if config.interlevel_loss_mult > 0 and not module.single_mlp: |
|
losses['interlevel'] = train_utils.interlevel_loss(ray_history, config) |
|
|
|
|
|
if config.anti_interlevel_loss_mult > 0 and not module.single_mlp: |
|
losses['anti_interlevel'] = train_utils.anti_interlevel_loss(ray_history, config) |
|
|
|
|
|
if config.distortion_loss_mult > 0: |
|
losses['distortion'] = train_utils.distortion_loss(ray_history, config) |
|
|
|
|
|
if config.opacity_loss_mult > 0: |
|
losses['opacity'] = train_utils.opacity_loss(renderings, config) |
|
|
|
|
|
if (config.orientation_coarse_loss_mult > 0 or |
|
config.orientation_loss_mult > 0): |
|
losses['orientation'] = train_utils.orientation_loss(batch, module, ray_history, |
|
config) |
|
|
|
if config.hash_decay_mults > 0: |
|
losses['hash_decay'] = train_utils.hash_decay_loss(ray_history, config) |
|
|
|
|
|
if (config.predicted_normal_coarse_loss_mult > 0 or |
|
config.predicted_normal_loss_mult > 0): |
|
losses['predicted_normals'] = train_utils.predicted_normal_loss( |
|
module, ray_history, config) |
|
loss = sum(losses.values()) |
|
stats['loss'] = loss.item() |
|
stats['losses'] = tree_map(lambda x: x.item(), losses) |
|
|
|
|
|
accelerator.backward(loss) |
|
|
|
train_utils.clip_gradients(model, accelerator, config) |
|
optimizer.step() |
|
|
|
stats['psnrs'] = image.mse_to_psnr(stats['mses']) |
|
stats['psnr'] = stats['psnrs'][-1] |
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
stats_buffer.append(stats) |
|
if step == init_step + 1 or step % config.print_every == 0: |
|
elapsed_time = time.time() - train_start_time |
|
steps_per_sec = config.print_every / elapsed_time |
|
rays_per_sec = config.batch_size * steps_per_sec |
|
|
|
|
|
total_time += int(round(TIME_PRECISION * elapsed_time)) |
|
total_steps += config.print_every |
|
approx_total_time = int(round(step * total_time / total_steps)) |
|
|
|
|
|
fs = [utils.flatten_dict(s, sep='/') for s in stats_buffer] |
|
stats_stacked = {k: np.stack([f[k] for f in fs]) for k in fs[0].keys()} |
|
|
|
|
|
stats_split = {} |
|
for k, v in stats_stacked.items(): |
|
if v.ndim not in [1, 2] and v.shape[0] != len(stats_buffer): |
|
raise ValueError('statistics must be of size [n], or [n, k].') |
|
if v.ndim == 1: |
|
stats_split[k] = v |
|
elif v.ndim == 2: |
|
for i, vi in enumerate(tuple(v.T)): |
|
stats_split[f'{k}/{i}'] = vi |
|
|
|
|
|
for k, v in stats_split.items(): |
|
summary_writer.add_histogram('train_' + k, v, step) |
|
|
|
|
|
avg_stats = {k: np.mean(v) for k, v in stats_split.items()} |
|
max_stats = {k: np.max(v) for k, v in stats_split.items()} |
|
|
|
summ_fn = lambda s, v: summary_writer.add_scalar(s, v, step) |
|
|
|
|
|
for k, v in avg_stats.items(): |
|
summ_fn(f'train_avg_{k}', v) |
|
for k, v in max_stats.items(): |
|
summ_fn(f'train_max_{k}', v) |
|
|
|
summ_fn('train_num_params', num_params) |
|
summ_fn('train_learning_rate', learning_rate) |
|
summ_fn('train_steps_per_sec', steps_per_sec) |
|
summ_fn('train_rays_per_sec', rays_per_sec) |
|
|
|
summary_writer.add_scalar('train_avg_psnr_timed', avg_stats['psnr'], |
|
total_time // TIME_PRECISION) |
|
summary_writer.add_scalar('train_avg_psnr_timed_approx', avg_stats['psnr'], |
|
approx_total_time // TIME_PRECISION) |
|
|
|
if dataset.metadata is not None and module.learned_exposure_scaling: |
|
scalings = module.exposure_scaling_offsets.weight |
|
num_shutter_speeds = dataset.metadata['unique_shutters'].shape[0] |
|
for i_s in range(num_shutter_speeds): |
|
for j_s, value in enumerate(scalings[i_s]): |
|
summary_name = f'exposure/scaling_{i_s}_{j_s}' |
|
summary_writer.add_scalar(summary_name, value, step) |
|
|
|
precision = int(np.ceil(np.log10(config.max_steps))) + 1 |
|
avg_loss = avg_stats['loss'] |
|
avg_psnr = avg_stats['psnr'] |
|
str_losses = { |
|
k[7:11]: (f'{v:0.5f}' if 1e-4 <= v < 10 else f'{v:0.1e}') |
|
for k, v in avg_stats.items() |
|
if k.startswith('losses/') |
|
} |
|
logger.info(f'{step}' + f'/{config.max_steps:d}:' + |
|
f'loss={avg_loss:0.5f},' + f'psnr={avg_psnr:.3f},' + |
|
f'lr={learning_rate:0.2e} | ' + |
|
','.join([f'{k}={s}' for k, s in str_losses.items()]) + |
|
f',{rays_per_sec:0.0f} r/s') |
|
|
|
|
|
reset_stats = True |
|
|
|
if step > 0 and step % config.checkpoint_every == 0 and accelerator.is_main_process: |
|
checkpoints.save_checkpoint(config.checkpoint_dir, |
|
accelerator, step, |
|
config.checkpoints_total_limit) |
|
|
|
|
|
if config.train_render_every > 0 and step % config.train_render_every == 0: |
|
|
|
|
|
|
|
eval_start_time = time.time() |
|
try: |
|
test_batch = next(test_dataiter) |
|
except StopIteration: |
|
test_dataiter = iter(test_dataloader) |
|
test_batch = next(test_dataiter) |
|
test_batch = accelerate.utils.send_to_device(test_batch, accelerator.device) |
|
|
|
|
|
rendering = models.render_image(model, accelerator, |
|
test_batch, False, |
|
train_frac, config) |
|
|
|
|
|
rendering = tree_map(lambda x: x.detach().cpu().numpy(), rendering) |
|
test_batch = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, test_batch) |
|
|
|
if accelerator.is_main_process: |
|
eval_time = time.time() - eval_start_time |
|
num_rays = np.prod(test_batch['directions'].shape[:-1]) |
|
rays_per_sec = num_rays / eval_time |
|
summary_writer.add_scalar('test_rays_per_sec', rays_per_sec, step) |
|
|
|
metric_start_time = time.time() |
|
metric = metric_harness( |
|
postprocess_fn(rendering['rgb']), postprocess_fn(test_batch['rgb'])) |
|
logger.info(f'Eval {step}: {eval_time:0.3f}s, {rays_per_sec:0.0f} rays/sec') |
|
logger.info(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s') |
|
for name, val in metric.items(): |
|
if not np.isnan(val): |
|
logger.info(f'{name} = {val:.4f}') |
|
summary_writer.add_scalar('train_metrics/' + name, val, step) |
|
|
|
if config.vis_decimate > 1: |
|
d = config.vis_decimate |
|
decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d] |
|
else: |
|
decimate_fn = lambda x: x |
|
rendering = tree_map(decimate_fn, rendering) |
|
test_batch = tree_map(decimate_fn, test_batch) |
|
vis_start_time = time.time() |
|
vis_suite = vis.visualize_suite(rendering, test_batch) |
|
with tqdm.external_write_mode(): |
|
logger.info(f'Visualized in {(time.time() - vis_start_time):0.3f}s') |
|
if config.rawnerf_mode: |
|
|
|
vis_suite['color_raw'] = rendering['rgb'] |
|
|
|
vis_suite['color_auto'] = postprocess_fn(rendering['rgb'], None) |
|
summary_writer.add_image('test_true_auto', |
|
tb_process_fn(postprocess_fn(test_batch['rgb'], None)), step) |
|
|
|
exposures = test_dataset.metadata['exposure_levels'] |
|
for p, x in list(exposures.items()): |
|
vis_suite[f'color/{p}'] = postprocess_fn(rendering['rgb'], x) |
|
summary_writer.add_image(f'test_true_color/{p}', |
|
tb_process_fn(postprocess_fn(test_batch['rgb'], x)), step) |
|
summary_writer.add_image('test_true_color', tb_process_fn(test_batch['rgb']), step) |
|
if config.compute_normal_metrics: |
|
summary_writer.add_image('test_true_normals', |
|
tb_process_fn(test_batch['normals']) / 2. + 0.5, step) |
|
for k, v in vis_suite.items(): |
|
summary_writer.add_image('test_output_' + k, tb_process_fn(v), step) |
|
|
|
if accelerator.is_main_process and config.max_steps > init_step: |
|
logger.info('Saving last checkpoint at step {} to {}'.format(step, config.checkpoint_dir)) |
|
checkpoints.save_checkpoint(config.checkpoint_dir, |
|
accelerator, step, |
|
config.checkpoints_total_limit) |
|
logger.info('Finish training.') |
|
|
|
|
|
if __name__ == '__main__': |
|
with gin.config_scope('train'): |
|
app.run(main) |
|
|