shubham-goel commited on
Commit
0c2905b
1 Parent(s): ad3be6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import os
3
  from pathlib import Path
 
4
  import sys
5
  import cv2
6
  import gradio as gr
@@ -25,6 +26,8 @@ except:
25
  import os
26
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
27
 
 
 
28
 
29
  # Setup HMR2.0 model
30
  LIGHT_BLUE=(0.65098039, 0.74117647, 0.85882353)
@@ -71,7 +74,10 @@ def infer(in_pil_img, in_threshold=0.8, out_pil_img=None):
71
 
72
  all_verts = []
73
  all_cam_t = []
 
74
 
 
 
75
  for batch in dataloader:
76
  batch = recursive_to(batch, device)
77
  with torch.no_grad():
@@ -101,6 +107,15 @@ def infer(in_pil_img, in_threshold=0.8, out_pil_img=None):
101
  all_verts.append(verts)
102
  all_cam_t.append(cam_t)
103
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Render front view
106
  if len(all_verts) > 0:
@@ -118,9 +133,9 @@ def infer(in_pil_img, in_threshold=0.8, out_pil_img=None):
118
  # convert to PIL image
119
  out_pil_img = Image.fromarray((input_img_overlay*255).astype(np.uint8))
120
 
121
- return out_pil_img
122
  else:
123
- return None
124
 
125
 
126
  with gr.Blocks(title="4DHumans", css=".gradio-container") as demo:
@@ -128,15 +143,18 @@ with gr.Blocks(title="4DHumans", css=".gradio-container") as demo:
128
  gr.HTML("""<div style="font-weight:bold; text-align:center; color:royalblue;">HMR 2.0</div>""")
129
 
130
  with gr.Row():
131
- input_image = gr.Image(label="Input image", type="pil", width=300, height=300, fixed_size=True)
132
- output_image = gr.Image(label="Reconstructions", type="pil", width=300, height=300, fixed_size=True)
 
 
 
133
 
134
  gr.HTML("""<br/>""")
135
 
136
  with gr.Row():
137
  threshold = gr.Slider(0, 1.0, value=0.6, label='Detection Threshold')
138
  send_btn = gr.Button("Infer")
139
- send_btn.click(fn=infer, inputs=[input_image, threshold], outputs=[output_image])
140
 
141
  # gr.Examples([
142
  # ['assets/test1.png', 0.6],
@@ -156,9 +174,6 @@ with gr.Blocks(title="4DHumans", css=".gradio-container") as demo:
156
  ],
157
  inputs=[input_image, 0.6])
158
 
159
- gr.HTML("""</ul>""")
160
-
161
-
162
 
163
  #demo.queue()
164
  demo.launch(debug=True)
@@ -166,4 +181,4 @@ demo.launch(debug=True)
166
 
167
 
168
 
169
- ### EOF ###
 
1
  import argparse
2
  import os
3
  from pathlib import Path
4
+ import tempfile
5
  import sys
6
  import cv2
7
  import gradio as gr
 
26
  import os
27
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
28
 
29
+ OUT_FOLDER = 'demo_out'
30
+ os.makedirs(OUT_FOLDER, exist_ok=True)
31
 
32
  # Setup HMR2.0 model
33
  LIGHT_BLUE=(0.65098039, 0.74117647, 0.85882353)
 
74
 
75
  all_verts = []
76
  all_cam_t = []
77
+ all_mesh_paths = []
78
 
79
+ temp_name = next(tempfile._get_candidate_names())
80
+
81
  for batch in dataloader:
82
  batch = recursive_to(batch, device)
83
  with torch.no_grad():
 
107
  all_verts.append(verts)
108
  all_cam_t.append(cam_t)
109
 
110
+ # Save all meshes to disk
111
+ # if args.save_mesh:
112
+ if True:
113
+ camera_translation = cam_t.copy()
114
+ tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_BLUE)
115
+
116
+ temp_path = os.path.join(f'{OUT_FOLDER}/{temp_name}_{person_id}.obj')
117
+ tmesh.export(temp_path)
118
+ all_mesh_paths.append(temp_path)
119
 
120
  # Render front view
121
  if len(all_verts) > 0:
 
133
  # convert to PIL image
134
  out_pil_img = Image.fromarray((input_img_overlay*255).astype(np.uint8))
135
 
136
+ return out_pil_img, all_mesh_paths
137
  else:
138
+ return None, []
139
 
140
 
141
  with gr.Blocks(title="4DHumans", css=".gradio-container") as demo:
 
143
  gr.HTML("""<div style="font-weight:bold; text-align:center; color:royalblue;">HMR 2.0</div>""")
144
 
145
  with gr.Row():
146
+ with gr.Column():
147
+ input_image = gr.Image(label="Input image", type="pil")
148
+ with gr.Column():
149
+ output_image = gr.Image(label="Reconstructions", type="pil")
150
+ output_meshes = gr.File(label="3D meshes")
151
 
152
  gr.HTML("""<br/>""")
153
 
154
  with gr.Row():
155
  threshold = gr.Slider(0, 1.0, value=0.6, label='Detection Threshold')
156
  send_btn = gr.Button("Infer")
157
+ send_btn.click(fn=infer, inputs=[input_image, threshold], outputs=[output_image, output_meshes])
158
 
159
  # gr.Examples([
160
  # ['assets/test1.png', 0.6],
 
174
  ],
175
  inputs=[input_image, 0.6])
176
 
 
 
 
177
 
178
  #demo.queue()
179
  demo.launch(debug=True)
 
181
 
182
 
183
 
184
+ ### EOF ###