chendl's picture
update cap
e770d90
raw
history blame
19.6 kB
from lavis.datasets.builders import load_dataset
import torch
import more_itertools
from tqdm import tqdm
from coco_metric import compute_cider, postprocess_captioning_generation
import json
import time
import os
from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
from PIL import Image
class VisualLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1]
self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
self.topk = 2
def __call__(self, input_ids, scores):
# print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
# import pdb; pdb.set_trace()
if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
scores[0, self.object_token_id] = 1000
if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
# print("generate a previsual token next")
scores[0, self.previsual_token_id] = 1000
elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id:
# print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual")
scores[0, self.eos_token_id] = 1000
elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id:
# print("generate a visual token next")
scores[0, self.visual_token_id] = 1000
return scores
def prepare_batch_images(batch, image_processor):
batch_images = None
for b in batch:
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
if batch_images is None:
batch_images = b_image
else:
batch_images = torch.cat([batch_images, b_image], dim=0)
return batch_images
def captioner(
model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
"""Evaluate a model on COCO dataset.
Returns:
float: CIDEr score
"""
visual_logits_processor = VisualLogitsProcessor(tokenizer)
model.eval()
# model.eval().cuda()
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
box_token = "<|#box#|>"
prebox_token = "<|#prebox#|>"
endofobject_token = "<|#endofobject#|>"
object_token = "<|#object#|>"
ori_prompt_length = len(input_ids[0])
have_prebox = False
out_image = None
while True:
batch_images = batch_images
input_ids = input_ids
attention_mask = attention_mask
image_start_index_list = image_start_index_list
image_nums = image_nums
if debug:
print("input--->",tokenizer.decode(input_ids[0]))
p1 = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1],
min_new_tokens=5,
eos_token_id=bos_token_id,
)
with torch.inference_mode():
outputs = model.generate(
batch_images,
input_ids,
attention_mask=attention_mask,
max_new_tokens=20,
# min_new_tokens=8,
num_beams=1,
# length_penalty=0,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
logits_processor_list=[p1, visual_logits_processor],
)
if debug:
print("outputs--->",tokenizer.decode(outputs[0]))
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
prompt = tokenizer.decode(outputs.clone()[0])
is_visual = (outputs[0, -2] == visual_token_id)
batch_text = tokenizer.batch_decode(outputs[:, :-1])
encodings = tokenizer(
batch_text,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
if debug:
print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
with torch.no_grad():
outputs = model(
vision_x=batch_images,
lang_x=input_ids,
attention_mask=attention_mask,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
)
boxes = outputs["boxes"]
scores = outputs["scores"]
# if not model.valid:
# import pdb; pdb.set_trace()
if boxes is not None:
if is_visual:
if have_prebox:
added_bbox_list.pop()
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
have_prebox = False
if debug:
print("find previsual and remove it--->", prompt)
first_box = boxes[scores.argmax()]
added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
prompt = prompt[:-len(tokenizer.eos_token)]
prompt += box_token + endofobject_token
if debug:
print("after inserting visual---->", prompt)
else:
import numpy as np
import cv2
open_cv_image = np.array(image_ori)
open_cv_image = open_cv_image[:, :, ::-1].copy()
for i, pre_box in enumerate(boxes):
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
# exit()
pre_box = boxes[scores.argmax()]
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
prompt = prompt[:-len(tokenizer.eos_token)]
prompt += prebox_token + object_token
have_prebox = True
if debug:
print("after inserting previsual---->", prompt)
else:
if debug:
import pdb;pdb.set_trace()
prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
else:
break
outputs = outputs[:, ori_prompt_length:]
outputs = postprocess_captioning_generation(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]).replace('"', "")
# new_predictions = [
# postprocess_captioning_generation(out).replace('"', "")
# for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
# ]
# import pdb; pdb.set_trace()
return outputs, out_image
def evaluate_coco_flickr(
model,
tokenizer,
image_processor,
batch_size,
is_flickr=False,
vis_embed_size=None,
rank=0,
world_size=1,
id=0,
debug=False,
):
"""Evaluate a model on COCO dataset.
Returns:
float: CIDEr score
"""
visual_logits_processor = VisualLogitsProcessor(tokenizer)
coco_dataset = load_dataset("coco_caption")
eval_dataset = coco_dataset["test"]
model.eval().cuda()
predictions = dict()
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
box_token = "<|#box#|>"
prebox_token = "<|#prebox#|>"
endofobject_token = "<|#endofobject#|>"
object_token = "<|#object#|>"
cnt = 0
if world_size > 1:
torch.distributed.barrier()
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
for ii, batch in enumerate(more_itertools.chunked(
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
)):
if ii % world_size != rank:
continue
cnt += len(batch)
batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
batch_images = prepare_batch_images(
batch=batch,
image_processor=image_processor,
).cuda()
prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
added_bbox_list = []
batch_text = [prompt for _ in batch]
encodings = tokenizer(
batch_text,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
ori_prompt_length = len(encodings["input_ids"][0])
have_prebox = False
while True:
batch_text = [prompt for _ in batch]
encodings = tokenizer(
batch_text,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"].cuda()
attention_mask = encodings["attention_mask"].cuda()
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
if debug:
print("input--->",tokenizer.decode(input_ids[0]))
p1 = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1],
min_new_tokens=5,
eos_token_id=bos_token_id,
)
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
outputs = model.generate(
batch_images,
input_ids,
attention_mask=attention_mask,
max_new_tokens=20,
# min_new_tokens=8,
num_beams=1,
# length_penalty=0,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
logits_processor_list=[p1, visual_logits_processor],
)
if debug:
print("outputs--->",tokenizer.decode(outputs[0]))
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
prompt = tokenizer.decode(outputs.clone()[0])
is_visual = (outputs[0, -2] == visual_token_id)
batch_text = tokenizer.batch_decode(outputs[:, :-1])
encodings = tokenizer(
batch_text,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"].cuda()
attention_mask = encodings["attention_mask"].cuda()
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
if debug:
print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
outputs = model(
vision_x=batch_images,
lang_x=input_ids,
attention_mask=attention_mask,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
)
boxes = outputs["boxes"]
scores = outputs["scores"]
# if not model.valid:
# import pdb; pdb.set_trace()
if boxes is not None:
if is_visual:
if have_prebox:
added_bbox_list.pop()
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
have_prebox = False
if debug:
print("find previsual and remove it--->", prompt)
first_box = boxes[scores.argmax()]
added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
prompt = prompt[:-len(tokenizer.eos_token)]
prompt += box_token + endofobject_token
if debug:
print("after inserting visual---->", prompt)
else:
import numpy as np
import cv2
open_cv_image = np.array(batch[0]["image"])
open_cv_image = open_cv_image[:, :, ::-1].copy()
for i, pre_box in enumerate(boxes):
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
cv2.imwrite("Atest.png", open_cv_image)
exit()
pre_box = boxes[scores.argmax()]
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
prompt = prompt[:-len(tokenizer.eos_token)]
prompt += prebox_token + object_token
have_prebox = True
if debug:
print("after inserting previsual---->", prompt)
else:
import pdb;pdb.set_trace()
prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
else:
break
outputs = outputs[:, ori_prompt_length:]
new_predictions = [
postprocess_captioning_generation(out).replace('"', "")
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
]
# import pdb; pdb.set_trace()
if rank == 0:
tqdm.write(new_predictions[0])
for i, sample in enumerate(batch):
predictions[int(sample["image_id"])] = {
"caption": new_predictions[i],
}
print(new_predictions)
exit()
results_path = (
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
if is_flickr
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
)
with open(results_path, "w") as f:
f.write(
json.dumps(
[
{"image_id": k, "caption": predictions[k]["caption"]}
for k in predictions
],
indent=2,
)
)
print("save to", results_path)
del predictions
time.sleep(10)
if world_size > 1:
torch.distributed.barrier()
if rank == 0:
print(f"evaluate on rank {rank}. world size is {world_size}")
predictions = []
for rank_i in range(world_size):
part_results_path = (
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
if is_flickr
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
)
print("load", part_results_path)
predictions.extend(json.load(open(part_results_path)))
os.remove(part_results_path)
print("num:", len(predictions))
results_path = (
f"flickrresults_{lang_encoder_name}.json"
if is_flickr
else f"cocoresults_{lang_encoder_name}.json"
)
json.dump(predictions, open(results_path, "w"), indent=2)
metrics = compute_cider(
result_path=results_path,
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
)
metrics["CIDEr"] *= 100
os.makedirs("eval_results", exist_ok=True)
acc = metrics["CIDEr"]
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
f.write(json.dumps(predictions, indent=2))
# delete the temporary file
os.remove(results_path)
else:
metrics = {}
metrics["CIDEr"] = 0.0
return metrics["CIDEr"]