import argparse
import time
from PIL import Image
import torch
import numpy as np
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
import string
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download, login
from open_flamingo.src.factory import create_model_and_transforms
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
# system_img: List[Image.Image] = []
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
# system_img=self.system_img,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id)
def dict(self):
return {
"system": self.system,
# "system_img": self.system_img,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
CONV_VISION = Conversation(
system="Give the following image: ImageContent. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
def get_outputs(
model,
batch_images,
attention_mask,
max_generation_length,
min_generation_length,
num_beams,
length_penalty,
input_ids,
image_start_index_list=None,
image_nums=None,
bad_words_ids=None,
):
# and torch.cuda.amp.autocast(dtype=torch.float16)
with torch.inference_mode():
outputs = model(
vision_x=batch_images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=None,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=None,
add_box=False,
)
# outputs = model.generate(
# batch_images,
# input_ids,
# attention_mask=attention_mask,
# max_new_tokens=max_generation_length,
# min_length=min_generation_length,
# num_beams=num_beams,
# length_penalty=length_penalty,
# image_start_index_list=image_start_index_list,
# image_nums=image_nums,
# bad_words_ids=bad_words_ids,
# )
return outputs
def generate(
idx,
image,
text,
image_processor,
tokenizer,
flamingo,
vis_embed_size=256,
rank=0,
world_size=1,
):
if image is None:
raise gr.Error("Please upload an image.")
flamingo.eval()
loc_token_ids = []
for i in range(1000):
loc_token_ids.append(int(tokenizer(f"", add_special_tokens=False)["input_ids"][-1]))
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
image_ori = image
image = image.convert("RGB")
width = image.width
height = image.height
image = image.resize((224, 224))
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
if idx == 1:
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
bad_words_ids = None
max_generation_length = 5
else:
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
bad_words_ids = loc_word_ids
max_generation_length = 300
encodings = tokenizer(
prompt,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
outputs = get_outputs(
model=flamingo,
batch_images=batch_images,
attention_mask=attention_mask,
max_generation_length=max_generation_length,
min_generation_length=4,
num_beams=1,
length_penalty=1.0,
input_ids=input_ids,
bad_words_ids=bad_words_ids,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
)
boxes = outputs["boxes"]
scores = outputs["scores"]
if len(scores) > 0:
box = boxes[scores.argmax()]/224
print(f"{box}")
if len(boxes)>0:
open_cv_image = np.array(image_ori)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
box = box*[width,height,width,height]
# for box in boxes:
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
return f"Output:{box}", out_image
else:
gen_text = tokenizer.batch_decode(outputs)
return (f"{gen_text}")
def preprocess_conv(data):
conversation = ""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
for idx, d in enumerate(data):
from_str = d["from"]
if from_str.lower() == "human":
from_str = "Human"
elif from_str.lower() == "gpt":
from_str = "Assistant"
else:
from_str = 'unknown'
conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
return conversation
def preprocess_image(sample, image_processor):
image = image_processor(sample)
if isinstance(image, transformers.image_processing_utils.BatchFeature):
image = torch.tensor(image["pixel_values"][0])
return image
class Chat:
def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
self.model = model
self.vis_processor = vis_processor
self.tokenizer = tokenizer
self.vis_embed_size = vis_embed_size
self.conv = []
# stop_words_ids = [torch.tensor([835]).to(self.device),
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv,radio):
if radio in ["Cap"]:
conv.append({
"from": "human",
"value": "",
})
elif radio in ["VQA"]:
conv.append({
"from": "human",
"value": f"Answer the question using a single word or phrase. {text}",
})
elif radio in ["REC"]:
conv.append({
"from": "human",
"value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.",
})
else:
conv.append({
"from": "human",
"value": text,
})
# if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
# and conv.messages[-1][1][-6:] == '': # last message is image.
# conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
# else:
# conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, radio, text_input, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
# conv.append_message(conv.roles[1], None)
# embs = self.get_context_emb(conv, img_list)
#
# # current_max_len = embs.shape[1] + max_new_tokens + 100
# # begin_idx = max(0, current_max_len - max_length)
# # embs = embs[:, begin_idx:]
# outputs = self.model.llama_model.generate(
# inputs_embeds=embs,
# max_new_tokens=max_new_tokens,
# stopping_criteria=self.stopping_criteria,
# num_beams=num_beams,
# min_length=min_length,
# top_p=top_p,
# repetition_penalty=repetition_penalty,
# length_penalty=length_penalty,
# temperature=temperature,
# )
# output_token = outputs[0]
# if output_token[0] == 0:
# output_token = output_token[1:]
# output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
# output_text = output_text.split('###')[0] # remove the stop sign '###'
# output_text = output_text.split('Assistant:')[-1].strip()
# conv.messages[-1][1] = output_text
visual_token = "<|#visual#|>"
previsual_token = "<|#previsual#|>"
box_token = "<|#box#|>"
prebox_token = "<|#prebox#|>"
end_token = "<|#endofobject#|>"
object_token = "<|#object#|>"
end_of_attr_token = "<|#endofattr#|>"
preend_of_attr_token = "<|#preendofattr#|>"
media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
endofattr_token_id = self.tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = self.tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
size = 224
self.model.eval()
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
# image_path = input("Please enter the image path: ")
image = img_list[0].convert("RGB")
image_ori = image
image = image.resize((size, size))
print(f"image size: {image.size}")
batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
# conversation = []
human_sentence = None
if radio in ["Cap","VQA"]:
conv.append({
"from": "gpt",
"value": "",
})
elif radio in ["REC"]:
conv.append(
{
"from": "gpt",
"value": object_token + text_input + end_token + visual_token,
}
)
else:
conv.append({
"from": "gpt",
"value": "",
})
# while True:
# human_sentence = input("### Human: ")
# if human_sentence == "#end#":
# break
# conversation.append({
# "from": "human",
# "value": human_sentence,
# })
# conversation.append({
# "from": "gpt",
# "value": "",
# })
text = preprocess_conv(conv).strip()
caption = f"<|#image#|>{self.tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
encodings = self.tokenizer(
caption,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
added_bbox_list = []
with torch.inference_mode():
text_outputs = self.model.generate(
batch_images,
input_ids,
attention_mask=attention_mask,
max_new_tokens=20,
# min_new_tokens=8,
num_beams=1,
# length_penalty=0,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
)
# and torch.cuda.amp.autocast(dtype=torch.float16)
with torch.no_grad():
outputs = self.model(
vision_x=batch_images,
lang_x=input_ids,
attention_mask=attention_mask,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=None,
add_box=False,
)
boxes = outputs["boxes"]
scores = outputs["scores"]
if len(scores) > 0:
box = boxes[scores.argmax()] / 224
print(f"{box}")
out_image = None
if len(boxes)>0:
width, height = image_ori.size
open_cv_image = np.array(image_ori)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
box = box * [width, height, width, height]
# for box in boxes:
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
# output_token = outputs[0, input_ids.shape[1]:]
# output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
# conv[-1]["value"] = output_text
# # conv.messages[-1][1] = output_text
# print(
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
output_text = self.tokenizer.decode(text_outputs[0])
output_text = re.findall(r'Assistant:(.+)', output_text)[-1]
return output_text, out_image
def upload_img(self, image, conv, img_list):
img_list.append(image)
# if isinstance(image, str): # is a image path
# raw_image = Image.open(image).convert('RGB')
# image = image.resize((224, 224))
# image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
# elif isinstance(image, Image.Image):
# raw_image = image
# image = image.resize((224, 224))
# image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
# elif isinstance(image, torch.Tensor):
# if len(image.shape) == 3:
# image = image.unsqueeze(0)
# # image = image.to(self.device)
#
# # image_emb, _ = self.model.encode_img(image)
# img_list.append(image_emb)
# conv.append_message(conv.roles[0], "")
msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg
# def get_context_emb(self, conv, img_list):
# prompt = conv.get_prompt()
# prompt_segs = prompt.split('')
# assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
# seg_tokens = [
# self.model.llama_tokenizer(
# seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# # only add bos to the first seg
# for i, seg in enumerate(prompt_segs)
# ]
# seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
# mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
# mixed_embs = torch.cat(mixed_embs, dim=1)
# return mixed_embs
def evaluate_exp(
model,
tokenizer,
image_processor,
vis_embed_size=None,
rank=0,
world_size=1,
id=0,
add_visual=True,
):
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
size = image_processor.size["shortest_edge"]
model.eval()
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
image_path = input("Please enter the image path: ")
image = Image.open(image_path).convert("RGB")
image = image.resize((size, size))
print(f"image size: {image.size}")
batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
conversation = []
human_sentence = None
while True:
human_sentence = input("### Human: ")
if human_sentence == "#end#":
break
conversation.append({
"from": "human",
"value": human_sentence,
})
conversation.append({
"from": "gpt",
"value": "",
})
text = preprocess_conv(conversation).strip()
caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}"
encodings = tokenizer(
caption,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"].to("cuda")
attention_mask = encodings["attention_mask"].to("cuda")
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
outputs = model.generate(
batch_images,
input_ids,
attention_mask=attention_mask,
max_new_tokens=100,
# min_new_tokens=8,
num_beams=1,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
)
print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")