Spaces:
Build error
Build error
import torch | |
import os | |
import numpy as np | |
import torchvision.utils as vutils | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from torch.autograd import Variable | |
from network.Transformer import Transformer | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_dir", default="test_img") | |
parser.add_argument("--load_size", default=1280) | |
parser.add_argument("--model_path", default="./pretrained_model") | |
parser.add_argument("--style", default="Shinkai") | |
parser.add_argument("--output_dir", default="test_output") | |
parser.add_argument("--gpu", type=int, default=0) | |
opt = parser.parse_args() | |
valid_ext = [".jpg", ".png"] | |
# setup | |
if not os.path.exists(opt.input_dir): | |
os.makedirs(opt.input_dir) | |
if not os.path.exists(opt.output_dir): | |
os.makedirs(opt.output_dir) | |
# load pretrained model | |
model = Transformer() | |
model.load_state_dict( | |
torch.load(os.path.join(opt.model_path, opt.style + "_net_G_float.pth")) | |
) | |
model.eval() | |
disable_gpu = opt.gpu == -1 or not torch.cuda.is_available() | |
if disable_gpu: | |
print("CPU mode") | |
model.float() | |
else: | |
print("GPU mode") | |
model.cuda() | |
for files in os.listdir(opt.input_dir): | |
ext = os.path.splitext(files)[1] | |
if ext not in valid_ext: | |
continue | |
# load image | |
input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB") | |
input_image = np.asarray(input_image) | |
# RGB -> BGR | |
input_image = input_image[:, :, [2, 1, 0]] | |
input_image = transforms.ToTensor()(input_image).unsqueeze(0) | |
# preprocess, (-1, 1) | |
input_image = -1 + 2 * input_image | |
if disable_gpu: | |
input_image = Variable(input_image).float() | |
else: | |
input_image = Variable(input_image).cuda() | |
# forward | |
output_image = model(input_image) | |
output_image = output_image[0] | |
# BGR -> RGB | |
output_image = output_image[[2, 1, 0], :, :] | |
output_image = output_image.data.cpu().float() * 0.5 + 0.5 | |
# save | |
vutils.save_image( | |
output_image, | |
os.path.join(opt.output_dir, files[:-4] + "_" + opt.style + ".jpg"), | |
) | |
print("Done!") | |