winfred2027 commited on
Commit
a8e5fc3
1 Parent(s): 4429214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -41
app.py CHANGED
@@ -30,25 +30,22 @@ def load_openshape(name, to_cpu=False):
30
  pce = pce.cpu()
31
  return pce
32
 
33
- def retrieval_filter_expand(key):
34
- with st.expander("Filters"):
35
- sim_th = st.slider("Similarity Threshold", 0.05, 0.5, 0.1, key=key + 'rtsimth')
36
- tag = st.text_input("Has Tag", "", key=key + 'rthastag')
37
- col1, col2 = st.columns(2)
38
- face_min = int(col1.text_input("Face Count Min", "0", key=key + 'rtfcmin'))
39
- face_max = int(col2.text_input("Face Count Max", "34985808", key=key + 'rtfcmax'))
40
- col1, col2 = st.columns(2)
41
- anim_min = int(col1.text_input("Animation Count Min", "0", key=key + 'rtacmin'))
42
- anim_max = int(col2.text_input("Animation Count Max", "563", key=key + 'rtacmax'))
43
- tag_n = not bool(tag.strip())
44
- anim_n = not (anim_min > 0 or anim_max < 563)
45
- face_n = not (face_min > 0 or face_max < 34985808)
46
- filter_fn = lambda x: (
47
- (anim_n or anim_min <= x['anims'] <= anim_max)
48
- and (face_n or face_min <= x['faces'] <= face_max)
49
- and (tag_n or tag in x['tags'])
50
- )
51
- return sim_th, filter_fn
52
 
53
  def retrieval_results(results):
54
  st.caption("Click the link to view the 3D shape")
@@ -148,32 +145,125 @@ def demo_retrieval():
148
 
149
  prog.progress(1.0, "Idle")
150
 
151
- st.title("TripletMix Demo")
152
- st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
153
- prog = st.progress(0.0, "Idle")
154
- tab_cls, tab_pc, tab_img, tab_text, tab_sd, tab_cap = st.tabs([
155
- "Classification",
156
- "Retrieval w/ 3D",
157
- "Retrieval w/ Image",
158
- "Retrieval w/ Text",
159
- "Image Generation",
160
- "Captioning",
161
- ])
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- f32 = numpy.float32
165
- half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
166
- clip_model, clip_prep = load_openclip()
167
- model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
168
 
169
  try:
170
- with tab_cls:
171
- demo_classification()
172
- with tab_cap:
173
- demo_captioning()
174
- with tab_sd:
175
- demo_pc2img()
176
- demo_retrieval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  except Exception:
178
  import traceback
179
  st.error(traceback.format_exc().replace("\n", " \n"))
 
30
  pce = pce.cpu()
31
  return pce
32
 
33
+ def retrieval_filter_expand():
34
+ sim_th = st.sidebar.slider("Similarity Threshold", 0.05, 0.5, 0.1, key='rsimth')
35
+ tag = ""
36
+ face_min = 0
37
+ face_max = 34985808
38
+ anim_min = 0
39
+ anim_max = 563
40
+ tag_n = not bool(tag.strip())
41
+ anim_n = not (anim_min > 0 or anim_max < 563)
42
+ face_n = not (face_min > 0 or face_max < 34985808)
43
+ filter_fn = lambda x: (
44
+ (anim_n or anim_min <= x['anims'] <= anim_max)
45
+ and (face_n or face_min <= x['faces'] <= face_max)
46
+ and (tag_n or tag in x['tags'])
47
+ )
48
+ return sim_th, filter_fn
 
 
 
49
 
50
  def retrieval_results(results):
51
  st.caption("Click the link to view the 3D shape")
 
145
 
146
  prog.progress(1.0, "Idle")
147
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ def retrieval_pc(load_data, k, sim_th, filter_fn):
150
+ pc = load_data(prog)
151
+ prog.progress(0.49, "Computing Embeddings")
152
+ col2 = utils.render_pc(pc)
153
+ ref_dev = next(model_g14.parameters()).device
154
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
155
+
156
+ sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
157
+ argsort = torch.argsort(sim, descending=True)
158
+ pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
159
+ with col2:
160
+ for i, (cat, sim) in zip(range(5), pred.items()):
161
+ st.text(cat)
162
+ st.caption("Similarity %.4f" % sim)
163
+
164
+ prog.progress(0.7, "Running Retrieval")
165
+ retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
166
+
167
+ prog.progress(1.0, "Idle")
168
+
169
+ def retrieval_img(pic, k, sim_th, filter_fn):
170
+ img = Image.open(pic)
171
+ prog.progress(0.49, "Computing Embeddings")
172
+ st.image(img)
173
+ device = clip_model.device
174
+ tn = clip_prep(images=[img], return_tensors="pt").to(device)
175
+ enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
176
+
177
+ prog.progress(0.7, "Running Retrieval")
178
+ retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
179
+
180
+ prog.progress(1.0, "Idle")
181
+
182
+ def retrieval_text(text, k, sim_th, filter_fn):
183
+ prog.progress(0.49, "Computing Embeddings")
184
+ device = clip_model.device
185
+ tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
186
+ enc = clip_model.get_text_features(**tn).float().cpu()
187
 
188
+ prog.progress(0.7, "Running Retrieval")
189
+ retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
190
+
191
+ prog.progress(1.0, "Idle")
192
 
193
  try:
194
+ f32 = numpy.float32
195
+ half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
196
+ clip_model, clip_prep = load_openclip()
197
+ model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
198
+
199
+ st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
200
+ st.sidebar.title("TripletMix Demo Configuration Panel")
201
+ task = st.sidebar.selectbox(
202
+ 'Task Selection',
203
+ ("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
204
+ )
205
+
206
+ if task == "3D Classification":
207
+ cls_mode = st.sidebar.selectbox(
208
+ 'Choose the source of categories',
209
+ ("LVIS Categories", "Custom Categories")
210
+ )
211
+ pc = st.sidebar.text_input("Input pc", key='rtextinput')
212
+ if cls_mode == "LVIS Categories":
213
+ if st.sidebar.button("submit"):
214
+ st.title("Classification with LVIS Categories")
215
+ prog = st.progress(0.0, "Idle")
216
+
217
+ elif cls_mode == "Custom Categories":
218
+ cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
219
+ cats = [a.strip() for a in cats.split(',')]
220
+ if len(cats) > 64:
221
+ st.error('Maximum 64 custom categories supported in the demo')
222
+ if st.sidebar.button("submit"):
223
+ st.title("Classification with Custom Categories")
224
+ prog = st.progress(0.0, "Idle")
225
+
226
+ elif task == "Cross-modal retrieval":
227
+ input_mode = st.sidebar.selectbox(
228
+ 'Choose an input modality',
229
+ ("Point Cloud", "Image", "Text")
230
+ )
231
+ k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
232
+ sim_th, filter_fn = retrieval_filter_expand()
233
+ if input_mode == "Point Cloud":
234
+ load_data = utils.input_3d_shape('rpcinput')
235
+ if st.sidebar.button("submit"):
236
+ st.title("Retrieval with Point Cloud")
237
+ prog = st.progress(0.0, "Idle")
238
+ retrieval_pc(load_data, k, sim_th, filter_fn)
239
+ elif input_mode == "Image":
240
+ pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
241
+ if st.sidebar.button("submit"):
242
+ st.title("Retrieval with Image")
243
+ prog = st.progress(0.0, "Idle")
244
+ retrieval_img(pic, k, sim_th, filter_fn)
245
+ elif input_mode == "Text":
246
+ text = st.sidebar.text_input("Input Text", key='rtextinput')
247
+ if st.sidebar.button("submit"):
248
+ st.title("Retrieval with Text")
249
+ prog = st.progress(0.0, "Idle")
250
+ retrieval_text(text, k, sim_th, filter_fn)
251
+ elif task == "Cross-modal generation":
252
+ generation_mode = st.sidebar.selectbox(
253
+ 'Choose the mode of generation',
254
+ ("PointCloud-to-Image", "PointCloud-to-Text")
255
+ )
256
+ pc = st.sidebar.text_input("Input pc", key='rtextinput')
257
+ if generation_mode == "PointCloud-to-Image":
258
+ if st.sidebar.button("submit"):
259
+ st.title("Image Generation")
260
+ prog = st.progress(0.0, "Idle")
261
+
262
+ elif generation_mode == "PointCloud-to-Text":
263
+ if st.sidebar.button("submit"):
264
+ st.title("Text Generation")
265
+ prog = st.progress(0.0, "Idle")
266
+
267
  except Exception:
268
  import traceback
269
  st.error(traceback.format_exc().replace("\n", " \n"))