秋山翔 commited on
Commit
52d252c
1 Parent(s): 9d1fa22

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env/
2
+ .idea
3
+
4
+ # Models
5
+ pretrained_model/*.pth
6
+
7
+ # Output
8
+ test_img/*
9
+ test_output/*
10
+
11
+ .history/
12
+ .vscode/
13
+ __pycache__/
14
+ .DS_Store
README.md CHANGED
@@ -1,37 +1,32 @@
1
  ---
2
  title: AnimeBackgroundGAN
3
- emoji: 🏢
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
  app_file: app.py
8
- pinned: false
9
  ---
10
 
11
  # Configuration
12
 
13
  `title`: _string_
14
- Display title for the Space
15
 
16
  `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
 
19
  `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
 
22
  `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
 
25
  `sdk`: _string_
26
- Can be either `gradio`, `streamlit`, or `static`
27
-
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
 
32
  `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
34
- Path is relative to the root of the repository.
35
 
36
  `pinned`: _boolean_
37
- Whether the Space stays on top of your list.
 
1
  ---
2
  title: AnimeBackgroundGAN
3
+ emoji: 🖼
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
  app_file: app.py
8
+ pinned: true
9
  ---
10
 
11
  # Configuration
12
 
13
  `title`: _string_
14
+ Anime Background GAN
15
 
16
  `emoji`: _string_
17
+ 🖼
18
 
19
  `colorFrom`: _string_
20
+ red
21
 
22
  `colorTo`: _string_
23
+ indigo
24
 
25
  `sdk`: _string_
26
+ gradio
 
 
 
 
27
 
28
  `app_file`: _string_
29
+ app.py
 
30
 
31
  `pinned`: _boolean_
32
+ true
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("pip install gradio==2.4.6")
4
+ import torch
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torchvision.utils as vutils
8
+ import torchvision.transforms as transforms
9
+
10
+ from PIL import Image
11
+ from torch.autograd import Variable
12
+ from network.Transformer import Transformer
13
+
14
+ LOAD_SIZE = 1280
15
+ STYLE = "Shinkai"
16
+ MODEL_PATH = "pretrained_model"
17
+ COLOUR_MODEL = "RGB"
18
+
19
+ model = Transformer()
20
+ model.load_state_dict(
21
+ torch.load(os.path.join(MODEL_PATH, f"{STYLE}_net_G_float.pth"))
22
+ )
23
+ model.eval()
24
+
25
+ disable_gpu = torch.cuda.is_available()
26
+
27
+
28
+ def inference(img):
29
+ # load image
30
+ input_image = img.convert(COLOUR_MODEL)
31
+ input_image = np.asarray(input_image)
32
+ # RGB -> BGR
33
+ input_image = input_image[:, :, [2, 1, 0]]
34
+ input_image = transforms.ToTensor()(input_image).unsqueeze(0)
35
+ # preprocess, (-1, 1)
36
+ input_image = -1 + 2 * input_image
37
+
38
+ if disable_gpu:
39
+ input_image = Variable(input_image).float()
40
+ else:
41
+ input_image = Variable(input_image).cuda()
42
+
43
+ # forward
44
+ output_image = model(input_image)
45
+ output_image = output_image[0]
46
+ # BGR -> RGB
47
+ output_image = output_image[[2, 1, 0], :, :]
48
+ output_image = output_image.data.cpu().float() * 0.5 + 0.5
49
+
50
+ return output_image
51
+
52
+
53
+ title = "AnimeBackgroundGAN"
54
+ description = "CartoonGAN from [Chen et.al](http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf) based on [Yijunmaverick's implementation](https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch)"
55
+ article = "<p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
56
+
57
+ examples = [
58
+ ["examples/garden_in.jpeg", "examples/garden_out.jpg"],
59
+ ["examples/library_in.jpeg", "examples/library_out.jpg"],
60
+ ]
61
+
62
+
63
+ gr.Interface(
64
+ inference,
65
+ [gr.inputs.Image(type="pil")],
66
+ gr.outputs.Image(type="pil"),
67
+ title=title,
68
+ description=description,
69
+ article=article,
70
+ examples=examples,
71
+ allow_flagging=False,
72
+ allow_screenshot=False,
73
+ enable_queue=True,
74
+ ).launch()
examples/garden_in.jpg ADDED
examples/garden_out.jpg ADDED
examples/library_in.jpg ADDED
examples/library_out.jpg ADDED
main.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import torchvision.utils as vutils
5
+
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ from torch.autograd import Variable
9
+
10
+ from network.Transformer import Transformer
11
+ import argparse
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", default="test_img")
15
+ parser.add_argument("--load_size", default=1280)
16
+ parser.add_argument("--model_path", default="./pretrained_model")
17
+ parser.add_argument("--style", default="Shinkai")
18
+ parser.add_argument("--output_dir", default="test_output")
19
+ parser.add_argument("--gpu", type=int, default=0)
20
+
21
+ opt = parser.parse_args()
22
+
23
+ valid_ext = [".jpg", ".png"]
24
+
25
+ # setup
26
+ if not os.path.exists(opt.input_dir):
27
+ os.makedirs(opt.input_dir)
28
+ if not os.path.exists(opt.output_dir):
29
+ os.makedirs(opt.output_dir)
30
+
31
+ # load pretrained model
32
+ model = Transformer()
33
+ model.load_state_dict(
34
+ torch.load(os.path.join(opt.model_path, opt.style + "_net_G_float.pth"))
35
+ )
36
+ model.eval()
37
+
38
+ disable_gpu = opt.gpu == -1 or not torch.cuda.is_available()
39
+
40
+ if disable_gpu:
41
+ print("CPU mode")
42
+ model.float()
43
+ else:
44
+ print("GPU mode")
45
+ model.cuda()
46
+
47
+ for files in os.listdir(opt.input_dir):
48
+ ext = os.path.splitext(files)[1]
49
+ if ext not in valid_ext:
50
+ continue
51
+ # load image
52
+ input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
53
+ input_image = np.asarray(input_image)
54
+ # RGB -> BGR
55
+ input_image = input_image[:, :, [2, 1, 0]]
56
+ input_image = transforms.ToTensor()(input_image).unsqueeze(0)
57
+ # preprocess, (-1, 1)
58
+ input_image = -1 + 2 * input_image
59
+ if disable_gpu:
60
+ input_image = Variable(input_image).float()
61
+ else:
62
+ input_image = Variable(input_image).cuda()
63
+
64
+ # forward
65
+ output_image = model(input_image)
66
+ output_image = output_image[0]
67
+ # BGR -> RGB
68
+ output_image = output_image[[2, 1, 0], :, :]
69
+ output_image = output_image.data.cpu().float() * 0.5 + 0.5
70
+ # save
71
+ vutils.save_image(
72
+ output_image,
73
+ os.path.join(opt.output_dir, files[:-4] + "_" + opt.style + ".jpg"),
74
+ )
75
+
76
+ print("Done!")
network/Transformer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Transformer(nn.Module):
7
+ def __init__(self):
8
+ super(Transformer, self).__init__()
9
+ #
10
+ self.refpad01_1 = nn.ReflectionPad2d(3)
11
+ self.conv01_1 = nn.Conv2d(3, 64, 7)
12
+ self.in01_1 = InstanceNormalization(64)
13
+ # relu
14
+ self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1)
15
+ self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1)
16
+ self.in02_1 = InstanceNormalization(128)
17
+ # relu
18
+ self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1)
19
+ self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1)
20
+ self.in03_1 = InstanceNormalization(256)
21
+ # relu
22
+
23
+ ## res block 1
24
+ self.refpad04_1 = nn.ReflectionPad2d(1)
25
+ self.conv04_1 = nn.Conv2d(256, 256, 3)
26
+ self.in04_1 = InstanceNormalization(256)
27
+ # relu
28
+ self.refpad04_2 = nn.ReflectionPad2d(1)
29
+ self.conv04_2 = nn.Conv2d(256, 256, 3)
30
+ self.in04_2 = InstanceNormalization(256)
31
+ # + input
32
+
33
+ ## res block 2
34
+ self.refpad05_1 = nn.ReflectionPad2d(1)
35
+ self.conv05_1 = nn.Conv2d(256, 256, 3)
36
+ self.in05_1 = InstanceNormalization(256)
37
+ # relu
38
+ self.refpad05_2 = nn.ReflectionPad2d(1)
39
+ self.conv05_2 = nn.Conv2d(256, 256, 3)
40
+ self.in05_2 = InstanceNormalization(256)
41
+ # + input
42
+
43
+ ## res block 3
44
+ self.refpad06_1 = nn.ReflectionPad2d(1)
45
+ self.conv06_1 = nn.Conv2d(256, 256, 3)
46
+ self.in06_1 = InstanceNormalization(256)
47
+ # relu
48
+ self.refpad06_2 = nn.ReflectionPad2d(1)
49
+ self.conv06_2 = nn.Conv2d(256, 256, 3)
50
+ self.in06_2 = InstanceNormalization(256)
51
+ # + input
52
+
53
+ ## res block 4
54
+ self.refpad07_1 = nn.ReflectionPad2d(1)
55
+ self.conv07_1 = nn.Conv2d(256, 256, 3)
56
+ self.in07_1 = InstanceNormalization(256)
57
+ # relu
58
+ self.refpad07_2 = nn.ReflectionPad2d(1)
59
+ self.conv07_2 = nn.Conv2d(256, 256, 3)
60
+ self.in07_2 = InstanceNormalization(256)
61
+ # + input
62
+
63
+ ## res block 5
64
+ self.refpad08_1 = nn.ReflectionPad2d(1)
65
+ self.conv08_1 = nn.Conv2d(256, 256, 3)
66
+ self.in08_1 = InstanceNormalization(256)
67
+ # relu
68
+ self.refpad08_2 = nn.ReflectionPad2d(1)
69
+ self.conv08_2 = nn.Conv2d(256, 256, 3)
70
+ self.in08_2 = InstanceNormalization(256)
71
+ # + input
72
+
73
+ ## res block 6
74
+ self.refpad09_1 = nn.ReflectionPad2d(1)
75
+ self.conv09_1 = nn.Conv2d(256, 256, 3)
76
+ self.in09_1 = InstanceNormalization(256)
77
+ # relu
78
+ self.refpad09_2 = nn.ReflectionPad2d(1)
79
+ self.conv09_2 = nn.Conv2d(256, 256, 3)
80
+ self.in09_2 = InstanceNormalization(256)
81
+ # + input
82
+
83
+ ## res block 7
84
+ self.refpad10_1 = nn.ReflectionPad2d(1)
85
+ self.conv10_1 = nn.Conv2d(256, 256, 3)
86
+ self.in10_1 = InstanceNormalization(256)
87
+ # relu
88
+ self.refpad10_2 = nn.ReflectionPad2d(1)
89
+ self.conv10_2 = nn.Conv2d(256, 256, 3)
90
+ self.in10_2 = InstanceNormalization(256)
91
+ # + input
92
+
93
+ ## res block 8
94
+ self.refpad11_1 = nn.ReflectionPad2d(1)
95
+ self.conv11_1 = nn.Conv2d(256, 256, 3)
96
+ self.in11_1 = InstanceNormalization(256)
97
+ # relu
98
+ self.refpad11_2 = nn.ReflectionPad2d(1)
99
+ self.conv11_2 = nn.Conv2d(256, 256, 3)
100
+ self.in11_2 = InstanceNormalization(256)
101
+ # + input
102
+
103
+ ##------------------------------------##
104
+ self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
105
+ self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1)
106
+ self.in12_1 = InstanceNormalization(128)
107
+ # relu
108
+ self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
109
+ self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1)
110
+ self.in13_1 = InstanceNormalization(64)
111
+ # relu
112
+ self.refpad12_1 = nn.ReflectionPad2d(3)
113
+ self.deconv03_1 = nn.Conv2d(64, 3, 7)
114
+ # tanh
115
+
116
+ def forward(self, x):
117
+ y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x))))
118
+ y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y))))
119
+ t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y))))
120
+
121
+ ##
122
+ y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04))))
123
+ t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04
124
+
125
+ y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05))))
126
+ t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05
127
+
128
+ y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06))))
129
+ t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06
130
+
131
+ y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07))))
132
+ t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07
133
+
134
+ y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08))))
135
+ t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08
136
+
137
+ y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09))))
138
+ t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09
139
+
140
+ y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10))))
141
+ t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10
142
+
143
+ y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11))))
144
+ y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11
145
+ ##
146
+
147
+ y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y))))
148
+ y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y))))
149
+ y = F.tanh(self.deconv03_1(self.refpad12_1(y)))
150
+
151
+ return y
152
+
153
+
154
+ class InstanceNormalization(nn.Module):
155
+ def __init__(self, dim, eps=1e-9):
156
+ super(InstanceNormalization, self).__init__()
157
+ self.scale = nn.Parameter(torch.FloatTensor(dim))
158
+ self.shift = nn.Parameter(torch.FloatTensor(dim))
159
+ self.eps = eps
160
+ self._reset_parameters()
161
+
162
+ def _reset_parameters(self):
163
+ self.scale.data.uniform_()
164
+ self.shift.data.zero_()
165
+
166
+ def __call__(self, x):
167
+ n = x.size(2) * x.size(3)
168
+ t = x.view(x.size(0), x.size(1), n)
169
+ mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
170
+ # Calculate the biased var. torch.var returns unbiased var
171
+ var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * (
172
+ (n - 1) / float(n)
173
+ )
174
+ scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0)
175
+ scale_broadcast = scale_broadcast.expand_as(x)
176
+ shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0)
177
+ shift_broadcast = shift_broadcast.expand_as(x)
178
+ out = (x - mean) / torch.sqrt(var + self.eps)
179
+ out = out * scale_broadcast + shift_broadcast
180
+ return out
network/__init__.py ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.9.0
2
+ torchvision==0.11.2
3
+ pillow==9.0.0
requirements_dev.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -r requirements.txt
2
+ black==21.12b0
3
+ flake8==4.0.1