Spaces:
Runtime error
Runtime error
from lavis.datasets.builders import load_dataset | |
import torch | |
import more_itertools | |
from tqdm import tqdm | |
from coco_metric import compute_cider, postprocess_captioning_generation | |
import json | |
import time | |
import os | |
from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor | |
from PIL import Image | |
class VisualLogitsProcessor(LogitsProcessor): | |
def __init__(self, tokenizer): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1] | |
self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1] | |
self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
self.topk = 2 | |
def __call__(self, input_ids, scores): | |
# print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk])) | |
# import pdb; pdb.set_trace() | |
if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum(): | |
scores[0, self.object_token_id] = 1000 | |
if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id: | |
if (input_ids[0, :-1] == self.object_token_id).sum() != 0: | |
# print("generate a previsual token next") | |
scores[0, self.previsual_token_id] = 1000 | |
elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id: | |
# print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual") | |
scores[0, self.eos_token_id] = 1000 | |
elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id: | |
# print("generate a visual token next") | |
scores[0, self.visual_token_id] = 1000 | |
return scores | |
def prepare_batch_images(batch, image_processor): | |
batch_images = None | |
for b in batch: | |
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
if batch_images is None: | |
batch_images = b_image | |
else: | |
batch_images = torch.cat([batch_images, b_image], dim=0) | |
return batch_images | |
def captioner( | |
model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False): | |
"""Evaluate a model on COCO dataset. | |
Returns: | |
float: CIDEr score | |
""" | |
visual_logits_processor = VisualLogitsProcessor(tokenizer) | |
model.eval() | |
# model.eval().cuda() | |
lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
box_token = "<|#box#|>" | |
prebox_token = "<|#prebox#|>" | |
endofobject_token = "<|#endofobject#|>" | |
object_token = "<|#object#|>" | |
ori_prompt_length = len(input_ids[0]) | |
have_prebox = False | |
out_image = None | |
while True: | |
batch_images = batch_images | |
input_ids = input_ids | |
attention_mask = attention_mask | |
image_start_index_list = image_start_index_list | |
image_nums = image_nums | |
if debug: | |
print("input--->",tokenizer.decode(input_ids[0])) | |
p1 = MinNewTokensLengthLogitsProcessor( | |
prompt_length_to_skip=input_ids.shape[-1], | |
min_new_tokens=5, | |
eos_token_id=bos_token_id, | |
) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
batch_images, | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=20, | |
# min_new_tokens=8, | |
num_beams=1, | |
# length_penalty=0, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
logits_processor_list=[p1, visual_logits_processor], | |
) | |
if debug: | |
print("outputs--->",tokenizer.decode(outputs[0])) | |
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id: | |
prompt = tokenizer.decode(outputs.clone()[0]) | |
is_visual = (outputs[0, -2] == visual_token_id) | |
batch_text = tokenizer.batch_decode(outputs[:, :-1]) | |
encodings = tokenizer( | |
batch_text, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
if debug: | |
print("get the visual bbox--->",tokenizer.decode(input_ids[0])) | |
with torch.no_grad(): | |
outputs = model( | |
vision_x=batch_images, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
image_nums=image_nums, | |
image_start_index_list=image_start_index_list, | |
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
add_box=added_bbox_list is not None and len(added_bbox_list) != 0, | |
) | |
boxes = outputs["boxes"] | |
scores = outputs["scores"] | |
# if not model.valid: | |
# import pdb; pdb.set_trace() | |
if boxes is not None: | |
if is_visual: | |
if have_prebox: | |
added_bbox_list.pop() | |
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") | |
have_prebox = False | |
if debug: | |
print("find previsual and remove it--->", prompt) | |
first_box = boxes[scores.argmax()] | |
added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224] | |
prompt = prompt[:-len(tokenizer.eos_token)] | |
prompt += box_token + endofobject_token | |
if debug: | |
print("after inserting visual---->", prompt) | |
else: | |
import numpy as np | |
import cv2 | |
open_cv_image = np.array(image_ori) | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
for i, pre_box in enumerate(boxes): | |
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1) | |
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) | |
# exit() | |
pre_box = boxes[scores.argmax()] | |
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] | |
prompt = prompt[:-len(tokenizer.eos_token)] | |
prompt += prebox_token + object_token | |
have_prebox = True | |
if debug: | |
print("after inserting previsual---->", prompt) | |
else: | |
if debug: | |
import pdb;pdb.set_trace() | |
prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) | |
else: | |
break | |
outputs = outputs[:, ori_prompt_length:] | |
outputs = postprocess_captioning_generation(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]).replace('"', "") | |
# new_predictions = [ | |
# postprocess_captioning_generation(out).replace('"', "") | |
# for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
# ] | |
# import pdb; pdb.set_trace() | |
return outputs, out_image | |
def evaluate_coco_flickr( | |
model, | |
tokenizer, | |
image_processor, | |
batch_size, | |
is_flickr=False, | |
vis_embed_size=None, | |
rank=0, | |
world_size=1, | |
id=0, | |
debug=False, | |
): | |
"""Evaluate a model on COCO dataset. | |
Returns: | |
float: CIDEr score | |
""" | |
visual_logits_processor = VisualLogitsProcessor(tokenizer) | |
coco_dataset = load_dataset("coco_caption") | |
eval_dataset = coco_dataset["test"] | |
model.eval().cuda() | |
predictions = dict() | |
lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
box_token = "<|#box#|>" | |
prebox_token = "<|#prebox#|>" | |
endofobject_token = "<|#endofobject#|>" | |
object_token = "<|#object#|>" | |
cnt = 0 | |
if world_size > 1: | |
torch.distributed.barrier() | |
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO" | |
for ii, batch in enumerate(more_itertools.chunked( | |
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size | |
)): | |
if ii % world_size != rank: | |
continue | |
cnt += len(batch) | |
batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224)) | |
batch_images = prepare_batch_images( | |
batch=batch, | |
image_processor=image_processor, | |
).cuda() | |
prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>" | |
added_bbox_list = [] | |
batch_text = [prompt for _ in batch] | |
encodings = tokenizer( | |
batch_text, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
ori_prompt_length = len(encodings["input_ids"][0]) | |
have_prebox = False | |
while True: | |
batch_text = [prompt for _ in batch] | |
encodings = tokenizer( | |
batch_text, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"].cuda() | |
attention_mask = encodings["attention_mask"].cuda() | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
if debug: | |
print("input--->",tokenizer.decode(input_ids[0])) | |
p1 = MinNewTokensLengthLogitsProcessor( | |
prompt_length_to_skip=input_ids.shape[-1], | |
min_new_tokens=5, | |
eos_token_id=bos_token_id, | |
) | |
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
outputs = model.generate( | |
batch_images, | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=20, | |
# min_new_tokens=8, | |
num_beams=1, | |
# length_penalty=0, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
logits_processor_list=[p1, visual_logits_processor], | |
) | |
if debug: | |
print("outputs--->",tokenizer.decode(outputs[0])) | |
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id: | |
prompt = tokenizer.decode(outputs.clone()[0]) | |
is_visual = (outputs[0, -2] == visual_token_id) | |
batch_text = tokenizer.batch_decode(outputs[:, :-1]) | |
encodings = tokenizer( | |
batch_text, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"].cuda() | |
attention_mask = encodings["attention_mask"].cuda() | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
if debug: | |
print("get the visual bbox--->",tokenizer.decode(input_ids[0])) | |
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): | |
outputs = model( | |
vision_x=batch_images, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
image_nums=image_nums, | |
image_start_index_list=image_start_index_list, | |
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
add_box=added_bbox_list is not None and len(added_bbox_list) != 0, | |
) | |
boxes = outputs["boxes"] | |
scores = outputs["scores"] | |
# if not model.valid: | |
# import pdb; pdb.set_trace() | |
if boxes is not None: | |
if is_visual: | |
if have_prebox: | |
added_bbox_list.pop() | |
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") | |
have_prebox = False | |
if debug: | |
print("find previsual and remove it--->", prompt) | |
first_box = boxes[scores.argmax()] | |
added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224] | |
prompt = prompt[:-len(tokenizer.eos_token)] | |
prompt += box_token + endofobject_token | |
if debug: | |
print("after inserting visual---->", prompt) | |
else: | |
import numpy as np | |
import cv2 | |
open_cv_image = np.array(batch[0]["image"]) | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
for i, pre_box in enumerate(boxes): | |
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1) | |
cv2.imwrite("Atest.png", open_cv_image) | |
exit() | |
pre_box = boxes[scores.argmax()] | |
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] | |
prompt = prompt[:-len(tokenizer.eos_token)] | |
prompt += prebox_token + object_token | |
have_prebox = True | |
if debug: | |
print("after inserting previsual---->", prompt) | |
else: | |
import pdb;pdb.set_trace() | |
prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) | |
else: | |
break | |
outputs = outputs[:, ori_prompt_length:] | |
new_predictions = [ | |
postprocess_captioning_generation(out).replace('"', "") | |
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
] | |
# import pdb; pdb.set_trace() | |
if rank == 0: | |
tqdm.write(new_predictions[0]) | |
for i, sample in enumerate(batch): | |
predictions[int(sample["image_id"])] = { | |
"caption": new_predictions[i], | |
} | |
print(new_predictions) | |
exit() | |
results_path = ( | |
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json" | |
if is_flickr | |
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json" | |
) | |
with open(results_path, "w") as f: | |
f.write( | |
json.dumps( | |
[ | |
{"image_id": k, "caption": predictions[k]["caption"]} | |
for k in predictions | |
], | |
indent=2, | |
) | |
) | |
print("save to", results_path) | |
del predictions | |
time.sleep(10) | |
if world_size > 1: | |
torch.distributed.barrier() | |
if rank == 0: | |
print(f"evaluate on rank {rank}. world size is {world_size}") | |
predictions = [] | |
for rank_i in range(world_size): | |
part_results_path = ( | |
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json" | |
if is_flickr | |
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json" | |
) | |
print("load", part_results_path) | |
predictions.extend(json.load(open(part_results_path))) | |
os.remove(part_results_path) | |
print("num:", len(predictions)) | |
results_path = ( | |
f"flickrresults_{lang_encoder_name}.json" | |
if is_flickr | |
else f"cocoresults_{lang_encoder_name}.json" | |
) | |
json.dump(predictions, open(results_path, "w"), indent=2) | |
metrics = compute_cider( | |
result_path=results_path, | |
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json", | |
) | |
metrics["CIDEr"] *= 100 | |
os.makedirs("eval_results", exist_ok=True) | |
acc = metrics["CIDEr"] | |
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f: | |
f.write(json.dumps(predictions, indent=2)) | |
# delete the temporary file | |
os.remove(results_path) | |
else: | |
metrics = {} | |
metrics["CIDEr"] = 0.0 | |
return metrics["CIDEr"] | |