import os os.system("pip install gradio==2.4.6") import torch import gradio as gr import numpy as np import torchvision.utils as vutils import torchvision.transforms as transforms from PIL import Image from torch.autograd import Variable from network.Transformer import Transformer LOAD_SIZE = 1280 STYLE = "Shinkai" MODEL_PATH = "pretrained_model" COLOUR_MODEL = "RGB" model = Transformer() model.load_state_dict( torch.load(os.path.join(MODEL_PATH, f"{STYLE}_net_G_float.pth")) ) model.eval() disable_gpu = torch.cuda.is_available() def inference(img): # load image input_image = img.convert(COLOUR_MODEL) input_image = np.asarray(input_image) # RGB -> BGR input_image = input_image[:, :, [2, 1, 0]] input_image = transforms.ToTensor()(input_image).unsqueeze(0) # preprocess, (-1, 1) input_image = -1 + 2 * input_image if disable_gpu: input_image = Variable(input_image).float() else: input_image = Variable(input_image).cuda() # forward output_image = model(input_image) output_image = output_image[0] # BGR -> RGB output_image = output_image[[2, 1, 0], :, :] output_image = output_image.data.cpu().float() * 0.5 + 0.5 return output_image title = "AnimeBackgroundGAN" description = "CartoonGAN from [Chen et.al](http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf) based on [Yijunmaverick's implementation](https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch)" article = "

Github Repo

visitor badge

" examples = [ ["examples/garden_in.jpeg", "examples/garden_out.jpg"], ["examples/library_in.jpeg", "examples/library_out.jpg"], ] gr.Interface( inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="pil"), title=title, description=description, article=article, examples=examples, allow_flagging=False, allow_screenshot=False, enable_queue=True, ).launch()