Spaces:
Runtime error
Runtime error
import torch | |
from tqdm import tqdm | |
from PIL import Image | |
from io import BytesIO | |
import base64 | |
import numpy as np | |
import time | |
import json | |
import os | |
import cv2 | |
from coco_metric import compute_cider | |
import random | |
import pickle | |
def evaluate_reg( | |
model, | |
tokenizer, | |
image_processor, | |
vis_embed_size=None, | |
rank=0, | |
world_size=1, | |
id=0, | |
): | |
lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
dataset_name = "refcocog" | |
pkl_file = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_data.pkl" | |
try: | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
except: | |
pass | |
model.eval().cuda() | |
if world_size > 1: | |
torch.distributed.barrier() | |
this_tot = 0 | |
predictions = [] | |
D = pickle.load(open(pkl_file, "rb")) | |
lines = [] | |
data = D["data"] | |
uniq_id_to_text = D["uniq_id_to_text"] | |
uniq_id_to_image = D["uniq_id_to_image"] | |
uniq_id_to_image_id = D["uniq_id_to_image_id"] | |
for image_id in data: | |
for region in data[image_id]: | |
uniq_id = data[image_id][region][0] | |
lines.append([uniq_id, uniq_id_to_image_id[uniq_id], [uniq_id_to_text[r] for r in data[image_id][region]], region, uniq_id_to_image[uniq_id]]) | |
print("total data:", len(lines)) | |
# lines = lines[:20] | |
pbar = tqdm(lines, disable=(rank != 0)) | |
for ii, line in enumerate(pbar): | |
if ii % world_size != rank: | |
continue | |
uniq_id, image_id, text, region_coord, image = line | |
gt_box = np.array(region_coord) | |
width = image.width | |
height = image.height | |
image = image.resize((224, 224)) | |
gt_box = gt_box / np.array([width, height, width, height]) * 224 | |
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|>"] | |
encodings = tokenizer( | |
prompt, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
batch_images = batch_images.cuda() | |
input_ids = input_ids.cuda() | |
attention_mask = attention_mask.cuda() | |
added_bbox_list = [(torch.tensor(gt_box).cuda() / 224).clamp(0, 0.99).unsqueeze(0)] | |
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
outputs = model.generate( | |
batch_images, | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=25, | |
min_length=5, | |
num_beams=8, | |
length_penalty=0, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
added_bbox_list=added_bbox_list, | |
) | |
outputs = outputs[:, len(input_ids[0]) :] | |
new_prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip().lower() | |
this_tot += 1 | |
if rank == 0 and this_tot % 10 == 0: | |
for i in range(1): | |
tqdm.write(f"answer: {text}\nmodel output: {new_prediction}") | |
predictions.append( | |
{"image_id": image_id, "caption": new_prediction} | |
) | |
results_path = f"reg_{lang_encoder_name}_{rank}_{id}.json" | |
json.dump(predictions, open(results_path, "w")) | |
print("save to", results_path) | |
del predictions | |
time.sleep(5) | |
if world_size > 1: | |
torch.distributed.barrier() | |
if rank == 0: | |
print(f"evaluate on rank {rank}. world size is {world_size}") | |
predictions = [] | |
for rank_i in range(world_size): | |
part_results_path = f"reg_{lang_encoder_name}_{rank_i}_{id}.json" | |
print("load", part_results_path) | |
part_data = json.load(open(part_results_path)) | |
predictions.extend(part_data) | |
os.remove(part_results_path) | |
print("num:", len(predictions)) | |
results_path = f"reg_{lang_encoder_name}_{id}_result.json" | |
json.dump(predictions, open(results_path, "w"), indent=2) | |
metrics = compute_cider( | |
result_path=results_path, | |
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_label.json", | |
) | |
os.makedirs("eval_results", exist_ok=True) | |
cider = metrics["CIDEr"] | |
print("cider", cider) | |
with open(os.path.join("eval_results", f"reg_{model.expr_name}_{model.step_num}_{int(time.time())}_{cider}"), "w") as f: | |
f.write(json.dumps(predictions, indent=2)) | |
# delete the temporary file | |
os.remove(results_path) | |
return cider | |
if __name__ == "__main__": | |
anno = json.load(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json")) | |
import pdb; pdb.set_trace() | |
print(anno.keys()) | |