File size: 2,204 Bytes
52d252c
 
 
 
 
 
 
 
 
 
 
 
 
 
cd1eaaf
 
52d252c
 
bc8701f
 
 
52d252c
bc8701f
52d252c
 
 
bc8701f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52d252c
 
 
 
 
 
 
 
 
a3348bb
bc8701f
 
adf5cfd
 
9cf125f
 
52d252c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os

os.system("pip install gradio==2.4.6")
import torch
import gradio as gr
import numpy as np
import torchvision.utils as vutils
import torchvision.transforms as transforms

from PIL import Image
from torch.autograd import Variable
from network.Transformer import Transformer

LOAD_SIZE = 1280
STYLE = "shinkai_makoto"
MODEL_PATH = "models"
COLOUR_MODEL = "RGB"

model = Transformer()
model.load_state_dict(torch.load(os.path.join(MODEL_PATH, f"{STYLE}.pth")))
model.eval()

disable_gpu = torch.cuda.is_available()


def inference(img):
    # load image
    input_image = img.convert(COLOUR_MODEL)
    input_image = np.asarray(input_image)
    # RGB -> BGR
    input_image = input_image[:, :, [2, 1, 0]]
    input_image = transforms.ToTensor()(input_image).unsqueeze(0)
    # preprocess, (-1, 1)
    input_image = -1 + 2 * input_image

    if disable_gpu:
        input_image = Variable(input_image).float()
    else:
        input_image = Variable(input_image).cuda()

    # forward
    output_image = model(input_image)
    output_image = output_image[0]
    # BGR -> RGB
    output_image = output_image[[2, 1, 0], :, :]
    output_image = output_image.data.cpu().float() * 0.5 + 0.5

    return output_image


title = "Anime Background GAN"
description = "<a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN from Chen et.al</a> based on <a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Yijunmaverick's implementation</a>."
article = "<p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"

examples = [
    ["examples/garden_in.jpeg", "examples/garden_out.jpg"],
    ["examples/library_in.jpeg", "examples/library_out.jpg"],
]


gr.Interface(
    fn=inference,
    inputs=[gr.inputs.Image(type="pil")],
    outputs=gr.outputs.Image(type="pil"),
    title=title,
    description=description,
    article=None,
    examples=None,
    allow_flagging=False,
    allow_screenshot=False,
    enable_queue=True,
).launch()