|
|
|
import os |
|
import sys |
|
import shutil |
|
import logging |
|
import colorlog |
|
from tqdm import tqdm |
|
import time |
|
import yaml |
|
import random |
|
import importlib |
|
from PIL import Image |
|
from warnings import simplefilter |
|
import imageio |
|
import math |
|
import collections |
|
import json |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import Adam |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data import DataLoader, Dataset |
|
from einops import rearrange, repeat |
|
import torch.distributed as dist |
|
from torchvision import datasets, transforms, utils |
|
|
|
logging.getLogger().setLevel(logging.WARNING) |
|
simplefilter(action='ignore', category=FutureWarning) |
|
|
|
def get_logger(filename=None): |
|
""" |
|
examples: |
|
logger = get_logger('try_logging.txt') |
|
|
|
logger.debug("Do something.") |
|
logger.info("Start print log.") |
|
logger.warning("Something maybe fail.") |
|
try: |
|
raise ValueError() |
|
except ValueError: |
|
logger.error("Error", exc_info=True) |
|
|
|
tips: |
|
DO NOT logger.inf(some big tensors since color may not helpful.) |
|
""" |
|
logger = logging.getLogger('utils') |
|
level = logging.DEBUG |
|
logger.setLevel(level=level) |
|
|
|
logger.propagate = False |
|
|
|
format_str = '[%(asctime)s <%(filename)s:%(lineno)d> %(funcName)s] %(message)s' |
|
|
|
streamHandler = logging.StreamHandler() |
|
streamHandler.setLevel(level) |
|
coloredFormatter = colorlog.ColoredFormatter( |
|
'%(log_color)s' + format_str, |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
reset=True, |
|
log_colors={ |
|
'DEBUG': 'cyan', |
|
|
|
'WARNING': 'yellow', |
|
'ERROR': 'red', |
|
'CRITICAL': 'reg,bg_white', |
|
} |
|
) |
|
|
|
streamHandler.setFormatter(coloredFormatter) |
|
logger.addHandler(streamHandler) |
|
|
|
if filename: |
|
fileHandler = logging.FileHandler(filename) |
|
fileHandler.setLevel(level) |
|
formatter = logging.Formatter(format_str) |
|
fileHandler.setFormatter(formatter) |
|
logger.addHandler(fileHandler) |
|
|
|
|
|
try: |
|
class UniqueLogger: |
|
def __init__(self, logger): |
|
self.logger = logger |
|
self.local_rank = torch.distributed.get_rank() |
|
|
|
def info(self, msg, *args, **kwargs): |
|
if self.local_rank == 0: |
|
return self.logger.info(msg, *args, **kwargs) |
|
|
|
def warning(self, msg, *args, **kwargs): |
|
if self.local_rank == 0: |
|
return self.logger.warning(msg, *args, **kwargs) |
|
|
|
logger = UniqueLogger(logger) |
|
|
|
|
|
except Exception: |
|
pass |
|
return logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
def split_filename(filename): |
|
absname = os.path.abspath(filename) |
|
dirname, basename = os.path.split(absname) |
|
split_tmp = basename.rsplit('.', maxsplit=1) |
|
if len(split_tmp) == 2: |
|
rootname, extname = split_tmp |
|
elif len(split_tmp) == 1: |
|
rootname = split_tmp[0] |
|
extname = None |
|
else: |
|
raise ValueError("programming error!") |
|
return dirname, rootname, extname |
|
|
|
def data2file(data, filename, type=None, override=False, printable=False, **kwargs): |
|
dirname, rootname, extname = split_filename(filename) |
|
print_did_not_save_flag = True |
|
if type: |
|
extname = type |
|
if not os.path.exists(dirname): |
|
os.makedirs(dirname, exist_ok=True) |
|
|
|
if not os.path.exists(filename) or override: |
|
if extname in ['jpg', 'png', 'jpeg']: |
|
utils.save_image(data, filename, **kwargs) |
|
elif extname == 'gif': |
|
imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0) |
|
elif extname == 'txt': |
|
if kwargs is None: |
|
kwargs = {} |
|
max_step = kwargs.get('max_step') |
|
if max_step is None: |
|
max_step = np.Infinity |
|
|
|
with open(filename, 'w', encoding='utf-8') as f: |
|
for i, e in enumerate(data): |
|
if i < max_step: |
|
f.write(str(e) + '\n') |
|
else: |
|
break |
|
else: |
|
raise ValueError('Do not support this type') |
|
if printable: logger.info('Saved data to %s' % os.path.abspath(filename)) |
|
else: |
|
if print_did_not_save_flag: logger.info( |
|
'Did not save data to %s because file exists and override is False' % os.path.abspath( |
|
filename)) |
|
|
|
|
|
def file2data(filename, type=None, printable=True, **kwargs): |
|
dirname, rootname, extname = split_filename(filename) |
|
print_load_flag = True |
|
if type: |
|
extname = type |
|
|
|
if extname in ['pth', 'ckpt']: |
|
data = torch.load(filename, map_location=kwargs.get('map_location')) |
|
elif extname == 'txt': |
|
top = kwargs.get('top', None) |
|
with open(filename, encoding='utf-8') as f: |
|
if top: |
|
data = [f.readline() for _ in range(top)] |
|
else: |
|
data = [e for e in f.read().split('\n') if e] |
|
elif extname == 'yaml': |
|
with open(filename, 'r') as f: |
|
data = yaml.load(f) |
|
else: |
|
raise ValueError('type can only support h5, npy, json, txt') |
|
if printable: |
|
if print_load_flag: |
|
logger.info('Loaded data from %s' % os.path.abspath(filename)) |
|
return data |
|
|
|
|
|
def ensure_dirname(dirname, override=False): |
|
if os.path.exists(dirname) and override: |
|
logger.info('Removing dirname: %s' % os.path.abspath(dirname)) |
|
try: |
|
shutil.rmtree(dirname) |
|
except OSError as e: |
|
raise ValueError('Failed to delete %s because %s' % (dirname, e)) |
|
|
|
if not os.path.exists(dirname): |
|
logger.info('Making dirname: %s' % os.path.abspath(dirname)) |
|
os.makedirs(dirname, exist_ok=True) |
|
|
|
|
|
def import_filename(filename): |
|
spec = importlib.util.spec_from_file_location("mymodule", filename) |
|
module = importlib.util.module_from_spec(spec) |
|
sys.modules[spec.name] = module |
|
spec.loader.exec_module(module) |
|
return module |
|
|
|
|
|
def adaptively_load_state_dict(target, state_dict): |
|
target_dict = target.state_dict() |
|
|
|
try: |
|
common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()} |
|
except Exception as e: |
|
logger.warning('load error %s', e) |
|
common_dict = {k: v for k, v in state_dict.items() if k in target_dict} |
|
|
|
if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \ |
|
target.state_dict()['param_groups'][0]['params']: |
|
logger.warning('Detected mismatch params, auto adapte state_dict to current') |
|
common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params'] |
|
target_dict.update(common_dict) |
|
target.load_state_dict(target_dict) |
|
|
|
missing_keys = [k for k in target_dict.keys() if k not in common_dict] |
|
unexpected_keys = [k for k in state_dict.keys() if k not in common_dict] |
|
|
|
if len(unexpected_keys) != 0: |
|
logger.warning( |
|
f"Some weights of state_dict were not used in target: {unexpected_keys}" |
|
) |
|
if len(missing_keys) != 0: |
|
logger.warning( |
|
f"Some weights of state_dict are missing used in target {missing_keys}" |
|
) |
|
if len(unexpected_keys) == 0 and len(missing_keys) == 0: |
|
logger.warning("Strictly Loaded state_dict.") |
|
|
|
def set_seed(seed=42): |
|
random.seed(seed) |
|
os.environ['PYHTONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
def image2pil(filename): |
|
return Image.open(filename) |
|
|
|
|
|
def image2arr(filename): |
|
pil = image2pil(filename) |
|
return pil2arr(pil) |
|
|
|
|
|
|
|
def pil2arr(pil): |
|
if isinstance(pil, list): |
|
arr = np.array( |
|
[np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil]) |
|
else: |
|
arr = np.array(pil) |
|
return arr |
|
|
|
|
|
def arr2pil(arr): |
|
if arr.ndim == 3: |
|
return Image.fromarray(arr.astype('uint8'), 'RGB') |
|
elif arr.ndim == 4: |
|
return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)] |
|
else: |
|
raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim) |
|
|
|
def notebook_show(*images): |
|
from IPython.display import Image |
|
from IPython.display import display |
|
display(*[Image(e) for e in images]) |