from PIL import Image
import os
import json
import logging
import shutil
import csv
# from lib.network.munit import Vgg16
from torch.autograd import Variable
from torch.optim import lr_scheduler
from easydict import EasyDict as edict
import torch
import torch.nn as nn
import os
import math
import torchvision.utils as vutils
import yaml
import numpy as np
import torch.nn.init as init
import time
def get_config(config_path):
with open(config_path, 'r') as stream:
config = yaml.load(stream, Loader=yaml.SafeLoader)
config = edict(config)
_, config_filename = os.path.split(config_path)
config_name, _ = os.path.splitext(config_filename)
config.name = config_name
return config
class TextLogger:
def __init__(self, log_path):
self.log_path = log_path
with open(self.log_path, "w") as f:
f.write("")
def log(self, log):
with open(self.log_path, "a+") as f:
f.write(log + "\n")
def eformat(f, prec):
s = "%.*e"%(prec, f)
mantissa, exp = s.split('e')
# add 1 to digits as 1 is taken by sign +/-
return "%se%d"%(mantissa, int(exp))
def __write_images(image_outputs, display_image_num, file_name):
image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels
image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True)
vutils.save_image(image_grid, file_name, nrow=1)
def write_2images(image_outputs, display_image_num, image_directory, postfix):
n = len(image_outputs)
__write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix))
__write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix))
def write_one_row_html(html_file, iterations, img_filename, all_size):
html_file.write("
iteration [%d] (%s)
" % (iterations,img_filename.split('/')[-1]))
html_file.write("""
""" % (img_filename, img_filename, all_size))
return
def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536):
html_file = open(filename, "w")
html_file.write('''
Experiment name = %s
''' % os.path.basename(filename))
html_file.write("current
")
write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size)
write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size)
for j in range(iterations, image_save_iterations-1, -1):
if j % image_save_iterations == 0:
write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size)
write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size)
write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size)
write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size)
html_file.write("")
html_file.close()
def write_loss(iterations, trainer, train_writer):
members = [attr for attr in dir(trainer) \
if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)]
for m in members:
train_writer.add_scalar(m, getattr(trainer, m), iterations + 1)
def slerp(val, low, high):
"""
original: Animating Rotation with Quaternion Curves, Ken Shoemake
https://arxiv.org/abs/1609.04468
Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White
"""
omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high)))
so = np.sin(omega)
return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high
def get_slerp_interp(nb_latents, nb_interp, z_dim):
"""
modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot
https://github.com/ptrblck/prog_gans_pytorch_inference
"""
latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32)
for _ in range(nb_latents):
low = np.random.randn(z_dim)
high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7
interp_vals = np.linspace(0, 1, num=nb_interp)
latent_interp = np.array([slerp(v, low, high) for v in interp_vals],
dtype=np.float32)
latent_interps = np.vstack((latent_interps, latent_interp))
return latent_interps[:, :, np.newaxis, np.newaxis]
# Get model list for resume
def get_model_list(dirname, key):
if os.path.exists(dirname) is False:
return None
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
if gen_models is None:
return None
gen_models.sort()
last_model_name = gen_models[-1]
return last_model_name
def get_scheduler(optimizer, hyperparameters, iterations=-1):
if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant':
scheduler = None # constant scheduler
elif hyperparameters['lr_policy'] == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'],
gamma=hyperparameters['gamma'], last_epoch=iterations)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy'])
return scheduler
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
# print m.__class__.__name__
if init_type == 'gaussian':
init.normal_(m.weight.data, 0.0, 0.02)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
return init_fun
class Timer:
def __init__(self, msg):
self.msg = msg
self.start_time = None
def __enter__(self):
self.start_time = time.time()
def __exit__(self, exc_type, exc_value, exc_tb):
print(self.msg % (time.time() - self.start_time))
class TrainClock(object):
def __init__(self):
self.epoch = 1
self.minibatch = 0
self.step = 0
def tick(self):
self.minibatch += 1
self.step += 1
def tock(self):
self.epoch += 1
self.minibatch = 0
def make_checkpoint(self):
return {
'epoch': self.epoch,
'minibatch': self.minibatch,
'step': self.step
}
def restore_checkpoint(self, clock_dict):
self.epoch = clock_dict['epoch']
self.minibatch = clock_dict['minibatch']
self.step = clock_dict['step']
class Table(object):
def __init__(self, filename):
'''
create a table to record experiment results that can be opened by excel
:param filename: using '.csv' as postfix
'''
assert '.csv' in filename
self.filename = filename
@staticmethod
def merge_headers(header1, header2):
#return list(set(header1 + header2))
if len(header1) > len(header2):
return header1
else:
return header2
def write(self, ordered_dict):
'''
write an entry
:param ordered_dict: something like {'name':'exp1', 'acc':90.5, 'epoch':50}
:return:
'''
if os.path.exists(self.filename) == False:
headers = list(ordered_dict.keys())
prev_rec = None
else:
with open(self.filename) as f:
reader = csv.DictReader(f)
headers = reader.fieldnames
prev_rec = [row for row in reader]
headers = self.merge_headers(headers, list(ordered_dict.keys()))
with open(self.filename, 'w', newline='') as f:
writer = csv.DictWriter(f, headers)
writer.writeheader()
if not prev_rec == None:
writer.writerows(prev_rec)
writer.writerow(ordered_dict)
class WorklogLogger:
def __init__(self, log_file):
logging.basicConfig(filename=log_file,
level=logging.DEBUG,
format='%(asctime)s - %(threadName)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger()
def put_line(self, line):
self.logger.info(line)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name):
self.name = name
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def save_args(args, save_dir):
param_path = os.path.join(save_dir, 'params.json')
with open(param_path, 'w') as fp:
json.dump(args.__dict__, fp, indent=4, sort_keys=True)
def ensure_dir(path):
"""
create path by first checking its existence,
:param paths: path
:return:
"""
if not os.path.exists(path):
os.makedirs(path)
def ensure_dirs(paths):
"""
create paths by first checking their existence
:param paths: list of path
:return:
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
ensure_dir(path)
else:
ensure_dir(paths)
def remkdir(path):
"""
if dir exists, remove it and create a new one
:param path:
:return:
"""
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
def cycle(iterable):
while True:
for x in iterable:
yield x
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def pad_to_16x(x):
if x % 16 > 0:
return x - x % 16 + 16
return x
def pad_to_height(tar_height, img_height, img_width):
scale = tar_height / img_height
h = pad_to_16x(tar_height)
w = pad_to_16x(int(img_width * scale))
return h, w, scale
def to_gpu(data):
for key, item in data.items():
if torch.is_tensor(item):
data[key] = item.cuda()
return data