winfred2027 commited on
Commit
e03c2d8
1 Parent(s): 5d090b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -107,7 +107,7 @@ def classification_custom(load_data, cats):
107
  ref_dev = next(model_g14.parameters()).device
108
  enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
109
  if model_name == "pb-sn-M":
110
- enc = pc_adapter(enc)
111
  sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
112
  argsort = torch.argsort(sim, descending=True)
113
  pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
@@ -123,9 +123,10 @@ def retrieval_pc(load_data, k, sim_th, filter_fn):
123
  prog.progress(0.5, "Computing Embeddings")
124
  col2 = utils.render_pc(pc)
125
  ref_dev = next(model_g14.parameters()).device
126
- enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
127
-
128
- sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
 
129
  argsort = torch.argsort(sim, descending=True)
130
  pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
131
  with col2:
@@ -213,7 +214,7 @@ try:
213
  if model_name == "pb-sn-M":
214
  model_g14, pc_adapter = load_tripletmix('tripletmix-pointbert-shapenet')
215
  elif model_name == "pb-sn":
216
- model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
217
  task = st.sidebar.selectbox(
218
  'Task Selection',
219
  ("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
 
107
  ref_dev = next(model_g14.parameters()).device
108
  enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
109
  if model_name == "pb-sn-M":
110
+ enc = pc_adapter(enc)
111
  sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
112
  argsort = torch.argsort(sim, descending=True)
113
  pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
 
123
  prog.progress(0.5, "Computing Embeddings")
124
  col2 = utils.render_pc(pc)
125
  ref_dev = next(model_g14.parameters()).device
126
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
127
+ if model_name == "pb-sn-M":
128
+ enc = pc_adapter(enc)
129
+ sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
130
  argsort = torch.argsort(sim, descending=True)
131
  pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
132
  with col2:
 
214
  if model_name == "pb-sn-M":
215
  model_g14, pc_adapter = load_tripletmix('tripletmix-pointbert-shapenet')
216
  elif model_name == "pb-sn":
217
+ model_g14 = load_openshape('openshape-pointbert-shapenet')
218
  task = st.sidebar.selectbox(
219
  'Task Selection',
220
  ("3D Classification", "Cross-modal retrieval", "Cross-modal generation")