M-DOL commited on
Commit
4a2ca64
1 Parent(s): efe0a64

no styleclip

Browse files
Files changed (2) hide show
  1. app.py +3 -5
  2. styleclip/styleclip_global.py +0 -181
app.py CHANGED
@@ -18,8 +18,7 @@ import torchvision.transforms as transforms
18
  from torchvision import utils
19
 
20
  from model.sg2_model import Generator
21
- from generate_videos import generate_frames, video_from_interpolations, project_code_by_edit_name
22
- from styleclip.styleclip_global import project_code_with_styleclip, style_tensor_to_style_dict
23
 
24
  import clip
25
 
@@ -30,7 +29,6 @@ os.makedirs(model_dir, exist_ok=True)
30
  model_repos = {
31
  "e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
32
  "dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
33
- "sc_fs3": ("rinong/stylegan-nada-models", "fs3.npy"),
34
  "base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
35
  "sketch": ("rinong/stylegan-nada-models", "sketch.pt"),
36
  "santa": ("mjdolan/stylegan-nada-models", "santa.pt"),
@@ -78,7 +76,7 @@ class ImageEditor(object):
78
 
79
  self.generators = {}
80
 
81
- self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib", "sc_fs3"]]
82
 
83
  for model in self.model_list:
84
  g_ema = Generator(
@@ -116,7 +114,6 @@ class ImageEditor(object):
116
  model_paths["dlib"]
117
  )
118
 
119
- self.styleclip_fs3 = torch.from_numpy(np.load(model_paths["sc_fs3"])).to(self.device)
120
 
121
  self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
122
 
@@ -163,6 +160,7 @@ class ImageEditor(object):
163
 
164
  return inner
165
 
 
166
  def get_target_latent(self, source_latent, alter, generators):
167
  np_source_latent = source_latent.squeeze(0).cpu().detach().numpy()
168
  if alter == "None":
 
18
  from torchvision import utils
19
 
20
  from model.sg2_model import Generator
21
+ from generate_videos import project_code_by_edit_name
 
22
 
23
  import clip
24
 
 
29
  model_repos = {
30
  "e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
31
  "dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
 
32
  "base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
33
  "sketch": ("rinong/stylegan-nada-models", "sketch.pt"),
34
  "santa": ("mjdolan/stylegan-nada-models", "santa.pt"),
 
76
 
77
  self.generators = {}
78
 
79
+ self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
80
 
81
  for model in self.model_list:
82
  g_ema = Generator(
 
114
  model_paths["dlib"]
115
  )
116
 
 
117
 
118
  self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
119
 
 
160
 
161
  return inner
162
 
163
+
164
  def get_target_latent(self, source_latent, alter, generators):
165
  np_source_latent = source_latent.squeeze(0).cpu().detach().numpy()
166
  if alter == "None":
styleclip/styleclip_global.py DELETED
@@ -1,181 +0,0 @@
1
- '''
2
- Code adapted from Stitch it in Time by Tzaban et al.
3
- https://github.com/rotemtzaban/STIT
4
- '''
5
-
6
-
7
- import numpy as np
8
- import torch
9
- from tqdm import tqdm
10
- from pathlib import Path
11
- import os
12
-
13
- import clip
14
-
15
- imagenet_templates = [
16
- 'a bad photo of a {}.',
17
- 'a photo of many {}.',
18
- 'a sculpture of a {}.',
19
- 'a photo of the hard to see {}.',
20
- 'a low resolution photo of the {}.',
21
- 'a rendering of a {}.',
22
- 'graffiti of a {}.',
23
- 'a bad photo of the {}.',
24
- 'a cropped photo of the {}.',
25
- 'a tattoo of a {}.',
26
- 'the embroidered {}.',
27
- 'a photo of a hard to see {}.',
28
- 'a bright photo of a {}.',
29
- 'a photo of a clean {}.',
30
- 'a photo of a dirty {}.',
31
- 'a dark photo of the {}.',
32
- 'a drawing of a {}.',
33
- 'a photo of my {}.',
34
- 'the plastic {}.',
35
- 'a photo of the cool {}.',
36
- 'a close-up photo of a {}.',
37
- 'a black and white photo of the {}.',
38
- 'a painting of the {}.',
39
- 'a painting of a {}.',
40
- 'a pixelated photo of the {}.',
41
- 'a sculpture of the {}.',
42
- 'a bright photo of the {}.',
43
- 'a cropped photo of a {}.',
44
- 'a plastic {}.',
45
- 'a photo of the dirty {}.',
46
- 'a jpeg corrupted photo of a {}.',
47
- 'a blurry photo of the {}.',
48
- 'a photo of the {}.',
49
- 'a good photo of the {}.',
50
- 'a rendering of the {}.',
51
- 'a {} in a video game.',
52
- 'a photo of one {}.',
53
- 'a doodle of a {}.',
54
- 'a close-up photo of the {}.',
55
- 'a photo of a {}.',
56
- 'the origami {}.',
57
- 'the {} in a video game.',
58
- 'a sketch of a {}.',
59
- 'a doodle of the {}.',
60
- 'a origami {}.',
61
- 'a low resolution photo of a {}.',
62
- 'the toy {}.',
63
- 'a rendition of the {}.',
64
- 'a photo of the clean {}.',
65
- 'a photo of a large {}.',
66
- 'a rendition of a {}.',
67
- 'a photo of a nice {}.',
68
- 'a photo of a weird {}.',
69
- 'a blurry photo of a {}.',
70
- 'a cartoon {}.',
71
- 'art of a {}.',
72
- 'a sketch of the {}.',
73
- 'a embroidered {}.',
74
- 'a pixelated photo of a {}.',
75
- 'itap of the {}.',
76
- 'a jpeg corrupted photo of the {}.',
77
- 'a good photo of a {}.',
78
- 'a plushie {}.',
79
- 'a photo of the nice {}.',
80
- 'a photo of the small {}.',
81
- 'a photo of the weird {}.',
82
- 'the cartoon {}.',
83
- 'art of the {}.',
84
- 'a drawing of the {}.',
85
- 'a photo of the large {}.',
86
- 'a black and white photo of a {}.',
87
- 'the plushie {}.',
88
- 'a dark photo of a {}.',
89
- 'itap of a {}.',
90
- 'graffiti of the {}.',
91
- 'a toy {}.',
92
- 'itap of my {}.',
93
- 'a photo of a cool {}.',
94
- 'a photo of a small {}.',
95
- 'a tattoo of the {}.',
96
- ]
97
-
98
- CONV_CODE_INDICES = [(0, 512), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)]
99
- FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
100
- [(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
101
-
102
- def zeroshot_classifier(model, classnames, templates, device):
103
-
104
- with torch.no_grad():
105
- zeroshot_weights = []
106
- for classname in tqdm(classnames):
107
- texts = [template.format(classname) for template in templates] # format with class
108
- texts = clip.tokenize(texts).to(device) # tokenize
109
- class_embeddings = model.encode_text(texts) # embed with text encoder
110
- class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
111
- class_embedding = class_embeddings.mean(dim=0)
112
- class_embedding /= class_embedding.norm()
113
- zeroshot_weights.append(class_embedding)
114
- zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
115
- return zeroshot_weights
116
-
117
- def expand_to_full_dim(partial_tensor):
118
- full_dim_tensor = torch.zeros(size=(1, 9088))
119
-
120
- start_idx = 0
121
- for conv_start, conv_end in CONV_CODE_INDICES:
122
- length = conv_end - conv_start
123
- full_dim_tensor[:, conv_start:conv_end] = partial_tensor[start_idx:start_idx + length]
124
- start_idx += length
125
-
126
- return full_dim_tensor
127
-
128
- def get_direction(neutral_class, target_class, beta, di, clip_model=None):
129
-
130
- device = "cuda" if torch.cuda.is_available() else "cpu"
131
-
132
- if clip_model is None:
133
- clip_model, _ = clip.load("ViT-B/32", device=device)
134
-
135
- class_names = [neutral_class, target_class]
136
- class_weights = zeroshot_classifier(clip_model, class_names, imagenet_templates, device)
137
-
138
- dt = class_weights[:, 1] - class_weights[:, 0]
139
- dt = dt / dt.norm()
140
-
141
- dt = dt.float()
142
- di = di.float()
143
-
144
- relevance = di @ dt
145
- mask = relevance.abs() > beta
146
- direction = relevance * mask
147
- direction_max = direction.abs().max()
148
- if direction_max > 0:
149
- direction = direction / direction_max
150
- else:
151
- raise ValueError(f'Beta value {beta} is too high for mapping from {neutral_class} to {target_class},'
152
- f' try setting it to a lower value')
153
- return direction
154
-
155
- def style_tensor_to_style_dict(style_tensor, refernce_generator):
156
- style_layers = refernce_generator.modulation_layers
157
-
158
- style_dict = {}
159
- for layer_idx, layer in enumerate(style_layers):
160
- style_dict[layer] = style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]]
161
-
162
- return style_dict
163
-
164
- def style_dict_to_style_tensor(style_dict, reference_generator):
165
- style_layers = reference_generator.modulation_layers
166
-
167
- style_tensor = torch.zeros(size=(1, 9088))
168
- for layer in style_dict:
169
- layer_idx = style_layers.index(layer)
170
- style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]
171
-
172
- return style_tensor
173
-
174
- def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
175
- edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
176
-
177
- edit_full_dim = expand_to_full_dim(edit_direction)
178
-
179
- source_s = style_dict_to_style_tensor(source_latent, reference_generator)
180
-
181
- return source_s + alpha * edit_full_dim