winfred2027 commited on
Commit
fbeb0dc
1 Parent(s): a403cac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -33,11 +33,10 @@ def load_openshape(name, to_cpu=False):
33
 
34
 
35
  def load_tripletmix(name, to_cpu=False):
36
- pce, pca = openshape.load_pc_encoder_mix(name)
37
  if to_cpu:
38
  pce = pce.cpu()
39
- pca = pca.cpu()
40
- return pce, pca
41
 
42
 
43
 
@@ -81,10 +80,8 @@ def classification_lvis(load_data):
81
  pc = load_data(prog)
82
  col2 = utils.render_pc(pc)
83
  prog.progress(0.5, "Running Classification")
84
- ref_dev = next(model_g14.parameters()).device
85
- enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
86
- if model_name == "pb-sn-M":
87
- enc = pc_adapter(enc)
88
  sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
89
  argsort = torch.argsort(sim, descending=True)
90
  pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
@@ -104,10 +101,8 @@ def classification_custom(load_data, cats):
104
  feats = clip_model.get_text_features(**tn).float().cpu()
105
 
106
  prog.progress(0.5, "Running Classification")
107
- ref_dev = next(model_g14.parameters()).device
108
- enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
109
- if model_name == "pb-sn-M":
110
- enc = pc_adapter(enc)
111
  sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
112
  argsort = torch.argsort(sim, descending=True)
113
  pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
@@ -207,14 +202,7 @@ try:
207
 
208
  st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
209
  st.sidebar.title("TripletMix Demo Configuration Panel")
210
- model_name = st.sidebar.selectbox(
211
- 'Model Selection',
212
- ("pb-sn-M", "pb-sn")
213
- )
214
- if model_name == "pb-sn-M":
215
- model_g14, pc_adapter = load_tripletmix('tripletmix-pointbert-shapenet')
216
- elif model_name == "pb-sn":
217
- model_g14 = load_openshape('openshape-pointbert-shapenet')
218
  task = st.sidebar.selectbox(
219
  'Task Selection',
220
  ("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
@@ -225,6 +213,14 @@ try:
225
  'Choose the source of categories',
226
  ("LVIS Categories", "Custom Categories")
227
  )
 
 
 
 
 
 
 
 
228
  load_data = utils.input_3d_shape('rpcinput')
229
  if cls_mode == "LVIS Categories":
230
  st.title("Classification with LVIS Categories")
 
33
 
34
 
35
  def load_tripletmix(name, to_cpu=False):
36
+ pce = openshape.load_pc_encoder_mix(name)
37
  if to_cpu:
38
  pce = pce.cpu()
39
+ return pce
 
40
 
41
 
42
 
 
80
  pc = load_data(prog)
81
  col2 = utils.render_pc(pc)
82
  prog.progress(0.5, "Running Classification")
83
+ ref_dev = next(model_classification.parameters()).device
84
+ enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
 
 
85
  sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
86
  argsort = torch.argsort(sim, descending=True)
87
  pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
 
101
  feats = clip_model.get_text_features(**tn).float().cpu()
102
 
103
  prog.progress(0.5, "Running Classification")
104
+ ref_dev = next(model_classification.parameters()).device
105
+ enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
 
 
106
  sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
107
  argsort = torch.argsort(sim, descending=True)
108
  pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
 
202
 
203
  st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
204
  st.sidebar.title("TripletMix Demo Configuration Panel")
205
+
 
 
 
 
 
 
 
206
  task = st.sidebar.selectbox(
207
  'Task Selection',
208
  ("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
 
213
  'Choose the source of categories',
214
  ("LVIS Categories", "Custom Categories")
215
  )
216
+ model_name = st.sidebar.selectbox(
217
+ 'Model Selection',
218
+ ("pb-Mix", "pb")
219
+ )
220
+ if model_name == "pb-Mix":
221
+ model_classification = load_tripletmix('tripletmix-pointbert-all-modelnet40')
222
+ elif model_name == "pb":
223
+ model_classification = load_openshape('openshape-pointbert-vitg14-rgb')
224
  load_data = utils.input_3d_shape('rpcinput')
225
  if cls_mode == "LVIS Categories":
226
  st.title("Classification with LVIS Categories")