chendl's picture
update chat
fdffde6
raw
history blame
14.1 kB
import os
# os.system("cd transformers && pip install .")
os.system("cd multimodal && pip install .")
os.system("cd multimodal/YOLOX && pip install .")
import numpy as np
import torch
from PIL import Image
import string
import cv2
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import hf_hub_download, login
from open_flamingo.src.factory import create_model_and_transforms
from open_flamingo.chat.conversation import Chat, CONV_VISION
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
"ViT-L-14",
"datacomp_xl_s13b_b90k",
"EleutherAI/pythia-1.4b",
"EleutherAI/pythia-1.4b",
location_token_num=1000,
lora=False,
lora_r=16,
use_sam=None,
add_visual_token=True,
use_format_v2=True,
add_box=True,
add_pe=False,
add_relation=False,
enhance_data=False,
)
checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
model_state_dict = {}
for key in checkpoint.keys():
model_state_dict[key.replace("module.", "")] = checkpoint[key]
if "vision_encoder.logit_scale"in model_state_dict:
# previous checkpoint has some unnecessary weights
del model_state_dict["vision_encoder.logit_scale"]
del model_state_dict["vision_encoder.visual.proj"]
del model_state_dict["vision_encoder.visual.ln_post.weight"]
del model_state_dict["vision_encoder.visual.ln_post.bias"]
flamingo.load_state_dict(model_state_dict, strict=True)
chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size )
def get_outputs(
model,
batch_images,
attention_mask,
max_generation_length,
min_generation_length,
num_beams,
length_penalty,
input_ids,
image_start_index_list=None,
image_nums=None,
bad_words_ids=None,
):
# and torch.cuda.amp.autocast(dtype=torch.float16)
with torch.inference_mode():
outputs = model(
vision_x=batch_images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=None,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=None,
add_box=False,
)
# outputs = model.generate(
# batch_images,
# input_ids,
# attention_mask=attention_mask,
# max_new_tokens=max_generation_length,
# min_length=min_generation_length,
# num_beams=num_beams,
# length_penalty=length_penalty,
# image_start_index_list=image_start_index_list,
# image_nums=image_nums,
# bad_words_ids=bad_words_ids,
# )
return outputs
def generate(
idx,
image,
text,
vis_embed_size=256,
rank=0,
world_size=1,
):
if image is None:
raise gr.Error("Please upload an image.")
flamingo.eval()
loc_token_ids = []
for i in range(1000):
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
image_ori = image
image = image.convert("RGB")
width = image.width
height = image.height
image = image.resize((224, 224))
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
if idx == 1:
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
bad_words_ids = None
max_generation_length = 5
else:
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
bad_words_ids = loc_word_ids
max_generation_length = 30
encodings = tokenizer(
prompt,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
outputs = get_outputs(
model=flamingo,
batch_images=batch_images,
attention_mask=attention_mask,
max_generation_length=max_generation_length,
min_generation_length=4,
num_beams=1,
length_penalty=1.0,
input_ids=input_ids,
bad_words_ids=bad_words_ids,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
)
boxes = outputs["boxes"]
scores = outputs["scores"]
if len(scores) > 0:
box = boxes[scores.argmax()]/224
print(f"{box}")
if idx == 1:
open_cv_image = np.array(image_ori)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
box = box*[width,height,width,height]
# for box in boxes:
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
return f"Output:{box}", out_image
elif idx == 2:
gen_text = tokenizer.batch_decode(outputs)
return (f"Question: {text.strip()} Answer: {gen_text}")
else:
gen_text = tokenizer.batch_decode(outputs)
return (f"Output:{gen_text}")
title = """<h1 align="center">Demo of Compositional-VLM</h1>"""
description = """<h3>This is the demo of Compositional-VLM. Upload your images and start chatting!</h3>"""
article = """<div style='display:flex; gap: 0.25rem; '><a href='https://compositionalvlm.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
"""
# TODO show examples below
# ========================================
# Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
interactive=False), gr.update(
value="Upload & Start Chat", interactive=True), chat_state, img_list
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = []
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
value="Start Chatting", interactive=False), chat_state, img_list
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
llm_message = \
chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(SHARED_UI_WARNING)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
num_beams = gr.Slider(
minimum=1,
maximum=5,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='Compositional-VLM')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
upload_button.click(upload_img, [image, text_input, chat_state],
[image, text_input, upload_button, chat_state, img_list])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
queue=False)
demo.launch(enable_queue=True)
#
# with gr.Blocks() as demo:
# gr.Markdown(
# """
# 🍜 Object Centric Pretraining Demo
# In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience.
# The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text.
# """
# )
#
# with gr.Accordion("See terms and conditions"):
# gr.Markdown(
# """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""")
#
# with gr.Tab("πŸ“· Image Captioning"):
# with gr.Row():
#
#
# query_image = gr.Image(type="pil")
# with gr.Row():
# chat_input = gr.Textbox(lines=1, label="Chat Input")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
#
# def on_click_fn(img,text): return generate(0, img, text)
#
# run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output])
#
# with gr.Tab("πŸ¦“ Grounding"):
# with gr.Row():
# with gr.Column(scale=1):
# query_image = gr.Image(type="pil")
# with gr.Column(scale=1):
# out_image = gr.Image(type="pil")
# with gr.Row():
# chat_input = gr.Textbox(lines=1, label="Chat Input")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img, text): return generate(1, img, text)
#
#
# run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image])
#
# with gr.Tab("πŸ”’ Counting objects"):
# with gr.Row():
# query_image = gr.Image(type="pil")
# with gr.Row():
# chat_input = gr.Textbox(lines=1, label="Chat Input")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img,text): return generate(0, img, text)
#
#
# run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output])
#
# with gr.Tab("πŸ•΅οΈ Visual Question Answering"):
# with gr.Row():
# query_image = gr.Image(type="pil")
# with gr.Row():
# question = gr.Textbox(lines=1, label="Question")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img, txt): return generate(2, img, txt)
#
#
# run_btn.click(
# on_click_fn, inputs=[query_image, question], outputs=[text_output]
# )
#
# with gr.Tab("🌎 Custom"):
# gr.Markdown(
# """### Customize the demonstration by uploading your own images and text samples.
# ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**"""
# )
# with gr.Row():
# query_image = gr.Image(type="pil")
# with gr.Row():
# question = gr.Textbox(lines=1, label="Question")
# text_output = gr.Textbox(value="Output:", label="Model output")
#
# run_btn = gr.Button("Run model")
#
#
# def on_click_fn(img, txt): return generate(2, img, txt)
#
#
# run_btn.click(
# on_click_fn, inputs=[query_image, question], outputs=[text_output]
# )
#
# demo.queue(concurrency_count=1)
# demo.launch()