from PIL import Image import os import json import logging import shutil import csv # from lib.network.munit import Vgg16 from torch.autograd import Variable from torch.optim import lr_scheduler from easydict import EasyDict as edict import torch import torch.nn as nn import os import math import torchvision.utils as vutils import yaml import numpy as np import torch.nn.init as init import time def get_config(config_path): with open(config_path, 'r') as stream: config = yaml.load(stream, Loader=yaml.SafeLoader) config = edict(config) _, config_filename = os.path.split(config_path) config_name, _ = os.path.splitext(config_filename) config.name = config_name return config class TextLogger: def __init__(self, log_path): self.log_path = log_path with open(self.log_path, "w") as f: f.write("") def log(self, log): with open(self.log_path, "a+") as f: f.write(log + "\n") def eformat(f, prec): s = "%.*e"%(prec, f) mantissa, exp = s.split('e') # add 1 to digits as 1 is taken by sign +/- return "%se%d"%(mantissa, int(exp)) def __write_images(image_outputs, display_image_num, file_name): image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True) vutils.save_image(image_grid, file_name, nrow=1) def write_2images(image_outputs, display_image_num, image_directory, postfix): n = len(image_outputs) __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix)) __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix)) def write_one_row_html(html_file, iterations, img_filename, all_size): html_file.write("

iteration [%d] (%s)

" % (iterations,img_filename.split('/')[-1])) html_file.write("""


""" % (img_filename, img_filename, all_size)) return def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): html_file = open(filename, "w") html_file.write(''' Experiment name = %s ''' % os.path.basename(filename)) html_file.write("

current

") write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size) write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size) for j in range(iterations, image_save_iterations-1, -1): if j % image_save_iterations == 0: write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size) write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size) write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size) write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size) html_file.write("") html_file.close() def write_loss(iterations, trainer, train_writer): members = [attr for attr in dir(trainer) \ if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)] for m in members: train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) def slerp(val, low, high): """ original: Animating Rotation with Quaternion Curves, Ken Shoemake https://arxiv.org/abs/1609.04468 Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White """ omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high))) so = np.sin(omega) return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high def get_slerp_interp(nb_latents, nb_interp, z_dim): """ modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot https://github.com/ptrblck/prog_gans_pytorch_inference """ latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32) for _ in range(nb_latents): low = np.random.randn(z_dim) high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7 interp_vals = np.linspace(0, 1, num=nb_interp) latent_interp = np.array([slerp(v, low, high) for v in interp_vals], dtype=np.float32) latent_interps = np.vstack((latent_interps, latent_interp)) return latent_interps[:, :, np.newaxis, np.newaxis] # Get model list for resume def get_model_list(dirname, key): if os.path.exists(dirname) is False: return None gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] if gen_models is None: return None gen_models.sort() last_model_name = gen_models[-1] return last_model_name def get_scheduler(optimizer, hyperparameters, iterations=-1): if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': scheduler = None # constant scheduler elif hyperparameters['lr_policy'] == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], gamma=hyperparameters['gamma'], last_epoch=iterations) else: return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) return scheduler def weights_init(init_type='gaussian'): def init_fun(m): classname = m.__class__.__name__ if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): # print m.__class__.__name__ if init_type == 'gaussian': init.normal_(m.weight.data, 0.0, 0.02) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=math.sqrt(2)) elif init_type == 'default': pass else: assert 0, "Unsupported initialization: {}".format(init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) return init_fun class Timer: def __init__(self, msg): self.msg = msg self.start_time = None def __enter__(self): self.start_time = time.time() def __exit__(self, exc_type, exc_value, exc_tb): print(self.msg % (time.time() - self.start_time)) class TrainClock(object): def __init__(self): self.epoch = 1 self.minibatch = 0 self.step = 0 def tick(self): self.minibatch += 1 self.step += 1 def tock(self): self.epoch += 1 self.minibatch = 0 def make_checkpoint(self): return { 'epoch': self.epoch, 'minibatch': self.minibatch, 'step': self.step } def restore_checkpoint(self, clock_dict): self.epoch = clock_dict['epoch'] self.minibatch = clock_dict['minibatch'] self.step = clock_dict['step'] class Table(object): def __init__(self, filename): ''' create a table to record experiment results that can be opened by excel :param filename: using '.csv' as postfix ''' assert '.csv' in filename self.filename = filename @staticmethod def merge_headers(header1, header2): #return list(set(header1 + header2)) if len(header1) > len(header2): return header1 else: return header2 def write(self, ordered_dict): ''' write an entry :param ordered_dict: something like {'name':'exp1', 'acc':90.5, 'epoch':50} :return: ''' if os.path.exists(self.filename) == False: headers = list(ordered_dict.keys()) prev_rec = None else: with open(self.filename) as f: reader = csv.DictReader(f) headers = reader.fieldnames prev_rec = [row for row in reader] headers = self.merge_headers(headers, list(ordered_dict.keys())) with open(self.filename, 'w', newline='') as f: writer = csv.DictWriter(f, headers) writer.writeheader() if not prev_rec == None: writer.writerows(prev_rec) writer.writerow(ordered_dict) class WorklogLogger: def __init__(self, log_file): logging.basicConfig(filename=log_file, level=logging.DEBUG, format='%(asctime)s - %(threadName)s - %(levelname)s - %(message)s') self.logger = logging.getLogger() def put_line(self, line): self.logger.info(line) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name): self.name = name self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def save_args(args, save_dir): param_path = os.path.join(save_dir, 'params.json') with open(param_path, 'w') as fp: json.dump(args.__dict__, fp, indent=4, sort_keys=True) def ensure_dir(path): """ create path by first checking its existence, :param paths: path :return: """ if not os.path.exists(path): os.makedirs(path) def ensure_dirs(paths): """ create paths by first checking their existence :param paths: list of path :return: """ if isinstance(paths, list) and not isinstance(paths, str): for path in paths: ensure_dir(path) else: ensure_dir(paths) def remkdir(path): """ if dir exists, remove it and create a new one :param path: :return: """ if os.path.exists(path): shutil.rmtree(path) os.makedirs(path) def cycle(iterable): while True: for x in iterable: yield x def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) image_pil.save(image_path) def pad_to_16x(x): if x % 16 > 0: return x - x % 16 + 16 return x def pad_to_height(tar_height, img_height, img_width): scale = tar_height / img_height h = pad_to_16x(tar_height) w = pad_to_16x(int(img_width * scale)) return h, w, scale def to_gpu(data): for key, item in data.items(): if torch.is_tensor(item): data[key] = item.cuda() return data