|
from PIL import Image |
|
import os |
|
import json |
|
import logging |
|
import shutil |
|
import csv |
|
|
|
from torch.autograd import Variable |
|
from torch.optim import lr_scheduler |
|
from easydict import EasyDict as edict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import os |
|
import math |
|
import torchvision.utils as vutils |
|
import yaml |
|
import numpy as np |
|
import torch.nn.init as init |
|
import time |
|
|
|
|
|
def get_config(config_path): |
|
with open(config_path, 'r') as stream: |
|
config = yaml.load(stream, Loader=yaml.SafeLoader) |
|
config = edict(config) |
|
_, config_filename = os.path.split(config_path) |
|
config_name, _ = os.path.splitext(config_filename) |
|
config.name = config_name |
|
return config |
|
|
|
class TextLogger: |
|
|
|
def __init__(self, log_path): |
|
self.log_path = log_path |
|
with open(self.log_path, "w") as f: |
|
f.write("") |
|
def log(self, log): |
|
with open(self.log_path, "a+") as f: |
|
f.write(log + "\n") |
|
|
|
def eformat(f, prec): |
|
s = "%.*e"%(prec, f) |
|
mantissa, exp = s.split('e') |
|
|
|
return "%se%d"%(mantissa, int(exp)) |
|
|
|
|
|
def __write_images(image_outputs, display_image_num, file_name): |
|
image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] |
|
image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) |
|
image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True) |
|
vutils.save_image(image_grid, file_name, nrow=1) |
|
|
|
|
|
def write_2images(image_outputs, display_image_num, image_directory, postfix): |
|
n = len(image_outputs) |
|
__write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix)) |
|
__write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix)) |
|
|
|
|
|
def write_one_row_html(html_file, iterations, img_filename, all_size): |
|
html_file.write("<h3>iteration [%d] (%s)</h3>" % (iterations,img_filename.split('/')[-1])) |
|
html_file.write(""" |
|
<p><a href="%s"> |
|
<img src="%s" style="width:%dpx"> |
|
</a><br> |
|
<p> |
|
""" % (img_filename, img_filename, all_size)) |
|
return |
|
|
|
|
|
def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): |
|
html_file = open(filename, "w") |
|
html_file.write(''' |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Experiment name = %s</title> |
|
<meta http-equiv="refresh" content="30"> |
|
</head> |
|
<body> |
|
''' % os.path.basename(filename)) |
|
html_file.write("<h3>current</h3>") |
|
write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size) |
|
write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size) |
|
for j in range(iterations, image_save_iterations-1, -1): |
|
if j % image_save_iterations == 0: |
|
write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size) |
|
write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size) |
|
write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size) |
|
write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size) |
|
html_file.write("</body></html>") |
|
html_file.close() |
|
|
|
|
|
def write_loss(iterations, trainer, train_writer): |
|
members = [attr for attr in dir(trainer) \ |
|
if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)] |
|
for m in members: |
|
train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) |
|
|
|
|
|
def slerp(val, low, high): |
|
""" |
|
original: Animating Rotation with Quaternion Curves, Ken Shoemake |
|
https://arxiv.org/abs/1609.04468 |
|
Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White |
|
""" |
|
omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high))) |
|
so = np.sin(omega) |
|
return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high |
|
|
|
|
|
def get_slerp_interp(nb_latents, nb_interp, z_dim): |
|
""" |
|
modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot |
|
https://github.com/ptrblck/prog_gans_pytorch_inference |
|
""" |
|
|
|
latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32) |
|
for _ in range(nb_latents): |
|
low = np.random.randn(z_dim) |
|
high = np.random.randn(z_dim) |
|
interp_vals = np.linspace(0, 1, num=nb_interp) |
|
latent_interp = np.array([slerp(v, low, high) for v in interp_vals], |
|
dtype=np.float32) |
|
latent_interps = np.vstack((latent_interps, latent_interp)) |
|
|
|
return latent_interps[:, :, np.newaxis, np.newaxis] |
|
|
|
|
|
|
|
def get_model_list(dirname, key): |
|
if os.path.exists(dirname) is False: |
|
return None |
|
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if |
|
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] |
|
if gen_models is None: |
|
return None |
|
gen_models.sort() |
|
last_model_name = gen_models[-1] |
|
return last_model_name |
|
|
|
|
|
def get_scheduler(optimizer, hyperparameters, iterations=-1): |
|
if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': |
|
scheduler = None |
|
elif hyperparameters['lr_policy'] == 'step': |
|
scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], |
|
gamma=hyperparameters['gamma'], last_epoch=iterations) |
|
else: |
|
return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) |
|
return scheduler |
|
|
|
|
|
def weights_init(init_type='gaussian'): |
|
def init_fun(m): |
|
classname = m.__class__.__name__ |
|
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): |
|
|
|
if init_type == 'gaussian': |
|
init.normal_(m.weight.data, 0.0, 0.02) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) |
|
elif init_type == 'kaiming': |
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=math.sqrt(2)) |
|
elif init_type == 'default': |
|
pass |
|
else: |
|
assert 0, "Unsupported initialization: {}".format(init_type) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
return init_fun |
|
|
|
|
|
class Timer: |
|
def __init__(self, msg): |
|
self.msg = msg |
|
self.start_time = None |
|
|
|
def __enter__(self): |
|
self.start_time = time.time() |
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb): |
|
print(self.msg % (time.time() - self.start_time)) |
|
|
|
|
|
class TrainClock(object): |
|
def __init__(self): |
|
self.epoch = 1 |
|
self.minibatch = 0 |
|
self.step = 0 |
|
|
|
def tick(self): |
|
self.minibatch += 1 |
|
self.step += 1 |
|
|
|
def tock(self): |
|
self.epoch += 1 |
|
self.minibatch = 0 |
|
|
|
def make_checkpoint(self): |
|
return { |
|
'epoch': self.epoch, |
|
'minibatch': self.minibatch, |
|
'step': self.step |
|
} |
|
|
|
def restore_checkpoint(self, clock_dict): |
|
self.epoch = clock_dict['epoch'] |
|
self.minibatch = clock_dict['minibatch'] |
|
self.step = clock_dict['step'] |
|
|
|
|
|
class Table(object): |
|
def __init__(self, filename): |
|
''' |
|
create a table to record experiment results that can be opened by excel |
|
:param filename: using '.csv' as postfix |
|
''' |
|
assert '.csv' in filename |
|
self.filename = filename |
|
|
|
@staticmethod |
|
def merge_headers(header1, header2): |
|
|
|
if len(header1) > len(header2): |
|
return header1 |
|
else: |
|
return header2 |
|
|
|
def write(self, ordered_dict): |
|
''' |
|
write an entry |
|
:param ordered_dict: something like {'name':'exp1', 'acc':90.5, 'epoch':50} |
|
:return: |
|
''' |
|
if os.path.exists(self.filename) == False: |
|
headers = list(ordered_dict.keys()) |
|
prev_rec = None |
|
else: |
|
with open(self.filename) as f: |
|
reader = csv.DictReader(f) |
|
headers = reader.fieldnames |
|
prev_rec = [row for row in reader] |
|
headers = self.merge_headers(headers, list(ordered_dict.keys())) |
|
|
|
with open(self.filename, 'w', newline='') as f: |
|
writer = csv.DictWriter(f, headers) |
|
writer.writeheader() |
|
if not prev_rec == None: |
|
writer.writerows(prev_rec) |
|
writer.writerow(ordered_dict) |
|
|
|
|
|
class WorklogLogger: |
|
def __init__(self, log_file): |
|
logging.basicConfig(filename=log_file, |
|
level=logging.DEBUG, |
|
format='%(asctime)s - %(threadName)s - %(levelname)s - %(message)s') |
|
|
|
self.logger = logging.getLogger() |
|
|
|
def put_line(self, line): |
|
self.logger.info(line) |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, name): |
|
self.name = name |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
def save_args(args, save_dir): |
|
param_path = os.path.join(save_dir, 'params.json') |
|
|
|
with open(param_path, 'w') as fp: |
|
json.dump(args.__dict__, fp, indent=4, sort_keys=True) |
|
|
|
|
|
def ensure_dir(path): |
|
""" |
|
create path by first checking its existence, |
|
:param paths: path |
|
:return: |
|
""" |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
|
|
def ensure_dirs(paths): |
|
""" |
|
create paths by first checking their existence |
|
:param paths: list of path |
|
:return: |
|
""" |
|
if isinstance(paths, list) and not isinstance(paths, str): |
|
for path in paths: |
|
ensure_dir(path) |
|
else: |
|
ensure_dir(paths) |
|
|
|
|
|
def remkdir(path): |
|
""" |
|
if dir exists, remove it and create a new one |
|
:param path: |
|
:return: |
|
""" |
|
if os.path.exists(path): |
|
shutil.rmtree(path) |
|
os.makedirs(path) |
|
|
|
|
|
def cycle(iterable): |
|
while True: |
|
for x in iterable: |
|
yield x |
|
|
|
|
|
def save_image(image_numpy, image_path): |
|
image_pil = Image.fromarray(image_numpy) |
|
image_pil.save(image_path) |
|
|
|
|
|
def pad_to_16x(x): |
|
if x % 16 > 0: |
|
return x - x % 16 + 16 |
|
return x |
|
|
|
|
|
def pad_to_height(tar_height, img_height, img_width): |
|
scale = tar_height / img_height |
|
h = pad_to_16x(tar_height) |
|
w = pad_to_16x(int(img_width * scale)) |
|
return h, w, scale |
|
|
|
|
|
def to_gpu(data): |
|
for key, item in data.items(): |
|
if torch.is_tensor(item): |
|
data[key] = item.cuda() |
|
return data |
|
|