SerdarHelli commited on
Commit
d64da46
1 Parent(s): 692f3e3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -2
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import sys
2
  import os
3
 
4
- os.system("git clone https://github.com/dunbar12138/pix2pix3D.git")
5
  sys.path.append("pix2pix3D")
6
 
7
  from typing import List, Optional, Tuple, Union
@@ -23,6 +23,15 @@ from tqdm import tqdm
23
  import imageio
24
  import trimesh
25
  import mcubes
 
 
 
 
 
 
 
 
 
26
 
27
  os.environ["PYOPENGL_PLATFORM"] = "egl"
28
 
@@ -112,12 +121,59 @@ models={"seg2cat":network_cat
112
 
113
  device='cuda' if torch.cuda.is_available() else 'cpu'
114
  outdir="/content/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
117
 
118
  network=models[cfg]
119
 
120
- with dnnlib.util.open_url(network) as f:
 
 
 
 
121
  G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
122
 
123
  if cfg == 'seg2cat' or cfg == 'seg2face':
@@ -239,5 +295,6 @@ demo_app = gr.Interface(
239
  theme="huggingface",
240
  description=desc,
241
  examples=examples,
 
242
  )
243
  demo_app.launch(debug=True, enable_queue=True)
 
1
  import sys
2
  import os
3
 
4
+ os.system("https://github.com/dunbar12138/pix2pix3D.git")
5
  sys.path.append("pix2pix3D")
6
 
7
  from typing import List, Optional, Tuple, Union
 
23
  import imageio
24
  import trimesh
25
  import mcubes
26
+ import copy
27
+
28
+ import pickle
29
+ import numpy as np
30
+ import torch
31
+ import dnnlib
32
+ from torch_utils import misc
33
+ from legacy import *
34
+ import io
35
 
36
  os.environ["PYOPENGL_PLATFORM"] = "egl"
37
 
 
121
 
122
  device='cuda' if torch.cuda.is_available() else 'cpu'
123
  outdir="/content/"
124
+ class CPU_Unpickler(pickle.Unpickler):
125
+ def find_class(self, module, name):
126
+ if module == 'torch.storage' and name == '_load_from_bytes':
127
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
128
+ return super().find_class(module, name)
129
+
130
+ def load_network_pkl_cpu(f, force_fp16=False):
131
+ data = CPU_Unpickler(f).load()
132
+
133
+ # Legacy TensorFlow pickle => convert.
134
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
135
+ tf_G, tf_D, tf_Gs = data
136
+ G = convert_tf_generator(tf_G)
137
+ D = convert_tf_discriminator(tf_D)
138
+ G_ema = convert_tf_generator(tf_Gs)
139
+ data = dict(G=G, D=D, G_ema=G_ema)
140
+
141
+ # Add missing fields.
142
+ if 'training_set_kwargs' not in data:
143
+ data['training_set_kwargs'] = None
144
+ if 'augment_pipe' not in data:
145
+ data['augment_pipe'] = None
146
+
147
+ # Validate contents.
148
+ assert isinstance(data['G'], torch.nn.Module)
149
+ assert isinstance(data['D'], torch.nn.Module)
150
+ assert isinstance(data['G_ema'], torch.nn.Module)
151
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
152
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
153
+
154
+ # Force FP16.
155
+ if force_fp16:
156
+ for key in ['G', 'D', 'G_ema']:
157
+ old = data[key]
158
+ kwargs = copy.deepcopy(old.init_kwargs)
159
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
160
+ fp16_kwargs.num_fp16_res = 4
161
+ fp16_kwargs.conv_clamp = 256
162
+ if kwargs != old.init_kwargs:
163
+ new = type(old)(**kwargs).eval().requires_grad_(False)
164
+ misc.copy_params_and_buffers(old, new, require_all=True)
165
+ data[key] = new
166
+ return data
167
 
168
  def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
169
 
170
  network=models[cfg]
171
 
172
+ if device=="cpu":
173
+ with dnnlib.util.open_url(network) as f:
174
+ G = load_network_pkl_cpu(f)['G_ema'].eval().to(device)
175
+ else:
176
+ with dnnlib.util.open_url(network) as f:
177
  G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
178
 
179
  if cfg == 'seg2cat' or cfg == 'seg2face':
 
295
  theme="huggingface",
296
  description=desc,
297
  examples=examples,
298
+ cache_examples=True,
299
  )
300
  demo_app.launch(debug=True, enable_queue=True)