Spaces:
Runtime error
Runtime error
from open_flamingo.eval.vqa_metric import compute_vqa_accuracy | |
import sys | |
import json | |
from bert_score import BERTScorer | |
from tqdm.contrib.concurrent import process_map | |
from tqdm import tqdm | |
import random | |
import time | |
NUM_GPU = 128 | |
def single_job(args): | |
data, refs, idx = args | |
success = False | |
while not success: | |
try: | |
time.sleep(random.random()*10) | |
scorer = BERTScorer( | |
lang="en", | |
rescale_with_baseline=True, | |
# model_type="microsoft/deberta-xlarge-mnli", | |
model_type="bert-base-uncased", | |
batch_size=4096, | |
device=f"cuda:{idx % 6}" | |
) | |
success = True | |
except: | |
time.sleep(random.random()*5) | |
for i, d in enumerate(tqdm(data, disable=(idx != 0))): | |
if d["answer"] == "": | |
continue | |
cands = [d["answer"]] * len(refs) | |
P, R, F1 = scorer.score(cands, refs, verbose=False) | |
d["answer"] = refs[F1.argmax()] | |
data[i] = d | |
return data | |
if __name__ == "__main__": | |
if sys.argv[1] == "vqav2": | |
question_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_OpenEnded_mscoco_val2014_questions.json" | |
annotation_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_mscoco_val2014_annotations.json" | |
else: | |
raise NotImplementedError | |
answer_list = json.load(open("answer_list.json")) | |
data = json.load(open(sys.argv[2])) | |
cands = [] | |
refs = [] | |
data_parts = [] | |
for i in range(NUM_GPU): | |
data_parts.append([[], answer_list, i]) | |
for i, d in enumerate(data): | |
data_parts[i % NUM_GPU][0].append(d) | |
datas = process_map(single_job, data_parts, max_workers=NUM_GPU, disable=True) | |
all_data = [] | |
for data in datas: | |
all_data.extend(data) | |
json.dump(all_data, open("temp_result", "w")) | |
acc = compute_vqa_accuracy( | |
result_json_path="temp_result", | |
question_json_path=question_json_path, | |
annotation_json_path=annotation_json_path, | |
) | |
print(acc) | |