chendl's picture
update cap
e770d90
raw
history blame
5.64 kB
import torch
from tqdm import tqdm
from PIL import Image
from io import BytesIO
import base64
import numpy as np
import time
import json
import os
import cv2
from coco_metric import compute_cider
import random
import pickle
def evaluate_reg(
model,
tokenizer,
image_processor,
vis_embed_size=None,
rank=0,
world_size=1,
id=0,
):
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
dataset_name = "refcocog"
pkl_file = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_data.pkl"
try:
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
except:
pass
model.eval().cuda()
if world_size > 1:
torch.distributed.barrier()
this_tot = 0
predictions = []
D = pickle.load(open(pkl_file, "rb"))
lines = []
data = D["data"]
uniq_id_to_text = D["uniq_id_to_text"]
uniq_id_to_image = D["uniq_id_to_image"]
uniq_id_to_image_id = D["uniq_id_to_image_id"]
for image_id in data:
for region in data[image_id]:
uniq_id = data[image_id][region][0]
lines.append([uniq_id, uniq_id_to_image_id[uniq_id], [uniq_id_to_text[r] for r in data[image_id][region]], region, uniq_id_to_image[uniq_id]])
print("total data:", len(lines))
# lines = lines[:20]
pbar = tqdm(lines, disable=(rank != 0))
for ii, line in enumerate(pbar):
if ii % world_size != rank:
continue
uniq_id, image_id, text, region_coord, image = line
gt_box = np.array(region_coord)
width = image.width
height = image.height
image = image.resize((224, 224))
gt_box = gt_box / np.array([width, height, width, height]) * 224
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|>"]
encodings = tokenizer(
prompt,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
image_start_index_list = [[x] for x in image_start_index_list]
image_nums = [1] * len(input_ids)
batch_images = batch_images.cuda()
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
added_bbox_list = [(torch.tensor(gt_box).cuda() / 224).clamp(0, 0.99).unsqueeze(0)]
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
outputs = model.generate(
batch_images,
input_ids,
attention_mask=attention_mask,
max_new_tokens=25,
min_length=5,
num_beams=8,
length_penalty=0,
image_start_index_list=image_start_index_list,
image_nums=image_nums,
added_bbox_list=added_bbox_list,
)
outputs = outputs[:, len(input_ids[0]) :]
new_prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip().lower()
this_tot += 1
if rank == 0 and this_tot % 10 == 0:
for i in range(1):
tqdm.write(f"answer: {text}\nmodel output: {new_prediction}")
predictions.append(
{"image_id": image_id, "caption": new_prediction}
)
results_path = f"reg_{lang_encoder_name}_{rank}_{id}.json"
json.dump(predictions, open(results_path, "w"))
print("save to", results_path)
del predictions
time.sleep(5)
if world_size > 1:
torch.distributed.barrier()
if rank == 0:
print(f"evaluate on rank {rank}. world size is {world_size}")
predictions = []
for rank_i in range(world_size):
part_results_path = f"reg_{lang_encoder_name}_{rank_i}_{id}.json"
print("load", part_results_path)
part_data = json.load(open(part_results_path))
predictions.extend(part_data)
os.remove(part_results_path)
print("num:", len(predictions))
results_path = f"reg_{lang_encoder_name}_{id}_result.json"
json.dump(predictions, open(results_path, "w"), indent=2)
metrics = compute_cider(
result_path=results_path,
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_label.json",
)
os.makedirs("eval_results", exist_ok=True)
cider = metrics["CIDEr"]
print("cider", cider)
with open(os.path.join("eval_results", f"reg_{model.expr_name}_{model.step_num}_{int(time.time())}_{cider}"), "w") as f:
f.write(json.dumps(predictions, indent=2))
# delete the temporary file
os.remove(results_path)
return cider
if __name__ == "__main__":
anno = json.load(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json"))
import pdb; pdb.set_trace()
print(anno.keys())