Spaces:
Runtime error
Runtime error
File size: 2,131 Bytes
0b7b08a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from open_flamingo.eval.vqa_metric import compute_vqa_accuracy
import sys
import json
from bert_score import BERTScorer
from tqdm.contrib.concurrent import process_map
from tqdm import tqdm
import random
import time
NUM_GPU = 128
def single_job(args):
data, refs, idx = args
success = False
while not success:
try:
time.sleep(random.random()*10)
scorer = BERTScorer(
lang="en",
rescale_with_baseline=True,
# model_type="microsoft/deberta-xlarge-mnli",
model_type="bert-base-uncased",
batch_size=4096,
device=f"cuda:{idx % 6}"
)
success = True
except:
time.sleep(random.random()*5)
for i, d in enumerate(tqdm(data, disable=(idx != 0))):
if d["answer"] == "":
continue
cands = [d["answer"]] * len(refs)
P, R, F1 = scorer.score(cands, refs, verbose=False)
d["answer"] = refs[F1.argmax()]
data[i] = d
return data
if __name__ == "__main__":
if sys.argv[1] == "vqav2":
question_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_OpenEnded_mscoco_val2014_questions.json"
annotation_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_mscoco_val2014_annotations.json"
else:
raise NotImplementedError
answer_list = json.load(open("answer_list.json"))
data = json.load(open(sys.argv[2]))
cands = []
refs = []
data_parts = []
for i in range(NUM_GPU):
data_parts.append([[], answer_list, i])
for i, d in enumerate(data):
data_parts[i % NUM_GPU][0].append(d)
datas = process_map(single_job, data_parts, max_workers=NUM_GPU, disable=True)
all_data = []
for data in datas:
all_data.extend(data)
json.dump(all_data, open("temp_result", "w"))
acc = compute_vqa_accuracy(
result_json_path="temp_result",
question_json_path=question_json_path,
annotation_json_path=annotation_json_path,
)
print(acc)
|