Spaces:
Runtime error
Runtime error
import argparse | |
import time | |
from PIL import Image | |
import torch | |
import numpy as np | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
import dataclasses | |
from enum import auto, Enum | |
from typing import List, Tuple, Any | |
import string | |
import cv2 | |
import gradio as gr | |
from huggingface_hub import hf_hub_download, login | |
from open_flamingo.src.factory import create_model_and_transforms | |
class SeparatorStyle(Enum): | |
"""Different separator style.""" | |
SINGLE = auto() | |
TWO = auto() | |
class Conversation: | |
"""A class that keeps all conversation history.""" | |
system: str | |
roles: List[str] | |
messages: List[List[str]] | |
offset: int | |
# system_img: List[Image.Image] = [] | |
sep_style: SeparatorStyle = SeparatorStyle.SINGLE | |
sep: str = "###" | |
sep2: str = None | |
skip_next: bool = False | |
conv_id: Any = None | |
def get_prompt(self): | |
if self.sep_style == SeparatorStyle.SINGLE: | |
ret = self.system + self.sep | |
for role, message in self.messages: | |
if message: | |
ret += role + ": " + message + self.sep | |
else: | |
ret += role + ":" | |
return ret | |
elif self.sep_style == SeparatorStyle.TWO: | |
seps = [self.sep, self.sep2] | |
ret = self.system + seps[0] | |
for i, (role, message) in enumerate(self.messages): | |
if message: | |
ret += role + ": " + message + seps[i % 2] | |
else: | |
ret += role + ":" | |
return ret | |
else: | |
raise ValueError(f"Invalid style: {self.sep_style}") | |
def append_message(self, role, message): | |
self.messages.append([role, message]) | |
def to_gradio_chatbot(self): | |
ret = [] | |
for i, (role, msg) in enumerate(self.messages[self.offset:]): | |
if i % 2 == 0: | |
ret.append([msg, None]) | |
else: | |
ret[-1][-1] = msg | |
return ret | |
def copy(self): | |
return Conversation( | |
system=self.system, | |
# system_img=self.system_img, | |
roles=self.roles, | |
messages=[[x, y] for x, y in self.messages], | |
offset=self.offset, | |
sep_style=self.sep_style, | |
sep=self.sep, | |
sep2=self.sep2, | |
conv_id=self.conv_id) | |
def dict(self): | |
return { | |
"system": self.system, | |
# "system_img": self.system_img, | |
"roles": self.roles, | |
"messages": self.messages, | |
"offset": self.offset, | |
"sep": self.sep, | |
"sep2": self.sep2, | |
"conv_id": self.conv_id, | |
} | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all((stop == input_ids[0][-len(stop):])).item(): | |
return True | |
return False | |
CONV_VISION = Conversation( | |
system="Give the following image: <Img>ImageContent</Img>. " | |
"You will be able to see the image once I provide it to you. Please answer my questions.", | |
roles=("Human", "Assistant"), | |
messages=[], | |
offset=2, | |
sep_style=SeparatorStyle.SINGLE, | |
sep="###", | |
) | |
def get_outputs( | |
model, | |
batch_images, | |
attention_mask, | |
max_generation_length, | |
min_generation_length, | |
num_beams, | |
length_penalty, | |
input_ids, | |
image_start_index_list=None, | |
image_nums=None, | |
bad_words_ids=None, | |
): | |
# and torch.cuda.amp.autocast(dtype=torch.float16) | |
with torch.inference_mode(): | |
outputs = model( | |
vision_x=batch_images, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
labels=None, | |
image_nums=image_nums, | |
image_start_index_list=image_start_index_list, | |
added_bbox_list=None, | |
add_box=False, | |
) | |
# outputs = model.generate( | |
# batch_images, | |
# input_ids, | |
# attention_mask=attention_mask, | |
# max_new_tokens=max_generation_length, | |
# min_length=min_generation_length, | |
# num_beams=num_beams, | |
# length_penalty=length_penalty, | |
# image_start_index_list=image_start_index_list, | |
# image_nums=image_nums, | |
# bad_words_ids=bad_words_ids, | |
# ) | |
return outputs | |
def generate( | |
idx, | |
image, | |
text, | |
image_processor, | |
tokenizer, | |
flamingo, | |
vis_embed_size=256, | |
rank=0, | |
world_size=1, | |
): | |
if image is None: | |
raise gr.Error("Please upload an image.") | |
flamingo.eval() | |
loc_token_ids = [] | |
for i in range(1000): | |
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1])) | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
image_ori = image | |
image = image.convert("RGB") | |
width = image.width | |
height = image.height | |
image = image.resize((224, 224)) | |
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
if idx == 1: | |
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"] | |
bad_words_ids = None | |
max_generation_length = 5 | |
else: | |
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"] | |
bad_words_ids = loc_word_ids | |
max_generation_length = 300 | |
encodings = tokenizer( | |
prompt, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
outputs = get_outputs( | |
model=flamingo, | |
batch_images=batch_images, | |
attention_mask=attention_mask, | |
max_generation_length=max_generation_length, | |
min_generation_length=4, | |
num_beams=1, | |
length_penalty=1.0, | |
input_ids=input_ids, | |
bad_words_ids=bad_words_ids, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
) | |
boxes = outputs["boxes"] | |
scores = outputs["scores"] | |
if len(scores) > 0: | |
box = boxes[scores.argmax()]/224 | |
print(f"{box}") | |
if len(boxes)>0: | |
open_cv_image = np.array(image_ori) | |
# Convert RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
box = box*[width,height,width,height] | |
# for box in boxes: | |
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2) | |
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) | |
return f"Output:{box}", out_image | |
else: | |
gen_text = tokenizer.batch_decode(outputs) | |
return (f"{gen_text}") | |
def preprocess_conv(data): | |
conversation = "" | |
BEGIN_SIGNAL = "### " | |
END_SIGNAL = "\n" | |
for idx, d in enumerate(data): | |
from_str = d["from"] | |
if from_str.lower() == "human": | |
from_str = "Human" | |
elif from_str.lower() == "gpt": | |
from_str = "Assistant" | |
else: | |
from_str = 'unknown' | |
conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL) | |
return conversation | |
def preprocess_image(sample, image_processor): | |
image = image_processor(sample) | |
if isinstance(image, transformers.image_processing_utils.BatchFeature): | |
image = torch.tensor(image["pixel_values"][0]) | |
return image | |
class Chat: | |
def __init__(self, model, vis_processor, tokenizer, vis_embed_size ): | |
self.model = model | |
self.vis_processor = vis_processor | |
self.tokenizer = tokenizer | |
self.vis_embed_size = vis_embed_size | |
self.conv = [] | |
# stop_words_ids = [torch.tensor([835]).to(self.device), | |
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. | |
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
def ask(self, text, conv,radio): | |
if radio in ["Cap"]: | |
conv.append({ | |
"from": "human", | |
"value": "", | |
}) | |
elif radio in ["VQA"]: | |
conv.append({ | |
"from": "human", | |
"value": f"Answer the question using a single word or phrase. {text}", | |
}) | |
elif radio in ["REC"]: | |
conv.append({ | |
"from": "human", | |
"value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.", | |
}) | |
else: | |
conv.append({ | |
"from": "human", | |
"value": text, | |
}) | |
# if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ | |
# and conv.messages[-1][1][-6:] == '</Img>': # last message is image. | |
# conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) | |
# else: | |
# conv.append_message(conv.roles[0], text) | |
def answer(self, conv, img_list, radio, text_input, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9, | |
repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000): | |
# conv.append_message(conv.roles[1], None) | |
# embs = self.get_context_emb(conv, img_list) | |
# | |
# # current_max_len = embs.shape[1] + max_new_tokens + 100 | |
# # begin_idx = max(0, current_max_len - max_length) | |
# # embs = embs[:, begin_idx:] | |
# outputs = self.model.llama_model.generate( | |
# inputs_embeds=embs, | |
# max_new_tokens=max_new_tokens, | |
# stopping_criteria=self.stopping_criteria, | |
# num_beams=num_beams, | |
# min_length=min_length, | |
# top_p=top_p, | |
# repetition_penalty=repetition_penalty, | |
# length_penalty=length_penalty, | |
# temperature=temperature, | |
# ) | |
# output_token = outputs[0] | |
# if output_token[0] == 0: | |
# output_token = output_token[1:] | |
# output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) | |
# output_text = output_text.split('###')[0] # remove the stop sign '###' | |
# output_text = output_text.split('Assistant:')[-1].strip() | |
# conv.messages[-1][1] = output_text | |
visual_token = "<|#visual#|>" | |
previsual_token = "<|#previsual#|>" | |
box_token = "<|#box#|>" | |
prebox_token = "<|#prebox#|>" | |
end_token = "<|#endofobject#|>" | |
object_token = "<|#object#|>" | |
end_of_attr_token = "<|#endofattr#|>" | |
preend_of_attr_token = "<|#preendofattr#|>" | |
media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
endofattr_token_id = self.tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = self.tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
size = 224 | |
self.model.eval() | |
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png" | |
# image_path = input("Please enter the image path: ") | |
image = img_list[0].convert("RGB") | |
image_ori = image | |
image = image.resize((size, size)) | |
print(f"image size: {image.size}") | |
batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
# conversation = [] | |
human_sentence = None | |
if radio in ["Cap","VQA"]: | |
conv.append({ | |
"from": "gpt", | |
"value": "", | |
}) | |
elif radio in ["REC"]: | |
conv.append( | |
{ | |
"from": "gpt", | |
"value": object_token + text_input + end_token + visual_token, | |
} | |
) | |
else: | |
conv.append({ | |
"from": "gpt", | |
"value": "", | |
}) | |
# while True: | |
# human_sentence = input("### Human: ") | |
# if human_sentence == "#end#": | |
# break | |
# conversation.append({ | |
# "from": "human", | |
# "value": human_sentence, | |
# }) | |
# conversation.append({ | |
# "from": "gpt", | |
# "value": "", | |
# }) | |
text = preprocess_conv(conv).strip() | |
caption = f"<|#image#|>{self.tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}" | |
encodings = self.tokenizer( | |
caption, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
added_bbox_list = [] | |
with torch.inference_mode(): | |
text_outputs = self.model.generate( | |
batch_images, | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=20, | |
# min_new_tokens=8, | |
num_beams=1, | |
# length_penalty=0, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
) | |
# and torch.cuda.amp.autocast(dtype=torch.float16) | |
with torch.no_grad(): | |
outputs = self.model( | |
vision_x=batch_images, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
image_nums=image_nums, | |
image_start_index_list=image_start_index_list, | |
added_bbox_list=None, | |
add_box=False, | |
) | |
boxes = outputs["boxes"] | |
scores = outputs["scores"] | |
if len(scores) > 0: | |
box = boxes[scores.argmax()] / 224 | |
print(f"{box}") | |
out_image = None | |
if len(boxes)>0: | |
width, height = image_ori.size | |
open_cv_image = np.array(image_ori) | |
# Convert RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
box = box * [width, height, width, height] | |
# for box in boxes: | |
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2) | |
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) | |
# output_token = outputs[0, input_ids.shape[1]:] | |
# output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip() | |
# conv[-1]["value"] = output_text | |
# # conv.messages[-1][1] = output_text | |
# print( | |
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}") | |
output_text = self.tokenizer.decode(text_outputs[0]) | |
output_text = re.findall(r'Assistant:(.+)', output_text)[-1] | |
return output_text, out_image | |
def upload_img(self, image, conv, img_list): | |
img_list.append(image) | |
# if isinstance(image, str): # is a image path | |
# raw_image = Image.open(image).convert('RGB') | |
# image = image.resize((224, 224)) | |
# image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
# elif isinstance(image, Image.Image): | |
# raw_image = image | |
# image = image.resize((224, 224)) | |
# image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
# elif isinstance(image, torch.Tensor): | |
# if len(image.shape) == 3: | |
# image = image.unsqueeze(0) | |
# # image = image.to(self.device) | |
# | |
# # image_emb, _ = self.model.encode_img(image) | |
# img_list.append(image_emb) | |
# conv.append_message(conv.roles[0], "<Img><ImageHere></Img>") | |
msg = "Received." | |
# self.conv.append_message(self.conv.roles[1], msg) | |
return msg | |
# def get_context_emb(self, conv, img_list): | |
# prompt = conv.get_prompt() | |
# prompt_segs = prompt.split('<ImageHere>') | |
# assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." | |
# seg_tokens = [ | |
# self.model.llama_tokenizer( | |
# seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids | |
# # only add bos to the first seg | |
# for i, seg in enumerate(prompt_segs) | |
# ] | |
# seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] | |
# mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] | |
# mixed_embs = torch.cat(mixed_embs, dim=1) | |
# return mixed_embs | |
def evaluate_exp( | |
model, | |
tokenizer, | |
image_processor, | |
vis_embed_size=None, | |
rank=0, | |
world_size=1, | |
id=0, | |
add_visual=True, | |
): | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
size = image_processor.size["shortest_edge"] | |
model.eval() | |
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png" | |
image_path = input("Please enter the image path: ") | |
image = Image.open(image_path).convert("RGB") | |
image = image.resize((size, size)) | |
print(f"image size: {image.size}") | |
batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
conversation = [] | |
human_sentence = None | |
while True: | |
human_sentence = input("### Human: ") | |
if human_sentence == "#end#": | |
break | |
conversation.append({ | |
"from": "human", | |
"value": human_sentence, | |
}) | |
conversation.append({ | |
"from": "gpt", | |
"value": "", | |
}) | |
text = preprocess_conv(conversation).strip() | |
caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}" | |
encodings = tokenizer( | |
caption, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"].to("cuda") | |
attention_mask = encodings["attention_mask"].to("cuda") | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16): | |
outputs = model.generate( | |
batch_images, | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=100, | |
# min_new_tokens=8, | |
num_beams=1, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
) | |
print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}") | |