Spaces:
Sleeping
Sleeping
winfred2027
commited on
Commit
•
b0a14dc
1
Parent(s):
3e21179
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import openshape
|
|
7 |
import transformers
|
8 |
from PIL import Image
|
9 |
from huggingface_hub import HfFolder, snapshot_download
|
10 |
-
from demo_support import retrieval
|
11 |
|
12 |
@st.cache_resource
|
13 |
def load_openclip():
|
@@ -22,6 +22,12 @@ def load_openclip():
|
|
22 |
clip_model.cuda()
|
23 |
return clip_model, clip_prep
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def retrieval_filter_expand(key):
|
27 |
with st.expander("Filters"):
|
@@ -85,11 +91,19 @@ def demo_retrieval():
|
|
85 |
with tab_pc:
|
86 |
with st.form("rpcform"):
|
87 |
k = st.slider("Number of items to retrieve", 1, 100, 16, key='rpc')
|
88 |
-
|
89 |
sim_th, filter_fn = retrieval_filter_expand('pc')
|
90 |
if st.form_submit_button("Retrieve with Point Cloud"):
|
91 |
prog.progress(0.49, "Computing Embeddings")
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
with tab_img:
|
95 |
with st.form("rimgform"):
|
@@ -141,6 +155,7 @@ tab_cls, tab_pc, tab_img, tab_text, tab_sd, tab_cap = st.tabs([
|
|
141 |
f32 = numpy.float32
|
142 |
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
143 |
clip_model, clip_prep = load_openclip()
|
|
|
144 |
|
145 |
try:
|
146 |
with tab_cls:
|
|
|
7 |
import transformers
|
8 |
from PIL import Image
|
9 |
from huggingface_hub import HfFolder, snapshot_download
|
10 |
+
from demo_support import retrieval, utils
|
11 |
|
12 |
@st.cache_resource
|
13 |
def load_openclip():
|
|
|
22 |
clip_model.cuda()
|
23 |
return clip_model, clip_prep
|
24 |
|
25 |
+
@st.cache_resource
|
26 |
+
def load_openshape(name, to_cpu=False):
|
27 |
+
pce = openshape.load_pc_encoder(name)
|
28 |
+
if to_cpu:
|
29 |
+
pce = pce.cpu()
|
30 |
+
return pce
|
31 |
|
32 |
def retrieval_filter_expand(key):
|
33 |
with st.expander("Filters"):
|
|
|
91 |
with tab_pc:
|
92 |
with st.form("rpcform"):
|
93 |
k = st.slider("Number of items to retrieve", 1, 100, 16, key='rpc')
|
94 |
+
load_data = utils.load_3D_shape('rpcinput')
|
95 |
sim_th, filter_fn = retrieval_filter_expand('pc')
|
96 |
if st.form_submit_button("Retrieve with Point Cloud"):
|
97 |
prog.progress(0.49, "Computing Embeddings")
|
98 |
+
pc = load_data(prog)
|
99 |
+
col2 = utils.render_pc(pc)
|
100 |
+
ref_dev = next(model_g14.parameters()).device
|
101 |
+
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
|
102 |
+
|
103 |
+
prog.progress(0.7, "Running Retrieval")
|
104 |
+
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
|
105 |
+
|
106 |
+
prog.progress(1.0, "Idle")
|
107 |
|
108 |
with tab_img:
|
109 |
with st.form("rimgform"):
|
|
|
155 |
f32 = numpy.float32
|
156 |
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
157 |
clip_model, clip_prep = load_openclip()
|
158 |
+
model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
|
159 |
|
160 |
try:
|
161 |
with tab_cls:
|