hysts HF staff commited on
Commit
e38549b
1 Parent(s): a98b858

Use models from public repo

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. dualstylegan.py +13 -15
  3. packages.txt +1 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ shape_predictor_68_face_landmarks.dat*
dualstylegan.py CHANGED
@@ -2,6 +2,8 @@ from __future__ import annotations
2
 
3
  import argparse
4
  import os
 
 
5
  import sys
6
  from typing import Callable, Union
7
 
@@ -23,8 +25,7 @@ from model.dualstylegan import DualStyleGAN
23
  from model.encoder.align_all_parallel import align_face
24
  from model.encoder.psp import pSp
25
 
26
- HF_TOKEN = os.environ['HF_TOKEN']
27
- MODEL_REPO = 'hysts/DualStyleGAN'
28
 
29
 
30
  class Model:
@@ -54,16 +55,17 @@ class Model:
54
 
55
  @staticmethod
56
  def _create_dlib_landmark_model():
57
- path = huggingface_hub.hf_hub_download(
58
- 'hysts/dlib_face_landmark_model',
59
- 'shape_predictor_68_face_landmarks.dat',
60
- use_auth_token=HF_TOKEN)
61
- return dlib.shape_predictor(path)
 
 
62
 
63
  def _load_encoder(self) -> nn.Module:
64
  ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
65
- 'models/encoder.pt',
66
- use_auth_token=HF_TOKEN)
67
  ckpt = torch.load(ckpt_path, map_location='cpu')
68
  opts = ckpt['opts']
69
  opts['device'] = self.device.type
@@ -87,9 +89,7 @@ class Model:
87
  def _load_generator(self, style_type: str) -> nn.Module:
88
  model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
89
  ckpt_path = huggingface_hub.hf_hub_download(
90
- MODEL_REPO,
91
- f'models/{style_type}/generator.pt',
92
- use_auth_token=HF_TOKEN)
93
  ckpt = torch.load(ckpt_path, map_location='cpu')
94
  model.load_state_dict(ckpt['g_ema'])
95
  model.to(self.device)
@@ -103,9 +103,7 @@ class Model:
103
  else:
104
  filename = 'exstyle_code.npy'
105
  path = huggingface_hub.hf_hub_download(
106
- MODEL_REPO,
107
- f'models/{style_type}/{filename}',
108
- use_auth_token=HF_TOKEN)
109
  exstyles = np.load(path, allow_pickle=True).item()
110
  return exstyles
111
 
 
2
 
3
  import argparse
4
  import os
5
+ import pathlib
6
+ import subprocess
7
  import sys
8
  from typing import Callable, Union
9
 
 
25
  from model.encoder.align_all_parallel import align_face
26
  from model.encoder.psp import pSp
27
 
28
+ MODEL_REPO = 'CVPR/DualStyleGAN'
 
29
 
30
 
31
  class Model:
 
55
 
56
  @staticmethod
57
  def _create_dlib_landmark_model():
58
+ url = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
59
+ path = pathlib.Path('shape_predictor_68_face_landmarks.dat')
60
+ if not path.exists():
61
+ bz2_path = 'shape_predictor_68_face_landmarks.dat.bz2'
62
+ torch.hub.download_url_to_file(url, bz2_path)
63
+ subprocess.run(f'bunzip2 -d {bz2_path}'.split())
64
+ return dlib.shape_predictor(path.as_posix())
65
 
66
  def _load_encoder(self) -> nn.Module:
67
  ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
68
+ 'models/encoder.pt')
 
69
  ckpt = torch.load(ckpt_path, map_location='cpu')
70
  opts = ckpt['opts']
71
  opts['device'] = self.device.type
 
89
  def _load_generator(self, style_type: str) -> nn.Module:
90
  model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
91
  ckpt_path = huggingface_hub.hf_hub_download(
92
+ MODEL_REPO, f'models/{style_type}/generator.pt')
 
 
93
  ckpt = torch.load(ckpt_path, map_location='cpu')
94
  model.load_state_dict(ckpt['g_ema'])
95
  model.to(self.device)
 
103
  else:
104
  filename = 'exstyle_code.npy'
105
  path = huggingface_hub.hf_hub_download(
106
+ MODEL_REPO, f'models/{style_type}/{filename}')
 
 
107
  exstyles = np.load(path, allow_pickle=True).item()
108
  return exstyles
109
 
packages.txt CHANGED
@@ -1,2 +1,3 @@
 
1
  cmake
2
  ninja-build
 
1
+ bzip2
2
  cmake
3
  ninja-build