秋山翔 commited on
Commit
0cc1558
1 Parent(s): 0778fc2

FEAT: added style choices

Browse files
Files changed (1) hide show
  1. app.py +62 -7
app.py CHANGED
@@ -8,19 +8,65 @@ import torchvision.transforms as transforms
8
  from torch.autograd import Variable
9
  from network.Transformer import Transformer
10
 
 
 
 
 
11
  LOAD_SIZE = 1280
12
- STYLE = "shinkai_makoto"
13
  MODEL_PATH = "models"
14
  COLOUR_MODEL = "RGB"
15
 
16
- model = Transformer()
17
- model.load_state_dict(torch.load(os.path.join(MODEL_PATH, f"{STYLE}.pth")))
18
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  disable_gpu = True
21
 
22
 
23
- def inference(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # load image
25
  input_image = img.convert(COLOUR_MODEL)
26
  input_image = np.asarray(input_image)
@@ -36,6 +82,7 @@ def inference(img):
36
  input_image = Variable(input_image).cuda()
37
 
38
  # forward
 
39
  output_image = model(input_image)
40
  output_image = output_image[0]
41
  # BGR -> RGB
@@ -47,7 +94,7 @@ def inference(img):
47
 
48
  title = "Anime Background GAN"
49
  description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao."
50
- article = "<p style='text-align: center'><a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN from Chen et.al</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p><p style='text-align: center'><a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Original Implementation from Yijunmaverick</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
51
 
52
  examples = [
53
  ["examples/garden_in.jpg"],
@@ -57,7 +104,15 @@ examples = [
57
 
58
  gr.Interface(
59
  fn=inference,
60
- inputs=[gr.inputs.Image(type="pil")],
 
 
 
 
 
 
 
 
61
  outputs=gr.outputs.Image(type="pil"),
62
  title=title,
63
  description=description,
 
8
  from torch.autograd import Variable
9
  from network.Transformer import Transformer
10
 
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
  LOAD_SIZE = 1280
 
16
  MODEL_PATH = "models"
17
  COLOUR_MODEL = "RGB"
18
 
19
+ STYLE_SHINKAI = "Makoto Shinkai"
20
+ STYLE_HOSODA = "Mamoru Hosoda"
21
+ STYLE_MIYAZAKI = "Hayao Miyazaki"
22
+ STYLE_KON = "Satoshi Kon"
23
+ DEFAULT_STYLE = STYLE_SHINKAI
24
+ STYLE_CHOICE_LIST = [STYLE_SHINKAI, STYLE_HOSODA, STYLE_MIYAZAKI, STYLE_KON]
25
+
26
+ shinkai_model = Transformer()
27
+ hosoda_model = Transformer()
28
+ miyazaki_model = Transformer()
29
+ kon_model = Transformer()
30
+
31
+
32
+ shinkai_model.load_state_dict(
33
+ torch.load(os.path.join(MODEL_PATH, "shinkai_makoto.pth"))
34
+ )
35
+ hosoda_model.load_state_dict(
36
+ torch.load(os.path.join(MODEL_PATH, "hosoda_mamoru.pth"))
37
+ )
38
+ miyazaki_model.load_state_dict(
39
+ torch.load(os.path.join(MODEL_PATH, "miyazaki_hayao.pth"))
40
+ )
41
+ kon_model.load_state_dict(
42
+ torch.load(os.path.join(MODEL_PATH, "kon_satoshi.pth"))
43
+ )
44
+
45
+ shinkai_model.eval()
46
+ hosoda_model.eval()
47
+ miyazaki_model.eval()
48
+ kon_model.eval()
49
 
50
  disable_gpu = True
51
 
52
 
53
+ def get_model(style):
54
+ if style == STYLE_SHINKAI:
55
+ return shinkai_model
56
+ elif style == STYLE_HOSODA:
57
+ return hosoda_model
58
+ elif style == STYLE_MIYAZAKI:
59
+ return miyazaki_model
60
+ elif style == STYLE_KON:
61
+ return kon_model
62
+ else:
63
+ logger.warning(
64
+ f"Style {style} not found. Defaulting to Makoto Shinkai"
65
+ )
66
+ return shinkai_model
67
+
68
+
69
+ def inference(img, style):
70
  # load image
71
  input_image = img.convert(COLOUR_MODEL)
72
  input_image = np.asarray(input_image)
 
82
  input_image = Variable(input_image).cuda()
83
 
84
  # forward
85
+ model = get_model(style)
86
  output_image = model(input_image)
87
  output_image = output_image[0]
88
  # BGR -> RGB
 
94
 
95
  title = "Anime Background GAN"
96
  description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao."
97
+ article = "<p style='text-align: center'><a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN Whitepaper from Chen et.al</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p><p style='text-align: center'><a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Original Implementation from Yijunmaverick</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
98
 
99
  examples = [
100
  ["examples/garden_in.jpg"],
 
104
 
105
  gr.Interface(
106
  fn=inference,
107
+ inputs=[
108
+ gr.inputs.Image(type="pil", label="Input Photo"),
109
+ gradio.inputs.Dropdown(
110
+ STYLE_CHOICE_LIST,
111
+ type="value",
112
+ default=DEFAULT_STYLE,
113
+ label="Style",
114
+ ),
115
+ ],
116
  outputs=gr.outputs.Image(type="pil"),
117
  title=title,
118
  description=description,