Spaces:
Running
Running
import re | |
from pathlib import Path | |
import glob | |
import logging | |
import numpy as np | |
import torch | |
import cv2 | |
import os | |
import math | |
from adamp import AdamP | |
import random | |
import torch.nn as nn | |
_logger = None | |
def increment_path(path): | |
# Increment path, i.e. runs/exp1 --> runs/exp{sep}1, runs/exp{sep}2 etc. | |
res = re.search("\d+", path) | |
if res is None: | |
print("Set initial exp number!") | |
exit(1) | |
if not Path(path).exists(): | |
return str(path) | |
else: | |
path = path[:res.start()] | |
dirs = glob.glob(f"{path}*") # similar paths | |
matches = [re.search(rf"%s(\d+)" % Path(path).stem, d) for d in dirs] | |
i = [int(m.groups()[0]) for m in matches if m] # indices | |
n = max(i) + 1 # increment number | |
return f"{path}{n}" # update path | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, fmt=':f'): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def create_logger(log_file, level=logging.INFO): | |
global _logger | |
_logger = logging.getLogger() | |
formatter = logging.Formatter( | |
'[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s') | |
fh = logging.FileHandler(log_file) | |
fh.setFormatter(formatter) | |
sh = logging.StreamHandler() | |
sh.setFormatter(formatter) | |
_logger.setLevel(level) | |
_logger.addHandler(fh) | |
_logger.addHandler(sh) | |
return _logger | |
def get_mgrid(sidelen, dim=2): | |
'''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' | |
if isinstance(sidelen, int): | |
sidelen = dim * (sidelen,) | |
if dim == 2: | |
pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) | |
pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) | |
pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) | |
elif dim == 3: | |
pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) | |
pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) | |
pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) | |
pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) | |
else: | |
raise NotImplementedError('Not implemented for dim=%d' % dim) | |
pixel_coords -= 0.5 | |
pixel_coords *= 2. | |
pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) | |
return pixel_coords | |
def lin2img(tensor, image_resolution=None): | |
batch_size, num_samples, channels = tensor.shape | |
if image_resolution is None: | |
width = np.sqrt(num_samples).astype(int) | |
height = width | |
else: | |
if isinstance(image_resolution, int): | |
image_resolution = (image_resolution, image_resolution) | |
height = image_resolution[0] | |
width = image_resolution[1] | |
return tensor.permute(0, 2, 1).contiguous().view(batch_size, channels, height, width) | |
def normalize(x, opt, mode='normal'): | |
device = x.device | |
mean = torch.tensor(np.array(opt.transform_mean), dtype=x.dtype)[np.newaxis, :, np.newaxis, np.newaxis].to(device) | |
var = torch.tensor(np.array(opt.transform_var), dtype=x.dtype)[np.newaxis, :, np.newaxis, np.newaxis].to(device) | |
if mode == 'normal': | |
return (x - mean) / var | |
elif mode == 'inv': | |
return x * var + mean | |
def prepare_cooridinate_input(mask, dim=2): | |
'''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' | |
if mask.shape[0] == mask.shape[1]: | |
sidelen = mask.shape[0] | |
else: | |
sidelen = mask.shape[:2] | |
if isinstance(sidelen, int): | |
sidelen = dim * (sidelen,) | |
if dim == 2: | |
pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) | |
pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) | |
pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) | |
elif dim == 3: | |
pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) | |
pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) | |
pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) | |
pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) | |
else: | |
raise NotImplementedError('Not implemented for dim=%d' % dim) | |
pixel_coords -= 0.5 | |
pixel_coords *= 2. | |
return pixel_coords.squeeze(0).transpose(2, 0, 1) | |
def visualize(real, composite, mask, pred_fg, pred_harmonized, lut_transform_image, opt, epoch, | |
show=False, wandb=True, isAll=False, step=None): | |
save_path = os.path.join(opt.save_path, "figs", str(epoch)) | |
os.makedirs(save_path, exist_ok=True) | |
if isAll: | |
final_index = 1 | |
""" | |
Uncomment the following code if you want to save all the results, otherwise will only save the first image | |
of each batch | |
""" | |
# final_index = len(real) | |
else: | |
final_index = 1 | |
for id in range(final_index): | |
if show: | |
cv2.imshow("pred_fg", normalize(pred_fg, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) | |
cv2.imshow("real", normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) | |
cv2.imshow("lut_transform", normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) | |
cv2.imshow("composite", normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) | |
cv2.imshow("mask", mask[id].permute(1, 2, 0).cpu().numpy()) | |
cv2.imshow("pred_harmonized_image", | |
normalize(pred_harmonized, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) | |
cv2.waitKey() | |
if not opt.INRDecode: | |
real_tmp = cv2.cvtColor( | |
normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( | |
np.uint8), | |
cv2.COLOR_RGB2BGR) | |
composite_tmp = cv2.cvtColor( | |
normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( | |
np.uint8), cv2.COLOR_RGB2BGR) | |
mask_tmp = mask[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype(np.uint8) | |
lut_transform_image_tmp = cv2.cvtColor( | |
normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( | |
0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) | |
else: | |
pred_fg_tmp = cv2.cvtColor( | |
normalize(pred_fg, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( | |
np.uint8), cv2.COLOR_RGB2BGR) | |
real_tmp = cv2.cvtColor( | |
normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( | |
np.uint8), | |
cv2.COLOR_RGB2BGR) | |
composite_tmp = cv2.cvtColor( | |
normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( | |
np.uint8), cv2.COLOR_RGB2BGR) | |
lut_transform_image_tmp = cv2.cvtColor( | |
normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( | |
0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) | |
mask_tmp = mask[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype(np.uint8) | |
pred_harmonized_tmp = cv2.cvtColor( | |
normalize(pred_harmonized, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( | |
0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) | |
if isAll: | |
cv2.imwrite(os.path.join(save_path, f"{step}_{id}_composite.jpg"), composite_tmp) | |
cv2.imwrite(os.path.join(save_path, f"{step}_{id}_real.jpg"), real_tmp) | |
if opt.INRDecode: | |
cv2.imwrite(os.path.join(save_path, f"{step}_{id}_pred_harmonized_image.jpg"), pred_harmonized_tmp) | |
cv2.imwrite(os.path.join(save_path, f"{step}_{id}_lut_transform_image.jpg"), lut_transform_image_tmp) | |
cv2.imwrite(os.path.join(save_path, f"{step}_{id}_mask.jpg"), mask_tmp) | |
else: | |
if not opt.INRDecode: | |
cv2.imwrite(os.path.join(save_path, f"real_{step}_{id}.jpg"), real_tmp) | |
cv2.imwrite(os.path.join(save_path, f"composite_{step}_{id}.jpg"), composite_tmp) | |
cv2.imwrite(os.path.join(save_path, f"mask_{step}_{id}.jpg"), mask_tmp) | |
cv2.imwrite(os.path.join(save_path, f"lut_transform_image_{step}_{id}.jpg"), lut_transform_image_tmp) | |
else: | |
cv2.imwrite(os.path.join(save_path, f"pred_fg_{step}_{id}.jpg"), pred_fg_tmp) | |
cv2.imwrite(os.path.join(save_path, f"real_{step}_{id}.jpg"), real_tmp) | |
cv2.imwrite(os.path.join(save_path, f"composite_{step}_{id}.jpg"), composite_tmp) | |
cv2.imwrite(os.path.join(save_path, f"mask_{step}_{id}.jpg"), mask_tmp) | |
cv2.imwrite(os.path.join(save_path, f"pred_harmonized_image_{step}_{id}.jpg"), pred_harmonized_tmp) | |
cv2.imwrite(os.path.join(save_path, f"lut_transform_image_{step}_{id}.jpg"), lut_transform_image_tmp) | |
"Only upload images of the first batch of the first epoch to save storage." | |
if wandb and id == 0 and step == 0: | |
import wandb | |
real_tmp = wandb.Image(real_tmp, caption=epoch) | |
composite_tmp = wandb.Image(composite_tmp, caption=epoch) | |
if opt.INRDecode: | |
pred_fg_tmp = wandb.Image(pred_fg_tmp, caption=epoch) | |
pred_harmonized_tmp = wandb.Image(pred_harmonized_tmp, caption=epoch) | |
lut_transform_image_tmp = wandb.Image(lut_transform_image_tmp, caption=epoch) | |
mask_tmp = wandb.Image(mask_tmp, caption=epoch) | |
if not opt.INRDecode: | |
wandb.log( | |
{"pic/real": real_tmp, "pic/composite": composite_tmp, | |
"pic/mask": mask_tmp, | |
"pic/lut_trans": lut_transform_image_tmp, | |
"pic/epoch": epoch}) | |
else: | |
wandb.log( | |
{"pic/pred_fg": pred_fg_tmp, "pic/real": real_tmp, "pic/composite": composite_tmp, | |
"pic/mask": mask_tmp, | |
"pic/lut_trans": lut_transform_image_tmp, | |
"pic/pred_harmonized": pred_harmonized_tmp, | |
"pic/epoch": epoch}) | |
wandb.log({}) | |
def get_optimizer(model, opt_name, opt_kwargs): | |
params = [] | |
base_lr = opt_kwargs['lr'] | |
for name, param in model.named_parameters(): | |
param_group = {'params': [param]} | |
if not param.requires_grad: | |
params.append(param_group) | |
continue | |
if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): | |
# print(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') | |
param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult | |
params.append(param_group) | |
optimizer = { | |
'sgd': torch.optim.SGD, | |
'adam': torch.optim.Adam, | |
'adamw': torch.optim.AdamW, | |
'adamp': AdamP | |
}[opt_name.lower()](params, **opt_kwargs) | |
return optimizer | |
def improved_efficient_matmul(a, c, index, batch=256): | |
""" | |
Reduce the unneed memory cost, but the speed is very slow. | |
:param a: N * I * J | |
:param b: N * J * K | |
:return: N * I * K | |
""" | |
"The first can only support when a is not requires_grad_, and have high speed. While the second one supports " | |
"whatever situations, but speed is quite slow. More Details in " | |
"https://discuss.pytorch.org/t/many-weird-phenomena-about-torch-matmul-operation/158208" | |
# out = torch.cat( | |
# [torch.matmul(a[i * batch:i * batch + batch, :, :], c[index[i * batch:i * batch + batch], :, :]) for i in | |
# range(a.shape[0] // batch)], dim=0) | |
batch = 1 | |
out = torch.cat( | |
[torch.matmul(a[i * batch:i * batch + batch, :, :], c[index[i * batch], :, :]) for i in | |
range(a.shape[0] // batch)], dim=0) | |
return out | |
class LRMult(object): | |
def __init__(self, lr_mult=1.): | |
self.lr_mult = lr_mult | |
def __call__(self, m): | |
if getattr(m, 'weight', None) is not None: | |
m.weight.lr_mult = self.lr_mult | |
if getattr(m, 'bias', None) is not None: | |
m.bias.lr_mult = self.lr_mult | |
def customRandomCrop(objects, crop_height, crop_width, h_start=None, w_start=None): | |
if h_start is None: | |
h_start = random.random() | |
if w_start is None: | |
w_start = random.random() | |
if isinstance(objects, list): | |
out = [] | |
for obj in objects: | |
out.append(random_crop(obj, crop_height, crop_width, h_start, w_start)) | |
else: | |
out = random_crop(objects, crop_height, crop_width, h_start, w_start) | |
return out, h_start, w_start | |
def get_random_crop_coords(height: int, width: int, crop_height: int, crop_width: int, h_start: float, | |
w_start: float): | |
y1 = int((height - crop_height) * h_start) | |
y2 = y1 + crop_height | |
x1 = int((width - crop_width) * w_start) | |
x2 = x1 + crop_width | |
return x1, y1, x2, y2 | |
def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: float, w_start: float): | |
height, width = img.shape[:2] | |
if height < crop_height or width < crop_width: | |
raise ValueError( | |
"Requested crop size ({crop_height}, {crop_width}) is " | |
"larger than the image size ({height}, {width})".format( | |
crop_height=crop_height, crop_width=crop_width, height=height, width=width | |
) | |
) | |
x1, y1, x2, y2 = get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start) | |
img = img[y1:y2, x1:x2] | |
return img | |
class PadToDivisor: | |
def __init__(self, divisor): | |
super().__init__() | |
self.divisor = divisor | |
def transform(self, images): | |
self._pads = (*self._get_dim_padding(images[0].shape[-1]), *self._get_dim_padding(images[0].shape[-2])) | |
self.pad_operation = nn.ZeroPad2d(padding=self._pads) | |
out = [] | |
for im in images: | |
out.append(self.pad_operation(im)) | |
return out | |
def inv_transform(self, image): | |
assert self._pads is not None,\ | |
'Something went wrong, inv_transform(...) should be called after transform(...)' | |
return self._remove_padding(image) | |
def _get_dim_padding(self, dim_size): | |
pad = (self.divisor - dim_size % self.divisor) % self.divisor | |
pad_upper = pad // 2 | |
pad_lower = pad - pad_upper | |
return pad_upper, pad_lower | |
def _remove_padding(self, tensors): | |
tensor_h, tensor_w = tensors[0].shape[-2:] | |
out = [] | |
for t in tensors: | |
out.append(t[..., self._pads[2]:tensor_h - self._pads[3], self._pads[0]:tensor_w - self._pads[1]]) | |
return out | |