winfred2027 commited on
Commit
b0a14dc
1 Parent(s): 3e21179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -7,7 +7,7 @@ import openshape
7
  import transformers
8
  from PIL import Image
9
  from huggingface_hub import HfFolder, snapshot_download
10
- from demo_support import retrieval
11
 
12
  @st.cache_resource
13
  def load_openclip():
@@ -22,6 +22,12 @@ def load_openclip():
22
  clip_model.cuda()
23
  return clip_model, clip_prep
24
 
 
 
 
 
 
 
25
 
26
  def retrieval_filter_expand(key):
27
  with st.expander("Filters"):
@@ -85,11 +91,19 @@ def demo_retrieval():
85
  with tab_pc:
86
  with st.form("rpcform"):
87
  k = st.slider("Number of items to retrieve", 1, 100, 16, key='rpc')
88
- #pc = utils.load_3D_shape('rpcinput')
89
  sim_th, filter_fn = retrieval_filter_expand('pc')
90
  if st.form_submit_button("Retrieve with Point Cloud"):
91
  prog.progress(0.49, "Computing Embeddings")
92
-
 
 
 
 
 
 
 
 
93
 
94
  with tab_img:
95
  with st.form("rimgform"):
@@ -141,6 +155,7 @@ tab_cls, tab_pc, tab_img, tab_text, tab_sd, tab_cap = st.tabs([
141
  f32 = numpy.float32
142
  half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
143
  clip_model, clip_prep = load_openclip()
 
144
 
145
  try:
146
  with tab_cls:
 
7
  import transformers
8
  from PIL import Image
9
  from huggingface_hub import HfFolder, snapshot_download
10
+ from demo_support import retrieval, utils
11
 
12
  @st.cache_resource
13
  def load_openclip():
 
22
  clip_model.cuda()
23
  return clip_model, clip_prep
24
 
25
+ @st.cache_resource
26
+ def load_openshape(name, to_cpu=False):
27
+ pce = openshape.load_pc_encoder(name)
28
+ if to_cpu:
29
+ pce = pce.cpu()
30
+ return pce
31
 
32
  def retrieval_filter_expand(key):
33
  with st.expander("Filters"):
 
91
  with tab_pc:
92
  with st.form("rpcform"):
93
  k = st.slider("Number of items to retrieve", 1, 100, 16, key='rpc')
94
+ load_data = utils.load_3D_shape('rpcinput')
95
  sim_th, filter_fn = retrieval_filter_expand('pc')
96
  if st.form_submit_button("Retrieve with Point Cloud"):
97
  prog.progress(0.49, "Computing Embeddings")
98
+ pc = load_data(prog)
99
+ col2 = utils.render_pc(pc)
100
+ ref_dev = next(model_g14.parameters()).device
101
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
102
+
103
+ prog.progress(0.7, "Running Retrieval")
104
+ retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
105
+
106
+ prog.progress(1.0, "Idle")
107
 
108
  with tab_img:
109
  with st.form("rimgform"):
 
155
  f32 = numpy.float32
156
  half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
157
  clip_model, clip_prep = load_openclip()
158
+ model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
159
 
160
  try:
161
  with tab_cls: