SerdarHelli commited on
Commit
aafb4d8
1 Parent(s): 8fd8da4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -4
app.py CHANGED
@@ -166,6 +166,27 @@ def load_network_pkl_cpu(f, force_fp16=False):
166
  data[key] = new
167
  return data
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
170
 
171
  network=models[cfg]
@@ -194,11 +215,16 @@ def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames)
194
 
195
  save_dir = Path(outdir)
196
 
197
- input_label = PIL.Image.open(input)
198
- input_label = PIL.ImageOps.grayscale(input_label)
 
 
 
 
 
 
199
  input_label = np.asarray(input_label).astype(np.uint8)
200
  input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
201
- print(input_label.shape)
202
  input_pose = forward_pose.to(device)
203
 
204
  # Generate videos
@@ -282,7 +308,7 @@ demo_outputs=[
282
 
283
  ]
284
  examples = [
285
- ["seg2cat", "example_input.png", 1, 32, 128, 30, 30],
286
 
287
  ]
288
 
 
166
  data[key] = new
167
  return data
168
 
169
+ color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
170
+
171
+ def colormap2labelmap(color_img):
172
+ im_base = np.zeros((color_img.shape[0], color_img.shape[1]))
173
+ for idx, color in enumerate(color_list):
174
+
175
+ k1=((color_img == np.asarray(color))[:,:,0])*1
176
+ k2=((color_img == np.asarray(color))[:,:,1])*1
177
+ k3=((color_img == np.asarray(color))[:,:,2])*1
178
+ k=((k1*k2*k3)==1)
179
+
180
+ im_base[k] = idx
181
+ return im_base
182
+
183
+
184
+ def checklabelmap(img):
185
+ labels=np.unique(img)
186
+ for idx,label in enumerate(labels):
187
+ img[img==label]=idx
188
+ return img
189
+
190
  def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
191
 
192
  network=models[cfg]
 
215
 
216
  save_dir = Path(outdir)
217
 
218
+
219
+ if isinstance(input,str):
220
+ input_label =np.asarray( PIL.Image.open(input))
221
+ else:
222
+ input_label=np.asarray(input)
223
+
224
+ input_label=colormap2labelmap(input_label)
225
+ input_label=checklabelmap(input_label)
226
  input_label = np.asarray(input_label).astype(np.uint8)
227
  input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
 
228
  input_pose = forward_pose.to(device)
229
 
230
  # Generate videos
 
308
 
309
  ]
310
  examples = [
311
+ ["seg2cat", "img.png", 1, 32, 128, 30, 30],
312
 
313
  ]
314