File size: 1,070 Bytes
a977e3e
cbe9336
a977e3e
cbe9336
a977e3e
 
cbe9336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import json
from multiprocessing import freeze_support

from apps_metric import apps_metric


if __name__ == '__main__':
    """
    Verify by checking if reference solutions pass all test cases (with strict accuracy == 1).
    Note that some reference solutions may not pass all test cases. So only throw a warning.
    """
    freeze_support()

    solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
    solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
    single_solutions = [solution_sample1[:1], solution_sample2[:1]]
    multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]

    metric = apps_metric()
    result_1 = metric.compute(predictions=single_solutions, level="all")
    result_2 = metric.compute(predictions=multiple_solutions, level="all", k_list=[1, 2, 3])

    assert result_1 == {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
    assert result_2 == {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}