File size: 2,157 Bytes
52d252c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os

os.system("pip install gradio==2.4.6")
import torch
import gradio as gr
import numpy as np
import torchvision.utils as vutils
import torchvision.transforms as transforms

from PIL import Image
from torch.autograd import Variable
from network.Transformer import Transformer

LOAD_SIZE = 1280
STYLE = "Shinkai"
MODEL_PATH = "pretrained_model"
COLOUR_MODEL = "RGB"

model = Transformer()
model.load_state_dict(
    torch.load(os.path.join(MODEL_PATH, f"{STYLE}_net_G_float.pth"))
)
model.eval()

disable_gpu = torch.cuda.is_available()


def inference(img):
    # load image
    input_image = img.convert(COLOUR_MODEL)
    input_image = np.asarray(input_image)
    # RGB -> BGR
    input_image = input_image[:, :, [2, 1, 0]]
    input_image = transforms.ToTensor()(input_image).unsqueeze(0)
    # preprocess, (-1, 1)
    input_image = -1 + 2 * input_image

    if disable_gpu:
        input_image = Variable(input_image).float()
    else:
        input_image = Variable(input_image).cuda()

    # forward
    output_image = model(input_image)
    output_image = output_image[0]
    # BGR -> RGB
    output_image = output_image[[2, 1, 0], :, :]
    output_image = output_image.data.cpu().float() * 0.5 + 0.5

    return output_image


title = "AnimeBackgroundGAN"
description = "CartoonGAN from [Chen et.al](http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf) based on [Yijunmaverick's implementation](https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch)"
article = "<p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"

examples = [
    ["examples/garden_in.jpeg", "examples/garden_out.jpg"],
    ["examples/library_in.jpeg", "examples/library_out.jpg"],
]


gr.Interface(
    inference,
    [gr.inputs.Image(type="pil")],
    gr.outputs.Image(type="pil"),
    title=title,
    description=description,
    article=article,
    examples=examples,
    allow_flagging=False,
    allow_screenshot=False,
    enable_queue=True,
).launch()