Spaces:
Sleeping
Sleeping
winfred2027
commited on
Commit
•
fbeb0dc
1
Parent(s):
a403cac
Update app.py
Browse files
app.py
CHANGED
@@ -33,11 +33,10 @@ def load_openshape(name, to_cpu=False):
|
|
33 |
|
34 |
|
35 |
def load_tripletmix(name, to_cpu=False):
|
36 |
-
pce
|
37 |
if to_cpu:
|
38 |
pce = pce.cpu()
|
39 |
-
|
40 |
-
return pce, pca
|
41 |
|
42 |
|
43 |
|
@@ -81,10 +80,8 @@ def classification_lvis(load_data):
|
|
81 |
pc = load_data(prog)
|
82 |
col2 = utils.render_pc(pc)
|
83 |
prog.progress(0.5, "Running Classification")
|
84 |
-
ref_dev = next(
|
85 |
-
enc =
|
86 |
-
if model_name == "pb-sn-M":
|
87 |
-
enc = pc_adapter(enc)
|
88 |
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
89 |
argsort = torch.argsort(sim, descending=True)
|
90 |
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
@@ -104,10 +101,8 @@ def classification_custom(load_data, cats):
|
|
104 |
feats = clip_model.get_text_features(**tn).float().cpu()
|
105 |
|
106 |
prog.progress(0.5, "Running Classification")
|
107 |
-
ref_dev = next(
|
108 |
-
enc =
|
109 |
-
if model_name == "pb-sn-M":
|
110 |
-
enc = pc_adapter(enc)
|
111 |
sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
112 |
argsort = torch.argsort(sim, descending=True)
|
113 |
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
|
@@ -207,14 +202,7 @@ try:
|
|
207 |
|
208 |
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
209 |
st.sidebar.title("TripletMix Demo Configuration Panel")
|
210 |
-
|
211 |
-
'Model Selection',
|
212 |
-
("pb-sn-M", "pb-sn")
|
213 |
-
)
|
214 |
-
if model_name == "pb-sn-M":
|
215 |
-
model_g14, pc_adapter = load_tripletmix('tripletmix-pointbert-shapenet')
|
216 |
-
elif model_name == "pb-sn":
|
217 |
-
model_g14 = load_openshape('openshape-pointbert-shapenet')
|
218 |
task = st.sidebar.selectbox(
|
219 |
'Task Selection',
|
220 |
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|
@@ -225,6 +213,14 @@ try:
|
|
225 |
'Choose the source of categories',
|
226 |
("LVIS Categories", "Custom Categories")
|
227 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
load_data = utils.input_3d_shape('rpcinput')
|
229 |
if cls_mode == "LVIS Categories":
|
230 |
st.title("Classification with LVIS Categories")
|
|
|
33 |
|
34 |
|
35 |
def load_tripletmix(name, to_cpu=False):
|
36 |
+
pce = openshape.load_pc_encoder_mix(name)
|
37 |
if to_cpu:
|
38 |
pce = pce.cpu()
|
39 |
+
return pce
|
|
|
40 |
|
41 |
|
42 |
|
|
|
80 |
pc = load_data(prog)
|
81 |
col2 = utils.render_pc(pc)
|
82 |
prog.progress(0.5, "Running Classification")
|
83 |
+
ref_dev = next(model_classification.parameters()).device
|
84 |
+
enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
|
|
|
|
85 |
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
86 |
argsort = torch.argsort(sim, descending=True)
|
87 |
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
|
|
101 |
feats = clip_model.get_text_features(**tn).float().cpu()
|
102 |
|
103 |
prog.progress(0.5, "Running Classification")
|
104 |
+
ref_dev = next(model_classification.parameters()).device
|
105 |
+
enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
|
|
|
|
106 |
sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
107 |
argsort = torch.argsort(sim, descending=True)
|
108 |
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
|
|
|
202 |
|
203 |
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
204 |
st.sidebar.title("TripletMix Demo Configuration Panel")
|
205 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
task = st.sidebar.selectbox(
|
207 |
'Task Selection',
|
208 |
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|
|
|
213 |
'Choose the source of categories',
|
214 |
("LVIS Categories", "Custom Categories")
|
215 |
)
|
216 |
+
model_name = st.sidebar.selectbox(
|
217 |
+
'Model Selection',
|
218 |
+
("pb-Mix", "pb")
|
219 |
+
)
|
220 |
+
if model_name == "pb-Mix":
|
221 |
+
model_classification = load_tripletmix('tripletmix-pointbert-all-modelnet40')
|
222 |
+
elif model_name == "pb":
|
223 |
+
model_classification = load_openshape('openshape-pointbert-vitg14-rgb')
|
224 |
load_data = utils.input_3d_shape('rpcinput')
|
225 |
if cls_mode == "LVIS Categories":
|
226 |
st.title("Classification with LVIS Categories")
|