File size: 2,131 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from open_flamingo.eval.vqa_metric import compute_vqa_accuracy
import sys
import json
from bert_score import BERTScorer
from tqdm.contrib.concurrent import process_map
from tqdm import tqdm
import random
import time
NUM_GPU = 128

def single_job(args):
    data, refs, idx = args
    success = False
    while not success:
        try:
            time.sleep(random.random()*10)
            scorer = BERTScorer(
                lang="en",
                rescale_with_baseline=True,
                # model_type="microsoft/deberta-xlarge-mnli",
                model_type="bert-base-uncased",
                batch_size=4096,
                device=f"cuda:{idx % 6}"
            )
            success = True
        except:
            time.sleep(random.random()*5)
    for i, d in enumerate(tqdm(data, disable=(idx != 0))):
        if d["answer"] == "":
            continue
        cands = [d["answer"]] * len(refs)
        P, R, F1 = scorer.score(cands, refs, verbose=False)
        d["answer"] = refs[F1.argmax()]
        data[i] = d
    return data


if __name__ == "__main__":
    if sys.argv[1] == "vqav2":
        question_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_OpenEnded_mscoco_val2014_questions.json"
        annotation_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_mscoco_val2014_annotations.json"
    else:
        raise NotImplementedError
    answer_list = json.load(open("answer_list.json"))
    data = json.load(open(sys.argv[2]))
    cands = []
    refs = []
    data_parts = []
    for i in range(NUM_GPU):
        data_parts.append([[], answer_list, i])
    for i, d in enumerate(data):
        data_parts[i % NUM_GPU][0].append(d)
    datas = process_map(single_job, data_parts, max_workers=NUM_GPU, disable=True)
    all_data = []
    for data in datas:
        all_data.extend(data)
    json.dump(all_data, open("temp_result", "w"))
    acc = compute_vqa_accuracy(
        result_json_path="temp_result",
        question_json_path=question_json_path,
        annotation_json_path=annotation_json_path,
    )
    print(acc)