Spaces:
Sleeping
Sleeping
import random | |
import json | |
import numpy as np | |
from pathlib import Path | |
from typing import Iterable | |
from omegaconf import ListConfig | |
import cv2 | |
import torch | |
from functools import partial | |
import torchvision as thv | |
from torch.utils.data import Dataset | |
from utils import util_sisr | |
from utils import util_image | |
from utils import util_common | |
from basicsr.data.transforms import augment | |
from basicsr.data.realesrgan_dataset import RealESRGANDataset | |
def get_transforms(transform_type, kwargs): | |
''' | |
Accepted optins in kwargs. | |
mean: scaler or sequence, for nornmalization | |
std: scaler or sequence, for nornmalization | |
crop_size: int or sequence, random or center cropping | |
scale, out_shape: for Bicubic | |
min_max: tuple or list with length 2, for cliping | |
''' | |
if transform_type == 'default': | |
transform = thv.transforms.Compose([ | |
thv.transforms.ToTensor(), | |
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
]) | |
elif transform_type == 'resize_ccrop_norm': | |
transform = thv.transforms.Compose([ | |
util_image.SmallestMaxSize( | |
max_size=kwargs.get('size'), | |
interpolation=kwargs.get('interpolation'), | |
), | |
thv.transforms.ToTensor(), | |
thv.transforms.CenterCrop(size=kwargs.get('size', None)), | |
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
]) | |
elif transform_type == 'ccrop_norm': | |
transform = thv.transforms.Compose([ | |
thv.transforms.ToTensor(), | |
thv.transforms.CenterCrop(size=kwargs.get('size', None)), | |
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
]) | |
elif transform_type == 'rcrop_aug_norm': | |
transform = thv.transforms.Compose([ | |
util_image.RandomCrop(pch_size=kwargs.get('pch_size', 256)), | |
util_image.SpatialAug( | |
only_hflip=kwargs.get('only_hflip', False), | |
only_vflip=kwargs.get('only_vflip', False), | |
only_hvflip=kwargs.get('only_hvflip', False), | |
), | |
util_image.ToTensor(max_value=kwargs.get('max_value')), # (ndarray, hwc) --> (Tensor, chw) | |
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
]) | |
elif transform_type == 'aug_norm': | |
transform = thv.transforms.Compose([ | |
util_image.SpatialAug( | |
only_hflip=kwargs.get('only_hflip', False), | |
only_vflip=kwargs.get('only_vflip', False), | |
only_hvflip=kwargs.get('only_hvflip', False), | |
), | |
util_image.ToTensor(), # hwc --> chw | |
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
]) | |
else: | |
raise ValueError(f'Unexpected transform_variant {transform_variant}') | |
return transform | |
def create_dataset(dataset_config): | |
if dataset_config['type'] == 'base': | |
dataset = BaseData(**dataset_config['params']) | |
elif dataset_config['type'] == 'base_meta': | |
dataset = BaseDataMetaCond(**dataset_config['params']) | |
elif dataset_config['type'] == 'realesrgan': | |
dataset = RealESRGANDataset(dataset_config['params']) | |
else: | |
raise NotImplementedError(f"{dataset_config['type']}") | |
return dataset | |
class BaseData(Dataset): | |
def __init__( | |
self, | |
dir_path, | |
txt_path=None, | |
transform_type='default', | |
transform_kwargs={'mean':0.0, 'std':1.0}, | |
extra_dir_path=None, | |
extra_transform_type=None, | |
extra_transform_kwargs=None, | |
length=None, | |
need_path=False, | |
im_exts=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'], | |
recursive=False, | |
): | |
super().__init__() | |
file_paths_all = [] | |
if dir_path is not None: | |
file_paths_all.extend(util_common.scan_files_from_folder(dir_path, im_exts, recursive)) | |
if txt_path is not None: | |
file_paths_all.extend(util_common.readline_txt(txt_path)) | |
self.file_paths = file_paths_all if length is None else random.sample(file_paths_all, length) | |
self.file_paths_all = file_paths_all | |
self.length = length | |
self.need_path = need_path | |
self.transform = get_transforms(transform_type, transform_kwargs) | |
self.extra_dir_path = extra_dir_path | |
if extra_dir_path is not None: | |
assert extra_transform_type is not None | |
self.extra_transform = get_transforms(extra_transform_type, extra_transform_kwargs) | |
def __len__(self): | |
return len(self.file_paths) | |
def __getitem__(self, index): | |
im_path_base = self.file_paths[index] | |
im_base = util_image.imread(im_path_base, chn='rgb', dtype='float32') | |
im_target = self.transform(im_base) | |
out = {'image':im_target, 'lq':im_target} | |
if self.extra_dir_path is not None: | |
im_path_extra = Path(self.extra_dir_path) / Path(im_path_base).name | |
im_extra = util_image.imread(im_path_extra, chn='rgb', dtype='float32') | |
im_extra = self.extra_transform(im_extra) | |
out['gt'] = im_extra | |
if self.need_path: | |
out['path'] = im_path_base | |
return out | |
def reset_dataset(self): | |
self.file_paths = random.sample(self.file_paths_all, self.length) | |
class BaseDataMetaCond(Dataset): | |
def __init__( | |
self, | |
meta_dir, | |
transform_type='default', | |
transform_kwargs={'mean':0.5, 'std':0.5}, | |
length=None, | |
need_path=False, | |
cond_key='canny', | |
cond_transform_type='default', | |
cond_transform_kwargs={'mean':0.5, 'std':0.5}, | |
): | |
super().__init__() | |
if not isinstance(meta_dir, ListConfig): | |
meta_dir = [meta_dir,] | |
meta_list = [] | |
# for current_dir in meta_dir: | |
# for json_path in Path(current_dir).glob("*.json"): | |
# with open(json_path, 'r') as json_file: | |
# meta_info = json.load(json_file) | |
# meta_list.append(meta_info) | |
for current_dir in meta_dir: | |
meta_list.extend(sorted([str(x) for x in Path(current_dir).glob("*.json")])) | |
self.meta_list = meta_list if length is None else meta_list[:length] | |
self.cond_key = cond_key | |
self.length = length | |
self.need_path = need_path | |
self.transform = get_transforms(transform_type, transform_kwargs) | |
self.cond_trasform = get_transforms(cond_transform_type, cond_transform_kwargs) | |
def __len__(self): | |
return len(self.meta_list) | |
def __getitem__(self, index): | |
# meta_info = self.meta_list[index] | |
json_path = self.meta_list[index] | |
with open(json_path, 'r') as json_file: | |
meta_info = json.load(json_file) | |
# images | |
im_path = meta_info['source'] | |
im_source = util_image.imread(im_path, chn='rgb', dtype='uint8') | |
im_source = self.transform(im_source) | |
out = {'image': im_source,} | |
if self.need_path: | |
out['path'] = im_path | |
# latent | |
if 'latent' in meta_info: | |
latent_path = meta_info['latent'] | |
out['latent'] = np.load(latent_path) | |
# prompt | |
out['txt'] = meta_info['prompt'] | |
# condition | |
cond_key = self.cond_key | |
cond_path = meta_info[cond_key] | |
if cond_key == 'canny': | |
cond = util_image.imread(cond_path, chn='gray', dtype='uint8')[:, :, None] | |
elif cond_key == 'seg': | |
cond = util_image.imread(cond_path, chn='rgb', dtype='uint8') | |
else: | |
raise ValueError(f"Unexpected cond key: {cond_key}") | |
cond = self.cond_trasform(cond) | |
out['cond'] = cond | |
return out | |