Spaces:
Sleeping
Sleeping
import streamlit as st | |
st.title("TripletMix Demo") | |
st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.") | |
prog = st.progress(0.0, "Idle") | |
tab_cls, tab_img, tab_text, tab_pc, tab_sd, tab_cap = st.tabs([ | |
"Classification", | |
"Retrieval w/ Image", | |
"Retrieval w/ Text", | |
"Retrieval w/ 3D", | |
"Image Generation", | |
"Captioning", | |
]) | |
def demo_classification(): | |
with st.form("clsform"): | |
#load_data = misc_utils.input_3d_shape('cls') | |
cats = st.text_input("Custom Categories (64 max, separated with comma)") | |
cats = [a.strip() for a in cats.split(',')] | |
if len(cats) > 64: | |
st.error('Maximum 64 custom categories supported in the demo') | |
return | |
lvis_run = st.form_submit_button("Run Classification on LVIS Categories") | |
custom_run = st.form_submit_button("Run Classification on Custom Categories") | |
""" | |
if lvis_run or auto_submit("clsauto"): | |
pc = load_data(prog) | |
col2 = misc_utils.render_pc(pc) | |
prog.progress(0.5, "Running Classification") | |
pred = classification.pred_lvis_sims(model_g14, pc) | |
with col2: | |
for i, (cat, sim) in zip(range(5), pred.items()): | |
st.text(cat) | |
st.caption("Similarity %.4f" % sim) | |
prog.progress(1.0, "Idle") | |
if custom_run: | |
pc = load_data(prog) | |
col2 = misc_utils.render_pc(pc) | |
prog.progress(0.5, "Computing Category Embeddings") | |
device = clip_model.device | |
tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76, padding=True).to(device) | |
feats = clip_model.get_text_features(**tn).float().cpu() | |
prog.progress(0.5, "Running Classification") | |
pred = classification.pred_custom_sims(model_g14, pc, cats, feats) | |
with col2: | |
for i, (cat, sim) in zip(range(5), pred.items()): | |
st.text(cat) | |
st.caption("Similarity %.4f" % sim) | |
prog.progress(1.0, "Idle") | |
""" | |
""" | |
if image_examples(samples_index.classification, 3, example_text="Examples (Choose one of the following 3D shapes)"): | |
queue_auto_submit("clsauto") | |
""" | |
try: | |
with tab_cls: | |
demo_classification() | |
""" | |
with tab_cap: | |
demo_captioning() | |
with tab_sd: | |
demo_pc2img() | |
demo_retrieval() | |
""" | |
except Exception: | |
import traceback | |
st.error(traceback.format_exc().replace("\n", " \n")) | |