import os # os.system("cd transformers && pip install .") os.system("cd multimodal && pip install .") os.system("cd multimodal/YOLOX && pip install .") import numpy as np import torch from PIL import Image import string import cv2 import gradio as gr import torch from PIL import Image from huggingface_hub import hf_hub_download, login from open_flamingo.src.factory import create_model_and_transforms from open_flamingo.chat.conversation import Chat, CONV_VISION SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue. You can duplicate and use it with a paid private GPU. Duplicate Space Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io). ''' flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms( "ViT-L-14", "datacomp_xl_s13b_b90k", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-1.4b", location_token_num=1000, lora=False, lora_r=16, use_sam=None, add_visual_token=True, use_format_v2=True, add_box=True, add_pe=False, add_relation=False, enhance_data=False, ) checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt") checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] model_state_dict = {} for key in checkpoint.keys(): model_state_dict[key.replace("module.", "")] = checkpoint[key] if "vision_encoder.logit_scale"in model_state_dict: # previous checkpoint has some unnecessary weights del model_state_dict["vision_encoder.logit_scale"] del model_state_dict["vision_encoder.visual.proj"] del model_state_dict["vision_encoder.visual.ln_post.weight"] del model_state_dict["vision_encoder.visual.ln_post.bias"] flamingo.load_state_dict(model_state_dict, strict=True) chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size ) def get_outputs( model, batch_images, attention_mask, max_generation_length, min_generation_length, num_beams, length_penalty, input_ids, image_start_index_list=None, image_nums=None, bad_words_ids=None, ): # and torch.cuda.amp.autocast(dtype=torch.float16) with torch.inference_mode(): outputs = model( vision_x=batch_images, lang_x=input_ids, attention_mask=attention_mask, labels=None, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=None, add_box=False, ) # outputs = model.generate( # batch_images, # input_ids, # attention_mask=attention_mask, # max_new_tokens=max_generation_length, # min_length=min_generation_length, # num_beams=num_beams, # length_penalty=length_penalty, # image_start_index_list=image_start_index_list, # image_nums=image_nums, # bad_words_ids=bad_words_ids, # ) return outputs def generate( idx, image, text, vis_embed_size=256, rank=0, world_size=1, ): if image is None: raise gr.Error("Please upload an image.") flamingo.eval() loc_token_ids = [] for i in range(1000): loc_token_ids.append(int(tokenizer(f"", add_special_tokens=False)["input_ids"][-1])) media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] image_ori = image image = image.convert("RGB") width = image.width height = image.height image = image.resize((224, 224)) batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) if idx == 1: prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"] bad_words_ids = None max_generation_length = 5 else: prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"] bad_words_ids = loc_word_ids max_generation_length = 30 encodings = tokenizer( prompt, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) outputs = get_outputs( model=flamingo, batch_images=batch_images, attention_mask=attention_mask, max_generation_length=max_generation_length, min_generation_length=4, num_beams=1, length_penalty=1.0, input_ids=input_ids, bad_words_ids=bad_words_ids, image_start_index_list=image_start_index_list, image_nums=image_nums, ) boxes = outputs["boxes"] scores = outputs["scores"] if len(scores) > 0: box = boxes[scores.argmax()]/224 print(f"{box}") if idx == 1: open_cv_image = np.array(image_ori) # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() box = box*[width,height,width,height] # for box in boxes: open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2) out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) return f"Output:{box}", out_image elif idx == 2: gen_text = tokenizer.batch_decode(outputs) return (f"Question: {text.strip()} Answer: {gen_text}") else: gen_text = tokenizer.batch_decode(outputs) return (f"Output:{gen_text}") title = """

Demo of Compositional-VLM

""" description = """

This is the demo of Compositional-VLM. Upload your images and start chatting!

""" article = """
""" # TODO show examples below # ======================================== # Gradio Setting # ======================================== def gradio_reset(chat_state, img_list): if chat_state is not None: chat_state = [] if img_list is not None: img_list = [] return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update( value="Upload & Start Chat", interactive=True), chat_state, img_list def upload_img(gr_img, text_input, chat_state): if gr_img is None: return None, None, gr.update(interactive=True), chat_state, None chat_state = [] img_list = [] llm_message = chat.upload_img(gr_img, chat_state, img_list) return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update( value="Start Chatting", interactive=False), chat_state, img_list def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state chat.ask(user_message, chat_state) chatbot = chatbot + [[user_message, None]] return '', chatbot, chat_state def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): llm_message = \ chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0] chatbot[-1][1] = llm_message return chatbot, chat_state, img_list with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(SHARED_UI_WARNING) gr.Markdown(description) gr.Markdown(article) with gr.Row(): with gr.Column(scale=0.5): image = gr.Image(type="pil") upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") clear = gr.Button("Restart") num_beams = gr.Slider( minimum=1, maximum=5, value=1, step=1, interactive=True, label="beam search numbers)", ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature", ) with gr.Column(): chat_state = gr.State() img_list = gr.State() chatbot = gr.Chatbot(label='Compositional-VLM') text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False) upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list]) text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then( gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] ) clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False) demo.launch(enable_queue=True) # # with gr.Blocks() as demo: # gr.Markdown( # """ # 🍜 Object Centric Pretraining Demo # In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience. # The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text. # """ # ) # # with gr.Accordion("See terms and conditions"): # gr.Markdown( # """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""") # # with gr.Tab("📷 Image Captioning"): # with gr.Row(): # # # query_image = gr.Image(type="pil") # with gr.Row(): # chat_input = gr.Textbox(lines=1, label="Chat Input") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # # def on_click_fn(img,text): return generate(0, img, text) # # run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output]) # # with gr.Tab("🦓 Grounding"): # with gr.Row(): # with gr.Column(scale=1): # query_image = gr.Image(type="pil") # with gr.Column(scale=1): # out_image = gr.Image(type="pil") # with gr.Row(): # chat_input = gr.Textbox(lines=1, label="Chat Input") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img, text): return generate(1, img, text) # # # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image]) # # with gr.Tab("🔢 Counting objects"): # with gr.Row(): # query_image = gr.Image(type="pil") # with gr.Row(): # chat_input = gr.Textbox(lines=1, label="Chat Input") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img,text): return generate(0, img, text) # # # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output]) # # with gr.Tab("🕵️ Visual Question Answering"): # with gr.Row(): # query_image = gr.Image(type="pil") # with gr.Row(): # question = gr.Textbox(lines=1, label="Question") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img, txt): return generate(2, img, txt) # # # run_btn.click( # on_click_fn, inputs=[query_image, question], outputs=[text_output] # ) # # with gr.Tab("🌎 Custom"): # gr.Markdown( # """### Customize the demonstration by uploading your own images and text samples. # ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**""" # ) # with gr.Row(): # query_image = gr.Image(type="pil") # with gr.Row(): # question = gr.Textbox(lines=1, label="Question") # text_output = gr.Textbox(value="Output:", label="Model output") # # run_btn = gr.Button("Run model") # # # def on_click_fn(img, txt): return generate(2, img, txt) # # # run_btn.click( # on_click_fn, inputs=[query_image, question], outputs=[text_output] # ) # # demo.queue(concurrency_count=1) # demo.launch()