Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,117 Bytes
c9cc441 bd4674a 721391f c9cc441 e7d7d62 bd4674a fca9328 bd4674a fca9328 bd4674a d7efa60 bd4674a c9cc441 099b913 c9cc441 bd4674a c9cc441 bd4674a c9cc441 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""Test script for anime-to-sketch translation
Example:
python3 test.py --dataroot /your_path/dir --load_size 512
python3 test.py --dataroot /your_path/img.jpg --load_size 512
"""
import os
import torch
from scripts.data import get_image_list, get_transform, tensor_to_img, save_image
from scripts.model import create_model
import argparse
from tqdm.auto import tqdm
from kornia.enhance import equalize_clahe
from PIL import Image
import numpy as np
import spaces
model = None
device = None
def init_model(use_local=False):
global model, device
model_opt = "default"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # issue: nevetherless, use_gpu is False, it still uses GPU
model = create_model(model_opt, use_local).to(device)
model.eval()
# numpy配列の画像を受け取り、線画を生成してnumpy配列で返す
@spaces.GPU
def generate_sketch(image, clahe_clip=-1, load_size=512):
"""
Generate sketch image from input image
Args:
image (np.ndarray): input image
clahe_clip (float): clip threshold for CLAHE
load_size (int): image size to load
Returns:
np.ndarray: output image
"""
# create model
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model_opt = "default"
# model = create_model(model_opt).to(device)
# model.eval()
aus_resize = None
if load_size > 0:
aus_resize = (image.shape[0], image.shape[1])
transform = get_transform(load_size=load_size)
image = torch.from_numpy(image).permute(2, 0, 1).float()
# [0,255] to [-1,1]
image = transform(image)
if image.max() > 1:
image = (image-image.min())/(image.max()-image.min())*2-1
img, aus_resize = image.unsqueeze(0), aus_resize
if clahe_clip > 0:
img = (img + 1) / 2 # [-1,1] to [0,1]
img = equalize_clahe(img, clip_limit=clahe_clip)
img = (img - .5) / .5 # [0,1] to [-1,1]
aus_tensor = model(img.to(device))
# resize to original size
if aus_resize is not None:
aus_tensor = torch.nn.functional.interpolate(aus_tensor, aus_resize, mode='bilinear', align_corners=False)
aus_img = tensor_to_img(aus_tensor)
return aus_img
if __name__ == '__main__':
os.chdir(os.path.dirname("Anime2Sketch/"))
parser = argparse.ArgumentParser(description='Anime-to-sketch test options.')
parser.add_argument('--dataroot','-i', default='test_samples/', type=str)
parser.add_argument('--load_size','-s', default=512, type=int)
parser.add_argument('--output_dir','-o', default='results/', type=str)
parser.add_argument('--gpu_ids', '-g', default=[], help="gpu ids: e.g. 0 0,1,2 0,2.")
parser.add_argument('--model', default="default", type=str, help="variant of model to use. you can choose from ['default','improved']")
parser.add_argument('--clahe_clip', default=-1, type=float, help="clip threshold for CLAHE set to -1 to disable")
opt = parser.parse_args()
# # generate sketchで線画生成
# for test_path in tqdm(get_image_list(opt.dataroot)):
# basename = os.path.basename(test_path)
# aus_path = os.path.join(opt.output_dir, basename)
# # numpy配列で画像を読み込む
# img = Image.open(test_path)
# img = np.array(img)
# aus_img = generate_sketch(img, opt.clahe_clip)
# # 画像を保存
# save_image(aus_img, aus_path, (512, 512))
# create model
gpu_list = ','.join(str(x) for x in opt.gpu_ids)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = create_model(opt.model, use_local=True).to(device) # create a model given opt.model and other options
model.eval()
for test_path in tqdm(get_image_list(opt.dataroot)):
basename = os.path.basename(test_path)
aus_path = os.path.join(opt.output_dir, basename)
img = Image.open(test_path).convert('RGB')
img = np.array(img)
load_size = 512
aus_resize = None
if load_size > 0:
aus_resize = (img.shape[1], img.shape[0])
transform = get_transform(load_size=load_size)
img = torch.from_numpy(img).permute(2, 0, 1).float()
# [0,255] to [-1,1]
image = transform(img)
if image.max() > 1:
image = (image-image.min())/(image.max()-image.min())*2-1
print(image.min(), image.max())
img, aus_resize = image.unsqueeze(0), aus_resize
if opt.clahe_clip > 0:
img = (img + 1) / 2 # [-1,1] to [0,1]
img = equalize_clahe(img, clip_limit=opt.clahe_clip)
img = (img - .5) / .5 # [0,1] to [-1,1]
aus_tensor = model(img.to(device))
aus_img = tensor_to_img(aus_tensor)
save_image(aus_img, aus_path, aus_resize)
"""
# create model
gpu_list = ','.join(str(x) for x in opt.gpu_ids)
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
device = torch.device('cuda' if len(opt.gpu_ids)>0 else 'cpu')
model = create_model(opt.model).to(device) # create a model given opt.model and other options
model.eval()
# get input data
if os.path.isdir(opt.dataroot):
test_list = get_image_list(opt.dataroot)
elif os.path.isfile(opt.dataroot):
test_list = [opt.dataroot]
else:
raise Exception("{} is not a valid directory or image file.".format(opt.dataroot))
# save outputs
save_dir = opt.output_dir
os.makedirs(save_dir, exist_ok=True)
for test_path in tqdm(test_list):
basename = os.path.basename(test_path)
aus_path = os.path.join(save_dir, basename)
img, aus_resize = read_img_path(test_path, opt.load_size)
if opt.clahe_clip > 0:
img = (img + 1) / 2 # [-1,1] to [0,1]
img = equalize_clahe(img, clip_limit=opt.clahe_clip)
img = (img - .5) / .5 # [0,1] to [-1,1]
aus_tensor = model(img.to(device))
print(aus_tensor.shape)
aus_img = tensor_to_img(aus_tensor)
save_image(aus_img, aus_path, aus_resize)
"""
|