Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
import torch.distributed as dist | |
from torch.utils.data import Dataset, DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from torchvision.datasets import ImageFolder | |
from torchvision import transforms | |
from tqdm import tqdm | |
import os | |
from PIL import Image | |
import numpy as np | |
import itertools | |
import argparse | |
import random | |
from skimage.metrics import peak_signal_noise_ratio as psnr_loss | |
from skimage.metrics import structural_similarity as ssim_loss | |
from omegaconf import OmegaConf | |
from tokenizer.vqgan.model import VQModel | |
from tokenizer.vqgan.model import VQGAN_FROM_TAMING | |
class SingleFolderDataset(Dataset): | |
def __init__(self, directory, transform=None): | |
super().__init__() | |
self.directory = directory | |
self.transform = transform | |
self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) | |
if os.path.isfile(os.path.join(directory, file_name))] | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image_path = self.image_paths[idx] | |
image = Image.open(image_path).convert('RGB') | |
if self.transform: | |
image = self.transform(image) | |
return image, torch.tensor(0) | |
def create_npz_from_sample_folder(sample_dir, num=50_000): | |
""" | |
Builds a single .npz file from a folder of .png samples. | |
""" | |
samples = [] | |
for i in tqdm(range(num), desc="Building .npz file from samples"): | |
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") | |
sample_np = np.asarray(sample_pil).astype(np.uint8) | |
samples.append(sample_np) | |
random.shuffle(samples) # This is very important for IS(Inception Score) !!! | |
samples = np.stack(samples) | |
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) | |
npz_path = f"{sample_dir}.npz" | |
np.savez(npz_path, arr_0=samples) | |
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") | |
return npz_path | |
def center_crop_arr(pil_image, image_size): | |
""" | |
Center cropping implementation from ADM. | |
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | |
""" | |
while min(*pil_image.size) >= 2 * image_size: | |
pil_image = pil_image.resize( | |
tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
) | |
scale = image_size / min(*pil_image.size) | |
pil_image = pil_image.resize( | |
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
) | |
arr = np.array(pil_image) | |
crop_y = (arr.shape[0] - image_size) // 2 | |
crop_x = (arr.shape[1] - image_size) // 2 | |
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) | |
def main(args): | |
# Setup PyTorch: | |
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" | |
torch.set_grad_enabled(False) | |
# Setup DDP: | |
dist.init_process_group("nccl") | |
rank = dist.get_rank() | |
device = rank % torch.cuda.device_count() | |
seed = args.global_seed * dist.get_world_size() + rank | |
torch.manual_seed(seed) | |
torch.cuda.set_device(device) | |
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") | |
# create and load vqgan | |
cfg, ckpt = VQGAN_FROM_TAMING[args.vqgan] | |
config = OmegaConf.load(cfg) | |
vq_model = VQModel(**config.model.get("params", dict())).to(device) | |
vq_model.init_from_ckpt(ckpt, logging=False) | |
vq_model.eval() | |
# Create folder to save samples: | |
folder_name = f"{args.vqgan}-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" | |
sample_folder_dir = f"{args.sample_dir}/{folder_name}" | |
if rank == 0: | |
os.makedirs(sample_folder_dir, exist_ok=True) | |
print(f"Saving .png samples at {sample_folder_dir}") | |
dist.barrier() | |
# Setup data: | |
transform = transforms.Compose([ | |
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
]) | |
if args.dataset == 'imagenet': | |
dataset = ImageFolder(args.data_path, transform=transform) | |
num_fid_samples = 50000 | |
elif args.dataset == 'coco': | |
dataset = SingleFolderDataset(args.data_path, transform=transform) | |
num_fid_samples = 5000 | |
else: | |
raise Exception("please check dataset") | |
sampler = DistributedSampler( | |
dataset, | |
num_replicas=dist.get_world_size(), | |
rank=rank, | |
shuffle=False, | |
seed=args.global_seed | |
) | |
loader = DataLoader( | |
dataset, | |
batch_size=args.per_proc_batch_size, | |
shuffle=False, | |
sampler=sampler, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False | |
) | |
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run: | |
n = args.per_proc_batch_size | |
global_batch_size = n * dist.get_world_size() | |
psnr_val_rgb = [] | |
ssim_val_rgb = [] | |
loader = tqdm(loader) if rank == 0 else loader | |
total = 0 | |
for x, _ in loader: | |
rgb_gts = x | |
rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] | |
x = x.to(device) | |
with torch.no_grad(): | |
latent, _, [_, _, indices] = vq_model.encode(x) | |
samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1] | |
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() | |
# Save samples to disk as individual .png files | |
for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): | |
index = i * dist.get_world_size() + rank + total | |
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") | |
# metric | |
rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] | |
psnr = psnr_loss(rgb_restored, rgb_gt) | |
ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) | |
psnr_val_rgb.append(psnr) | |
ssim_val_rgb.append(ssim) | |
total += global_batch_size | |
# ------------------------------------ | |
# Summary | |
# ------------------------------------ | |
# Make sure all processes have finished saving their samples | |
dist.barrier() | |
world_size = dist.get_world_size() | |
gather_psnr_val = [None for _ in range(world_size)] | |
gather_ssim_val = [None for _ in range(world_size)] | |
dist.all_gather_object(gather_psnr_val, psnr_val_rgb) | |
dist.all_gather_object(gather_ssim_val, ssim_val_rgb) | |
if rank == 0: | |
gather_psnr_val = list(itertools.chain(*gather_psnr_val)) | |
gather_ssim_val = list(itertools.chain(*gather_ssim_val)) | |
psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) | |
ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) | |
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) | |
result_file = f"{sample_folder_dir}_results.txt" | |
print("writing results to {}".format(result_file)) | |
with open(result_file, 'w') as f: | |
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) | |
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) | |
print("Done.") | |
dist.barrier() | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data-path", type=str, required=True) | |
parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') | |
parser.add_argument("--vqgan", type=str, choices=list(VQGAN_FROM_TAMING.keys()), default="vqgan_imagenet_f16_16384") | |
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) | |
parser.add_argument("--sample-dir", type=str, default="reconstructions") | |
parser.add_argument("--per-proc-batch-size", type=int, default=32) | |
parser.add_argument("--global-seed", type=int, default=0) | |
parser.add_argument("--num-workers", type=int, default=4) | |
args = parser.parse_args() | |
main(args) |