import os import sys from pathlib import Path # os.system("cd transformers && pip install .") os.system("cd multimodal && pip install -e .") import numpy as np import torch from PIL import Image import tempfile import string import cv2 import gradio as gr import torch from PIL import Image from huggingface_hub import hf_hub_download, login from open_flamingo.src.factory import create_model_and_transforms from open_flamingo.chat.conversation import Chat, CONV_VISION sys.path.append(str(Path(__file__).parent.parent.parent)) TEMP_FILE_DIR = Path(__file__).parent / 'temp' TEMP_FILE_DIR.mkdir(parents=True, exist_ok=True) SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue. You can duplicate and use it with a paid private GPU. Duplicate Space Alternatively, you can also use the demo on our [project page](https://compositionalvlm.github.io/). ''' flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms( "ViT-L-14", "datacomp_xl_s13b_b90k", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-1.4b", location_token_num=1000, lora=False, lora_r=16, use_sam=None, add_visual_token=True, use_format_v2=True, add_box=True, add_pe=False, add_relation=False, enhance_data=False, ) model_name = "pythiaS" checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt") checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] model_state_dict = {} for key in checkpoint.keys(): model_state_dict[key.replace("module.", "")] = checkpoint[key] if "vision_encoder.logit_scale" in model_state_dict: # previous checkpoint has some unnecessary weights del model_state_dict["vision_encoder.logit_scale"] del model_state_dict["vision_encoder.visual.proj"] del model_state_dict["vision_encoder.visual.ln_post.weight"] del model_state_dict["vision_encoder.visual.ln_post.bias"] flamingo.load_state_dict(model_state_dict, strict=True) chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size) def get_outputs( model, batch_images, attention_mask, max_generation_length, min_generation_length, num_beams, length_penalty, input_ids, image_start_index_list=None, image_nums=None, bad_words_ids=None, ): # and torch.cuda.amp.autocast(dtype=torch.float16) with torch.inference_mode(): outputs = model( vision_x=batch_images, lang_x=input_ids, attention_mask=attention_mask, labels=None, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=None, add_box=False, ) # outputs = model.generate( # batch_images, # input_ids, # attention_mask=attention_mask, # max_new_tokens=max_generation_length, # min_length=min_generation_length, # num_beams=num_beams, # length_penalty=length_penalty, # image_start_index_list=image_start_index_list, # image_nums=image_nums, # bad_words_ids=bad_words_ids, # ) return outputs def generate( idx, image, text, vis_embed_size=256, rank=0, world_size=1, ): if image is None: raise gr.Error("Please upload an image.") flamingo.eval() loc_token_ids = [] for i in range(1000): loc_token_ids.append(int(tokenizer(f"", add_special_tokens=False)["input_ids"][-1])) media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] image_ori = image image = image.convert("RGB") width = image.width height = image.height image = image.resize((224, 224)) batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) if idx == 1: prompt = [ f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"] bad_words_ids = None max_generation_length = 5 else: prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"] bad_words_ids = loc_word_ids max_generation_length = 30 encodings = tokenizer( prompt, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) outputs = get_outputs( model=flamingo, batch_images=batch_images, attention_mask=attention_mask, max_generation_length=max_generation_length, min_generation_length=4, num_beams=1, length_penalty=1.0, input_ids=input_ids, bad_words_ids=bad_words_ids, image_start_index_list=image_start_index_list, image_nums=image_nums, ) boxes = outputs["boxes"] scores = outputs["scores"] if len(scores) > 0: box = boxes[scores.argmax()] / 224 print(f"{box}") if idx == 1: open_cv_image = np.array(image_ori) # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() box = box * [width, height, width, height] # for box in boxes: open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2) out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) return f"Output:{box}", out_image elif idx == 2: gen_text = tokenizer.batch_decode(outputs) return (f"Question: {text.strip()} Answer: {gen_text}") else: gen_text = tokenizer.batch_decode(outputs) return (f"Output:{gen_text}") title = """

Demo of Compositional-VLM

""" description = """

This is the demo of Compositional-VLM. Upload your images and start chatting!

""" article = """
""" # TODO show examples below # ======================================== # Gradio Setting # ======================================== def gradio_reset(chat_state, img_list): if chat_state is not None: chat_state = [] if img_list is not None: img_list = [] return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update( value="Upload & Start Chat", interactive=True), chat_state, img_list def build_image(image): if image is None: return None # res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8) from torchvision.transforms import ToPILImage # res = ToPILImage()(res) _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) image.save(path) return path def upload_img(gr_img, text_input, chat_state, chatbot): if gr_img is None: return None, None, gr.update(interactive=True), chat_state, None chat_state = [] img_list = [] path = build_image(gr_img) chatbot = chatbot + [[(path,), None]] llm_message = chat.upload_img(gr_img, chat_state, img_list) return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update( value="Start Chatting", interactive=False), chat_state, img_list, chatbot def gradio_ask(user_message, chatbot, chat_state, radio): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state chat.ask(user_message, chat_state, radio, model_name) chatbot = chatbot + [[user_message, None]] return chatbot, chat_state def gradio_answer(chatbot, chat_state, img_list, radio, text, num_beams, temperature): image = None llm_message, image = \ chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000, radio=radio, text_input=text, model_name=model_name) chatbot[-1][1] = llm_message if chat_state[-1]["from"] == "gpt": chat_state[-1]["value"] = llm_message if image == None: return "", chatbot, chat_state, img_list else: path = build_image(image) chatbot = chatbot + [[None, (path,)]] return "", chatbot, chat_state, img_list task_template = { "Cap": "Summarize the content of the photo .", "VQA": "For this image , I want a simple and direct answer to my question: ", "REC": "Can you point out in the image and provide the coordinates of its location?", "GC": "Can you give me a description of the region in image ?", "Advanced": "", } with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(SHARED_UI_WARNING) gr.Markdown(description) gr.Markdown(article) with gr.Row(): with gr.Column(scale=0.5): image = gr.Image(type="pil") upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") clear = gr.Button("Restart") radio = gr.Radio( ["Cap", "VQA", "REC", "Advanced"], label="Task Template", value='Cap', ) num_beams = gr.Slider( minimum=1, maximum=5, value=1, step=1, interactive=True, label="beam search numbers)", ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature", ) with gr.Column(): chat_state = gr.State() img_list = gr.State() chatbot = gr.Chatbot(label='Compositional-VLM') # template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False, # value='Provide a comprehensive description of the image and specify the positions of any mentioned objects in square brackets.') # text_input = gr.Textbox(label='', show_label=True, placeholder="Please upload your image first, then input...", lines=3, # value=None, visible=False, interactive=False) text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...', interactive=False) upload_button.click(upload_img, [image, text_input, chat_state, chatbot], [image, text_input, upload_button, chat_state, img_list, chatbot]) text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then( gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature], [text_input, chatbot, chat_state, img_list] ) clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False) demo.launch(share=True) # # with gr.Blocks() as demo: # gr.Markdown( # """ # 🍜 Object Centric Pretraining Demo # In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience. # The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text. # """ # ) # # with gr.Accordion("See terms and conditions"): # gr.Markdown( # """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""") # # with gr.Tab("πŸ“· Image Captioning"): # with gr.Row(): # # # query_image = gr.Image(type="pil") # with gr.Row(): # chat_input = gr.Textbox(lines=1, label="Chat Input") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # # def on_click_fn(img,text): return generate(0, img, text) # # run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output]) # # with gr.Tab("πŸ¦“ Grounding"): # with gr.Row(): # with gr.Column(scale=1): # query_image = gr.Image(type="pil") # with gr.Column(scale=1): # out_image = gr.Image(type="pil") # with gr.Row(): # chat_input = gr.Textbox(lines=1, label="Chat Input") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img, text): return generate(1, img, text) # # # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image]) # # with gr.Tab("πŸ”’ Counting objects"): # with gr.Row(): # query_image = gr.Image(type="pil") # with gr.Row(): # chat_input = gr.Textbox(lines=1, label="Chat Input") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img,text): return generate(0, img, text) # # # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output]) # # with gr.Tab("πŸ•΅οΈ Visual Question Answering"): # with gr.Row(): # query_image = gr.Image(type="pil") # with gr.Row(): # question = gr.Textbox(lines=1, label="Question") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img, txt): return generate(2, img, txt) # # # run_btn.click( # on_click_fn, inputs=[query_image, question], outputs=[text_output] # ) # # with gr.Tab("🌎 Custom"): # gr.Markdown( # """### Customize the demonstration by uploading your own images and text samples. # ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**""" # ) # with gr.Row(): # query_image = gr.Image(type="pil") # with gr.Row(): # question = gr.Textbox(lines=1, label="Question") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img, txt): return generate(2, img, txt) # # # run_btn.click( # on_click_fn, inputs=[query_image, question], outputs=[text_output] # ) # # demo.queue(concurrency_count=1) # demo.launch()