akiyamasho commited on
Commit
3d6551f
1 Parent(s): a8868f5

MAINT: use HF hub instead of directly loading models

Browse files
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import sys
3
  import torch
@@ -8,6 +9,7 @@ import torchvision.transforms as transforms
8
 
9
  from torch.autograd import Variable
10
  from network.Transformer import Transformer
 
11
 
12
  from PIL import Image
13
 
@@ -16,6 +18,8 @@ import logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
 
19
  MAX_DIMENSION = 1280
20
  MODEL_PATH = "models"
21
  COLOUR_MODEL = "RGB"
@@ -27,23 +31,37 @@ STYLE_KON = "Satoshi Kon"
27
  DEFAULT_STYLE = STYLE_SHINKAI
28
  STYLE_CHOICE_LIST = [STYLE_SHINKAI, STYLE_HOSODA, STYLE_MIYAZAKI, STYLE_KON]
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  shinkai_model = Transformer()
31
  hosoda_model = Transformer()
32
  miyazaki_model = Transformer()
33
  kon_model = Transformer()
34
 
 
 
35
 
36
  shinkai_model.load_state_dict(
37
- torch.load(os.path.join(MODEL_PATH, "shinkai_makoto.pth"))
38
  )
39
  hosoda_model.load_state_dict(
40
- torch.load(os.path.join(MODEL_PATH, "hosoda_mamoru.pth"))
41
  )
42
  miyazaki_model.load_state_dict(
43
- torch.load(os.path.join(MODEL_PATH, "miyazaki_hayao.pth"))
44
  )
45
  kon_model.load_state_dict(
46
- torch.load(os.path.join(MODEL_PATH, "kon_satoshi.pth"))
47
  )
48
 
49
  shinkai_model.eval()
@@ -51,8 +69,7 @@ hosoda_model.eval()
51
  miyazaki_model.eval()
52
  kon_model.eval()
53
 
54
- enable_gpu = torch.cuda.is_available()
55
-
56
 
57
  def get_model(style):
58
  if style == STYLE_SHINKAI:
@@ -109,6 +126,8 @@ def inference(img, style):
109
  return transforms.ToPILImage()(output_image)
110
 
111
 
 
 
112
  title = "Anime Background GAN"
113
  description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao."
114
  article = "<p style='text-align: center'><a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN Whitepaper from Chen et.al</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p><p style='text-align: center'><a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Original Implementation from Yijunmaverick</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
 
1
+ from cgitb import enable
2
  import os
3
  import sys
4
  import torch
 
9
 
10
  from torch.autograd import Variable
11
  from network.Transformer import Transformer
12
+ from huggingface_hub import hf_hub_download
13
 
14
  from PIL import Image
15
 
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
+ # Constants
22
+
23
  MAX_DIMENSION = 1280
24
  MODEL_PATH = "models"
25
  COLOUR_MODEL = "RGB"
 
31
  DEFAULT_STYLE = STYLE_SHINKAI
32
  STYLE_CHOICE_LIST = [STYLE_SHINKAI, STYLE_HOSODA, STYLE_MIYAZAKI, STYLE_KON]
33
 
34
+ MODEL_REPO_ID = "akiyamasho/AnimeBackgroundGAN"
35
+ MODEL_FILE_SHINKAI = "shinkai_makoto.pth"
36
+ MODEL_FILE_HOSODA = "hosoda_mamoru.pth"
37
+ MODEL_FILE_MIYAZAKI = "miyazaki_hayao.pth"
38
+ MODEL_FILE_KON = "kon_satoshi.pth"
39
+
40
+ # Model Initalisation
41
+ shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILE_SHINKAI)
42
+ hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILE_HOSODA)
43
+ miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILE_MIYAZAKI)
44
+ kon_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILE_KON)
45
+
46
  shinkai_model = Transformer()
47
  hosoda_model = Transformer()
48
  miyazaki_model = Transformer()
49
  kon_model = Transformer()
50
 
51
+ enable_gpu = torch.cuda.is_available()
52
+ map_location = torch.device("cuda") if enable_gpu else "cpu"
53
 
54
  shinkai_model.load_state_dict(
55
+ torch.load(shinkai_model_hfhub, map_location=map_location)
56
  )
57
  hosoda_model.load_state_dict(
58
+ torch.load(hosoda_model_hfhub, map_location=map_location)
59
  )
60
  miyazaki_model.load_state_dict(
61
+ torch.load(miyazaki_model_hfhub, map_location=map_location)
62
  )
63
  kon_model.load_state_dict(
64
+ torch.load(kon_model_hfhub, map_location=map_location)
65
  )
66
 
67
  shinkai_model.eval()
 
69
  miyazaki_model.eval()
70
  kon_model.eval()
71
 
72
+ # Functions
 
73
 
74
  def get_model(style):
75
  if style == STYLE_SHINKAI:
 
126
  return transforms.ToPILImage()(output_image)
127
 
128
 
129
+ # Gradio setup
130
+
131
  title = "Anime Background GAN"
132
  description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao."
133
  article = "<p style='text-align: center'><a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN Whitepaper from Chen et.al</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p><p style='text-align: center'><a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Original Implementation from Yijunmaverick</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
models/hosoda_mamoru.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c666eea7700864d5972765cc43e926d900174648297bfef494006dc230fd1bf0
3
- size 44529096
 
 
 
 
models/kon_satoshi.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0629352a54838e56a2ad7fca3e6e51e6889d4338c37469f9ddb43e5929ef9475
3
- size 44529096
 
 
 
 
models/miyazaki_hayao.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ab0e492efb3b705487db38679e363dc8b1f016692913bbe100587d695a9e2b5
3
- size 44529096
 
 
 
 
models/shinkai_makoto.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3547f611e780e79aebde7f7bc2b6c278555d701620f125583d666351044c486
3
- size 44529096
 
 
 
 
requirements_dev.txt CHANGED
@@ -2,4 +2,5 @@
2
  black==21.12b0
3
  flake8==4.0.1
4
  gradio==2.9.1
5
- jinja2==3.1.1
 
 
2
  black==21.12b0
3
  flake8==4.0.1
4
  gradio==2.9.1
5
+ jinja2==3.1.1
6
+ huggingface_hub==0.4.0