import sys import threading import streamlit as st import numpy import torch import openshape import transformers from PIL import Image from huggingface_hub import HfFolder, snapshot_download from demo_support import retrieval, generation, utils, lvis from collections import OrderedDict @st.cache_resource def load_openclip(): sys.clip_move_lock = threading.Lock() clip_model, clip_prep = transformers.CLIPModel.from_pretrained( "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", low_cpu_mem_usage=True, torch_dtype=half, offload_state_dict=True ), transformers.CLIPProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") if torch.cuda.is_available(): with sys.clip_move_lock: clip_model.cuda() return clip_model, clip_prep @st.cache_resource def load_openshape(name, to_cpu=False): pce = openshape.load_pc_encoder(name) if to_cpu: pce = pce.cpu() return pce def load_tripletmix(name, to_cpu=False): pce = openshape.load_pc_encoder_mix(name) if to_cpu: pce = pce.cpu() return pce def retrieval_filter_expand(): sim_th = st.sidebar.slider("Similarity Threshold", 0.05, 0.5, 0.1, key='rsimth') tag = "" face_min = 0 face_max = 34985808 anim_min = 0 anim_max = 563 tag_n = not bool(tag.strip()) anim_n = not (anim_min > 0 or anim_max < 563) face_n = not (face_min > 0 or face_max < 34985808) filter_fn = lambda x: ( (anim_n or anim_min <= x['anims'] <= anim_max) and (face_n or face_min <= x['faces'] <= face_max) and (tag_n or tag in x['tags']) ) return sim_th, filter_fn def retrieval_results(results): st.caption("Click the link to view the 3D shape") for i in range(len(results) // 4): cols = st.columns(4) for j in range(4): idx = i * 4 + j if idx >= len(results): continue entry = results[idx] with cols[j]: ext_link = f"https://objaverse.allenai.org/explore/?query={entry['u']}" st.image(entry['img']) # st.markdown(f"[![thumbnail {entry['desc'].replace('\n', ' ')}]({entry['img']})]({ext_link})") # st.text(entry['name']) quote_name = entry['name'].replace('[', '\\[').replace(']', '\\]').replace('\n', ' ') st.markdown(f"[{quote_name}]({ext_link})") def classification_lvis(load_data): pc = load_data(prog) col2 = utils.render_pc(pc) prog.progress(0.5, "Running Classification") ref_dev = next(model_g14.parameters()).device enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu() sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze()) argsort = torch.argsort(sim, descending=True) pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories)) with col2: for i, (cat, sim) in zip(range(5), pred.items()): st.text(cat) st.caption("Similarity %.4f" % sim) prog.progress(1.0, "Idle") def classification_custom(load_data, cats): pc = load_data(prog) col2 = utils.render_pc(pc) prog.progress(0.5, "Computing Category Embeddings") device = clip_model.device tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76, padding=True).to(device) feats = clip_model.get_text_features(**tn).float().cpu() prog.progress(0.5, "Running Classification") ref_dev = next(model_g14.parameters()).device enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu() sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze()) argsort = torch.argsort(sim, descending=True) pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats)) with col2: for i, (cat, sim) in zip(range(5), pred.items()): st.text(cat) st.caption("Similarity %.4f" % sim) prog.progress(1.0, "Idle") def retrieval_pc(load_data, k, sim_th, filter_fn): pc = load_data(prog) prog.progress(0.5, "Computing Embeddings") col2 = utils.render_pc(pc) ref_dev = next(model_g14.parameters()).device enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu() sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze()) argsort = torch.argsort(sim, descending=True) pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories)) with col2: for i, (cat, sim) in zip(range(5), pred.items()): st.text(cat) st.caption("Similarity %.4f" % sim) prog.progress(0.7, "Running Retrieval") retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn)) prog.progress(1.0, "Idle") def retrieval_img(pic, k, sim_th, filter_fn): img = Image.open(pic) prog.progress(0.5, "Computing Embeddings") st.image(img) device = clip_model.device tn = clip_prep(images=[img], return_tensors="pt").to(device) enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu() prog.progress(0.7, "Running Retrieval") retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn)) prog.progress(1.0, "Idle") def retrieval_text(text, k, sim_th, filter_fn): prog.progress(0.5, "Computing Embeddings") device = clip_model.device tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device) enc = clip_model.get_text_features(**tn).float().cpu() prog.progress(0.7, "Running Retrieval") retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn)) prog.progress(1.0, "Idle") def generation_img(load_data, prompt, noise_scale, cfg_scale, steps): pc = load_data(prog) prog.progress(0.5, "Running Generation") col2 = utils.render_pc(pc) if torch.cuda.is_available(): with sys.clip_move_lock: clip_model.cpu() width = 640 height = 640 img = generation.pc_to_image( model_g14, pc, prompt, noise_scale, width, height, cfg_scale, steps, lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i) ) if torch.cuda.is_available(): with sys.clip_move_lock: clip_model.cuda() with col2: st.image(img) prog.progress(1.0, "Idle") def generation_text(load_data, cond_scale): pc = load_data(prog) prog.progress(0.5, "Running Generation") col2 = utils.render_pc(pc) cap = generation.pc_to_text(model_g14, pc, cond_scale) st.text(cap) prog.progress(1.0, "Idle") try: f32 = numpy.float32 half = torch.float16 if torch.cuda.is_available() else torch.bfloat16 clip_model, clip_prep = load_openclip() #model_g14 = load_openshape('openshape-pointbert-vitg14-rgb') model_g14 = load_tripletmix('tripletmix-spconv-all') st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.") st.sidebar.title("TripletMix Demo Configuration Panel") task = st.sidebar.selectbox( 'Task Selection', ("3D Classification", "Cross-modal retrieval", "Cross-modal generation") ) if task == "3D Classification": cls_mode = st.sidebar.selectbox( 'Choose the source of categories', ("LVIS Categories", "Custom Categories") ) load_data = utils.input_3d_shape('rpcinput') if cls_mode == "LVIS Categories": st.title("Classification with LVIS Categories") prog = st.progress(0.0, "Idle") if st.sidebar.button("submit"): classification_lvis(load_data) elif cls_mode == "Custom Categories": st.title("Classification with Custom Categories") prog = st.progress(0.0, "Idle") cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)") cats = [a.strip() for a in cats.split(',')] if len(cats) > 64: st.error('Maximum 64 custom categories supported in the demo') if st.sidebar.button("submit"): classification_custom(load_data, cats) elif task == "Cross-modal retrieval": input_mode = st.sidebar.selectbox( 'Choose an input modality', ("Point Cloud", "Image", "Text") ) k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum') sim_th, filter_fn = retrieval_filter_expand() if input_mode == "Point Cloud": st.title("Retrieval with Point Cloud") prog = st.progress(0.0, "Idle") load_data = utils.input_3d_shape('rpcinput') if st.sidebar.button("submit"): retrieval_pc(load_data, k, sim_th, filter_fn) elif input_mode == "Image": st.title("Retrieval with Image") prog = st.progress(0.0, "Idle") pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput') if st.sidebar.button("submit"): retrieval_img(pic, k, sim_th, filter_fn) elif input_mode == "Text": st.title("Retrieval with Text") prog = st.progress(0.0, "Idle") text = st.sidebar.text_input("Input Text", key='rtextinput') if st.sidebar.button("submit"): retrieval_text(text, k, sim_th, filter_fn) elif task == "Cross-modal generation": generation_mode = st.sidebar.selectbox( 'Choose the mode of generation', ("PointCloud-to-Image", "PointCloud-to-Text") ) load_data = utils.input_3d_shape('rpcinput') if generation_mode == "PointCloud-to-Image": st.title("Image Generation") prog = st.progress(0.0, "Idle") prompt = st.sidebar.text_input("Prompt (Optional)", key='gprompt') noise_scale = st.sidebar.slider('Variation Level', 0, 5, 1) cfg_scale = st.sidebar.slider('Guidance Scale', 0.0, 30.0, 10.0) steps = st.sidebar.slider('Diffusion Steps', 8, 50, 25) if st.sidebar.button("submit"): generation_img(load_data, prompt, noise_scale, cfg_scale, steps) elif generation_mode == "PointCloud-to-Text": st.title("Text Generation") prog = st.progress(0.0, "Idle") cond_scale = st.sidebar.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='gcond') if st.sidebar.button("submit"): generation_text(load_data, cond_scale) except Exception: import traceback st.error(traceback.format_exc().replace("\n", " \n"))