winfred2027 commited on
Commit
5d090b9
1 Parent(s): 89a587b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -33,10 +33,11 @@ def load_openshape(name, to_cpu=False):
33
 
34
 
35
  def load_tripletmix(name, to_cpu=False):
36
- pce = openshape.load_pc_encoder_mix(name)
37
  if to_cpu:
38
  pce = pce.cpu()
39
- return pce
 
40
 
41
 
42
 
@@ -81,9 +82,10 @@ def classification_lvis(load_data):
81
  col2 = utils.render_pc(pc)
82
  prog.progress(0.5, "Running Classification")
83
  ref_dev = next(model_g14.parameters()).device
84
- enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
85
-
86
- sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
 
87
  argsort = torch.argsort(sim, descending=True)
88
  pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
89
  with col2:
@@ -103,8 +105,10 @@ def classification_custom(load_data, cats):
103
 
104
  prog.progress(0.5, "Running Classification")
105
  ref_dev = next(model_g14.parameters()).device
106
- enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
107
- sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
 
 
108
  argsort = torch.argsort(sim, descending=True)
109
  pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
110
  with col2:
@@ -197,11 +201,19 @@ try:
197
  f32 = numpy.float32
198
  half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
199
  clip_model, clip_prep = load_openclip()
200
- model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
201
  #model_g14 = load_tripletmix('tripletmix-spconv-all')
202
 
203
  st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
204
  st.sidebar.title("TripletMix Demo Configuration Panel")
 
 
 
 
 
 
 
 
205
  task = st.sidebar.selectbox(
206
  'Task Selection',
207
  ("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
 
33
 
34
 
35
  def load_tripletmix(name, to_cpu=False):
36
+ pce, pca = openshape.load_pc_encoder_mix(name)
37
  if to_cpu:
38
  pce = pce.cpu()
39
+ pca = pca.cpu()
40
+ return pce, pca
41
 
42
 
43
 
 
82
  col2 = utils.render_pc(pc)
83
  prog.progress(0.5, "Running Classification")
84
  ref_dev = next(model_g14.parameters()).device
85
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
86
+ if model_name == "pb-sn-M":
87
+ enc = pc_adapter(enc)
88
+ sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
89
  argsort = torch.argsort(sim, descending=True)
90
  pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
91
  with col2:
 
105
 
106
  prog.progress(0.5, "Running Classification")
107
  ref_dev = next(model_g14.parameters()).device
108
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
109
+ if model_name == "pb-sn-M":
110
+ enc = pc_adapter(enc)
111
+ sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
112
  argsort = torch.argsort(sim, descending=True)
113
  pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
114
  with col2:
 
201
  f32 = numpy.float32
202
  half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
203
  clip_model, clip_prep = load_openclip()
204
+ #model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
205
  #model_g14 = load_tripletmix('tripletmix-spconv-all')
206
 
207
  st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
208
  st.sidebar.title("TripletMix Demo Configuration Panel")
209
+ model_name = st.sidebar.selectbox(
210
+ 'Model Selection',
211
+ ("pb-sn-M", "pb-sn")
212
+ )
213
+ if model_name == "pb-sn-M":
214
+ model_g14, pc_adapter = load_tripletmix('tripletmix-pointbert-shapenet')
215
+ elif model_name == "pb-sn":
216
+ model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
217
  task = st.sidebar.selectbox(
218
  'Task Selection',
219
  ("3D Classification", "Cross-modal retrieval", "Cross-modal generation")