Spaces:
Runtime error
Runtime error
SerdarHelli
commited on
Commit
•
aafb4d8
1
Parent(s):
8fd8da4
Update app.py
Browse files
app.py
CHANGED
@@ -166,6 +166,27 @@ def load_network_pkl_cpu(f, force_fp16=False):
|
|
166 |
data[key] = new
|
167 |
return data
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
|
170 |
|
171 |
network=models[cfg]
|
@@ -194,11 +215,16 @@ def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames)
|
|
194 |
|
195 |
save_dir = Path(outdir)
|
196 |
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
input_label = np.asarray(input_label).astype(np.uint8)
|
200 |
input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
|
201 |
-
print(input_label.shape)
|
202 |
input_pose = forward_pose.to(device)
|
203 |
|
204 |
# Generate videos
|
@@ -282,7 +308,7 @@ demo_outputs=[
|
|
282 |
|
283 |
]
|
284 |
examples = [
|
285 |
-
["seg2cat", "
|
286 |
|
287 |
]
|
288 |
|
|
|
166 |
data[key] = new
|
167 |
return data
|
168 |
|
169 |
+
color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
|
170 |
+
|
171 |
+
def colormap2labelmap(color_img):
|
172 |
+
im_base = np.zeros((color_img.shape[0], color_img.shape[1]))
|
173 |
+
for idx, color in enumerate(color_list):
|
174 |
+
|
175 |
+
k1=((color_img == np.asarray(color))[:,:,0])*1
|
176 |
+
k2=((color_img == np.asarray(color))[:,:,1])*1
|
177 |
+
k3=((color_img == np.asarray(color))[:,:,2])*1
|
178 |
+
k=((k1*k2*k3)==1)
|
179 |
+
|
180 |
+
im_base[k] = idx
|
181 |
+
return im_base
|
182 |
+
|
183 |
+
|
184 |
+
def checklabelmap(img):
|
185 |
+
labels=np.unique(img)
|
186 |
+
for idx,label in enumerate(labels):
|
187 |
+
img[img==label]=idx
|
188 |
+
return img
|
189 |
+
|
190 |
def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
|
191 |
|
192 |
network=models[cfg]
|
|
|
215 |
|
216 |
save_dir = Path(outdir)
|
217 |
|
218 |
+
|
219 |
+
if isinstance(input,str):
|
220 |
+
input_label =np.asarray( PIL.Image.open(input))
|
221 |
+
else:
|
222 |
+
input_label=np.asarray(input)
|
223 |
+
|
224 |
+
input_label=colormap2labelmap(input_label)
|
225 |
+
input_label=checklabelmap(input_label)
|
226 |
input_label = np.asarray(input_label).astype(np.uint8)
|
227 |
input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
|
|
|
228 |
input_pose = forward_pose.to(device)
|
229 |
|
230 |
# Generate videos
|
|
|
308 |
|
309 |
]
|
310 |
examples = [
|
311 |
+
["seg2cat", "img.png", 1, 32, 128, 30, 30],
|
312 |
|
313 |
]
|
314 |
|