import os
import sys
from pathlib import Path
# os.system("cd transformers && pip install .")
os.system("cd multimodal && pip install -e .")
import numpy as np
import torch
from PIL import Image
import tempfile
import string
import cv2
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import hf_hub_download, login
from open_flamingo.src.factory import create_model_and_transforms
from open_flamingo.chat.conversation import Chat, CONV_VISION
sys.path.append(str(Path(__file__).parent.parent.parent))
TEMP_FILE_DIR = Path(__file__).parent / 'temp'
TEMP_FILE_DIR.mkdir(parents=True, exist_ok=True)
SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
You can duplicate and use it with a paid private GPU.
Alternatively, you can also use the demo on our [project page](https://compositionalvlm.github.io/).
'''
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
"ViT-L-14",
"datacomp_xl_s13b_b90k",
"EleutherAI/pythia-1.4b",
"EleutherAI/pythia-1.4b",
location_token_num=1000,
lora=False,
lora_r=16,
use_sam=None,
add_visual_token=True,
use_format_v2=True,
add_box=True,
add_pe=False,
add_relation=False,
enhance_data=False,
)
model_name = "pythiaS"
checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
model_state_dict = {}
for key in checkpoint.keys():
model_state_dict[key.replace("module.", "")] = checkpoint[key]
if "vision_encoder.logit_scale" in model_state_dict:
# previous checkpoint has some unnecessary weights
del model_state_dict["vision_encoder.logit_scale"]
del model_state_dict["vision_encoder.visual.proj"]
del model_state_dict["vision_encoder.visual.ln_post.weight"]
del model_state_dict["vision_encoder.visual.ln_post.bias"]
flamingo.load_state_dict(model_state_dict, strict=True)
chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size)
def get_outputs(
model,
batch_images,
attention_mask,
max_generation_length,
min_generation_length,
num_beams,
length_penalty,
input_ids,
image_start_index_list=None,
image_nums=None,
bad_words_ids=None,
):
# and torch.cuda.amp.autocast(dtype=torch.float16)
with torch.inference_mode():
outputs = model(
vision_x=batch_images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=None,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=None,
add_box=False,
)
# outputs = model.generate(
# batch_images,
# input_ids,
# attention_mask=attention_mask,
# max_new_tokens=max_generation_length,
# min_length=min_generation_length,
# num_beams=num_beams,
# length_penalty=length_penalty,
# image_start_index_list=image_start_index_list,
# image_nums=image_nums,
# bad_words_ids=bad_words_ids,
# )
return outputs
def generate(
idx,
image,
text,
vis_embed_size=256,
rank=0,
world_size=1,
):
if image is None:
raise gr.Error("Please upload an image.")
flamingo.eval()
loc_token_ids = []
for i in range(1000):
loc_token_ids.append(int(tokenizer(f"", add_special_tokens=False)["input_ids"][-1]))
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
image_ori = image
image = image.convert("RGB")
width = image.width
height = image.height
image = image.resize((224, 224))
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
if idx == 1:
prompt = [
f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
bad_words_ids = None
max_generation_length = 5
else:
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
bad_words_ids = loc_word_ids
max_generation_length = 30
encodings = tokenizer(
prompt,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
outputs = get_outputs(
model=flamingo,
batch_images=batch_images,
attention_mask=attention_mask,
max_generation_length=max_generation_length,
min_generation_length=4,
num_beams=1,
length_penalty=1.0,
input_ids=input_ids,
bad_words_ids=bad_words_ids,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
)
boxes = outputs["boxes"]
scores = outputs["scores"]
if len(scores) > 0:
box = boxes[scores.argmax()] / 224
print(f"{box}")
if idx == 1:
open_cv_image = np.array(image_ori)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
box = box * [width, height, width, height]
# for box in boxes:
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
return f"Output:{box}", out_image
elif idx == 2:
gen_text = tokenizer.batch_decode(outputs)
return (f"Question: {text.strip()} Answer: {gen_text}")
else:
gen_text = tokenizer.batch_decode(outputs)
return (f"Output:{gen_text}")
title = """Demo of Compositional-VLM
"""
description = """This is the demo of Compositional-VLM. Upload your images and start chatting!
"""
article = """
"""
# TODO show examples below
# ========================================
# Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
interactive=False), gr.update(
value="Upload & Start Chat", interactive=True), chat_state, img_list
def build_image(image):
if image is None:
return None
# res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8)
from torchvision.transforms import ToPILImage
# res = ToPILImage()(res)
_, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR)
image.save(path)
return path
def upload_img(gr_img, text_input, chat_state, chatbot):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = []
img_list = []
path = build_image(gr_img)
chatbot = chatbot + [[(path,), None]]
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
value="Start Chatting", interactive=False), chat_state, img_list, chatbot
def gradio_ask(user_message, chatbot, chat_state, radio):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
chat.ask(user_message, chat_state, radio, model_name)
chatbot = chatbot + [[user_message, None]]
return chatbot, chat_state
def gradio_answer(chatbot, chat_state, img_list, radio, text, num_beams, temperature):
image = None
llm_message, image = \
chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
max_length=2000, radio=radio, text_input=text, model_name=model_name)
chatbot[-1][1] = llm_message
if chat_state[-1]["from"] == "gpt":
chat_state[-1]["value"] = llm_message
if image == None:
return "", chatbot, chat_state, img_list
else:
path = build_image(image)
chatbot = chatbot + [[None, (path,)]]
return "", chatbot, chat_state, img_list
task_template = {
"Cap": "Summarize the content of the photo .",
"VQA": "For this image , I want a simple and direct answer to my question: ",
"REC": "Can you point out in the image and provide the coordinates of its location?",
"GC": "Can you give me a description of the region in image ?",
"Advanced": "",
}
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(SHARED_UI_WARNING)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
radio = gr.Radio(
["Cap", "VQA", "REC", "Advanced"], label="Task Template", value='Cap',
)
num_beams = gr.Slider(
minimum=1,
maximum=5,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='Compositional-VLM')
# template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False,
# value='Provide a comprehensive description of the image and specify the positions of any mentioned objects in square brackets.')
# text_input = gr.Textbox(label='', show_label=True, placeholder="Please upload your image firstοΌ then input...", lines=3,
# value=None, visible=False, interactive=False)
text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...',
interactive=False)
upload_button.click(upload_img, [image, text_input, chat_state, chatbot],
[image, text_input, upload_button, chat_state, img_list, chatbot])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature],
[text_input, chatbot, chat_state, img_list]
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
queue=False)
demo.launch(share=True)
#
# with gr.Blocks() as demo:
# gr.Markdown(
# """
# π Object Centric Pretraining Demo
# In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience.
# The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text.
# """
# )
#
# with gr.Accordion("See terms and conditions"):
# gr.Markdown(
# """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""")
#
# with gr.Tab("π· Image Captioning"):
# with gr.Row():
#
#
# query_image = gr.Image(type="pil")
# with gr.Row():
# chat_input = gr.Textbox(lines=1, label="Chat Input")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
#
# def on_click_fn(img,text): return generate(0, img, text)
#
# run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output])
#
# with gr.Tab("π¦ Grounding"):
# with gr.Row():
# with gr.Column(scale=1):
# query_image = gr.Image(type="pil")
# with gr.Column(scale=1):
# out_image = gr.Image(type="pil")
# with gr.Row():
# chat_input = gr.Textbox(lines=1, label="Chat Input")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img, text): return generate(1, img, text)
#
#
# run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image])
#
# with gr.Tab("π’ Counting objects"):
# with gr.Row():
# query_image = gr.Image(type="pil")
# with gr.Row():
# chat_input = gr.Textbox(lines=1, label="Chat Input")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img,text): return generate(0, img, text)
#
#
# run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output])
#
# with gr.Tab("π΅οΈ Visual Question Answering"):
# with gr.Row():
# query_image = gr.Image(type="pil")
# with gr.Row():
# question = gr.Textbox(lines=1, label="Question")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img, txt): return generate(2, img, txt)
#
#
# run_btn.click(
# on_click_fn, inputs=[query_image, question], outputs=[text_output]
# )
#
# with gr.Tab("π Custom"):
# gr.Markdown(
# """### Customize the demonstration by uploading your own images and text samples.
# ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**"""
# )
# with gr.Row():
# query_image = gr.Image(type="pil")
# with gr.Row():
# question = gr.Textbox(lines=1, label="Question")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img, txt): return generate(2, img, txt)
#
#
# run_btn.click(
# on_click_fn, inputs=[query_image, question], outputs=[text_output]
# )
#
# demo.queue(concurrency_count=1)
# demo.launch()