loubnabnl HF staff commited on
Commit
819e86a
β€’
1 Parent(s): 582c895

change location of utils files

Browse files
tests.py DELETED
@@ -1,17 +0,0 @@
1
- test_cases = [
2
- {
3
- "predictions": [0, 0],
4
- "references": [1, 1],
5
- "result": {"metric_score": 0}
6
- },
7
- {
8
- "predictions": [1, 1],
9
- "references": [1, 1],
10
- "result": {"metric_score": 1}
11
- },
12
- {
13
- "predictions": [1, 0],
14
- "references": [1, 1],
15
- "result": {"metric_score": 0.5}
16
- }
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/.ipynb_checkpoints/testing_util-checkpoint.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import faulthandler
4
+
5
+ # used for debugging to time steps
6
+ from datetime import datetime
7
+
8
+ # to run the solution files we're using a timing based approach
9
+ import signal
10
+
11
+ import numpy as np
12
+ # for capturing the stdout
13
+ from io import StringIO
14
+ # used for testing the code that reads from input
15
+ from unittest.mock import patch, mock_open
16
+
17
+ from pyext import RuntimeModule
18
+
19
+ from enum import Enum
20
+ class CODE_TYPE(Enum):
21
+ call_based = 0
22
+ standard_input = 1
23
+
24
+ # stuff for setting up signal timer
25
+ class TimeoutException(Exception):
26
+ pass
27
+ def timeout_handler(signum, frame):
28
+ print("alarm went off")
29
+ #return
30
+ raise TimeoutException
31
+ signal.signal(signal.SIGALRM, timeout_handler)
32
+ timeout = 4 # seconds
33
+
34
+ # used to capture stdout as a list
35
+ # from https://stackoverflow.com/a/16571630/6416660
36
+ # alternative use redirect_stdout() from contextlib
37
+ class Capturing(list):
38
+ def __enter__(self):
39
+ self._stdout = sys.stdout
40
+ sys.stdout = self._stringio = StringIO()
41
+ # Make closing the StringIO a no-op
42
+ self._stringio.close = lambda x: 1
43
+ return self
44
+ def __exit__(self, *args):
45
+ self.extend(self._stringio.getvalue().splitlines())
46
+ del self._stringio # free up some memory
47
+ sys.stdout = self._stdout
48
+
49
+
50
+ def run_test(sample, test=None, debug=False):
51
+ """
52
+ if test(generated_code) is not None it'll try to run the code.
53
+ otherwise it'll just return an input and output pair.
54
+ """
55
+ if debug:
56
+ print(f"start = {datetime.now().time()}")
57
+
58
+ try:
59
+ in_outs = json.loads(sample["input_output"])
60
+ except ValueError:
61
+ in_outs = None
62
+ if in_outs:
63
+ if in_outs.get("fn_name") is None:
64
+ which_type = CODE_TYPE.standard_input # Standard input
65
+ method_name = None
66
+ else:
67
+ which_type = CODE_TYPE.call_based # Call-based
68
+ method_name = in_outs["fn_name"]
69
+
70
+ if debug:
71
+ print(f"loaded input_output = {datetime.now().time()}")
72
+
73
+ if test is None:
74
+ return in_outs
75
+ elif test is not None:
76
+ results = []
77
+ sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
78
+ if debug:
79
+ print(f"loading test code = {datetime.now().time()}")
80
+
81
+ if which_type == CODE_TYPE.call_based:
82
+ sol += test
83
+ if debug:
84
+ print(f"sol = {sol}")
85
+ signal.alarm(timeout)
86
+ try:
87
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
88
+ if "class Solution" not in test:
89
+ tmp = tmp_sol
90
+ else:
91
+ tmp = tmp_sol.Solution()
92
+ signal.alarm(0)
93
+ except Exception as e:
94
+ signal.alarm(0)
95
+ if debug:
96
+ print(f"type 0 compilation error = {e}")
97
+ results.append(-2)
98
+ return results
99
+ signal.alarm(0)
100
+
101
+ elif which_type == CODE_TYPE.standard_input:
102
+ # sol
103
+ tmp_test = test.split("\n")
104
+
105
+ new_test = []
106
+ for x in tmp_test:
107
+ if (not x.startswith("from ")) and (not x.startswith("import ")):
108
+ new_test.append("\t" + x + "\n")
109
+ else:
110
+ new_test.append(x + "\n")
111
+ tmp_test = new_test
112
+
113
+ new_test = ""
114
+ started = False
115
+ for i in tmp_test:
116
+ if i.startswith("\t") and not started:
117
+ new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
118
+ new_test += "def code():\n"
119
+ new_test += i
120
+ started = True
121
+ elif started and ((i.startswith("from ")) or (i.startswith("import "))):
122
+ new_test += "\t" + i
123
+ else:
124
+ new_test += i
125
+ tmp_test = new_test
126
+
127
+ sol += tmp_test
128
+ if debug:
129
+ print(f"sol = {sol}")
130
+ method_name = "code"
131
+ signal.alarm(timeout)
132
+ try:
133
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
134
+ tmp = tmp_sol
135
+ signal.alarm(0)
136
+ except Exception as e:
137
+ signal.alarm(0)
138
+ if debug:
139
+ print(f"type 1 compilation error = {e}")
140
+ results.append(-2)
141
+ return results
142
+ signal.alarm(0)
143
+ if debug:
144
+ print(f"get method = {datetime.now().time()}")
145
+
146
+ try:
147
+ method = getattr(tmp, method_name) # get_attr second arg must be str
148
+ except:
149
+ signal.alarm(0)
150
+ e = sys.exc_info()
151
+ print(f"unable to get function error = {e}")
152
+ return results
153
+
154
+ for index, inputs in enumerate(in_outs["inputs"]):
155
+ # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
156
+ try:
157
+ if isinstance(inputs[0], dict):
158
+ inputs = [{int(k): v for k,v in inputs[0].items()}]
159
+ except:
160
+ True
161
+ try:
162
+ if isinstance(in_outs["outputs"][index], dict):
163
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
164
+ except:
165
+ True
166
+ try:
167
+ if isinstance(in_outs["outputs"][index][0], dict):
168
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
169
+ except:
170
+ True
171
+
172
+ if debug:
173
+ print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
174
+ if which_type == CODE_TYPE.call_based: # Call-based
175
+ signal.alarm(timeout)
176
+ faulthandler.enable()
177
+ try:
178
+ output = method(*inputs)
179
+
180
+ # ground truth sequences are not tuples
181
+ if isinstance(output, tuple):
182
+ output = list(output)
183
+
184
+ tmp_result = output == in_outs["outputs"][index]
185
+ if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
186
+ tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
187
+
188
+ # ground truth sequences are not tuples
189
+ try:
190
+ if isinstance(output[0], tuple):
191
+ tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
192
+ except:
193
+ True
194
+ results.append(tmp_result)
195
+
196
+ # reset the alarm
197
+ signal.alarm(0)
198
+ except Exception as e:
199
+ signal.alarm(0)
200
+ faulthandler.disable()
201
+ print(f"Standard input runtime error or time limit exceeded error = {e}")
202
+ results.append(-1)
203
+ continue
204
+ faulthandler.disable()
205
+ signal.alarm(0)
206
+ if debug:
207
+ print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
208
+ elif which_type == CODE_TYPE.standard_input: # Standard input
209
+ faulthandler.enable()
210
+ signal.alarm(timeout)
211
+ passed = False
212
+
213
+ if isinstance(inputs, list):
214
+ inputs = "\n".join(inputs)
215
+ if isinstance(in_outs['outputs'][index], list):
216
+ in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
217
+
218
+ with Capturing() as output:
219
+ try:
220
+ call_method(method, inputs)
221
+ # reset the alarm
222
+ signal.alarm(0)
223
+ passed = True
224
+ except Exception as e:
225
+ # runtime error or took too long
226
+ signal.alarm(0)
227
+ print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
228
+ results.append(-1)
229
+ signal.alarm(0)
230
+
231
+ if not passed:
232
+ if debug:
233
+ nl = "\n"
234
+ if not isinstance(inputs, list):
235
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
236
+ else:
237
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
238
+ continue
239
+
240
+ if passed and debug:
241
+ print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
242
+
243
+ if custom_compare_(output, in_outs['outputs'][index]):
244
+ tmp_result = True
245
+ results.append(tmp_result)
246
+ continue
247
+
248
+ # ground truth sequences are expressed as lists not tuples
249
+ if isinstance(output, tuple):
250
+ output = list(output)
251
+
252
+ tmp_result = False
253
+ try:
254
+ tmp_result = (output == [in_outs["outputs"][index]])
255
+ if isinstance(in_outs["outputs"][index], list):
256
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
257
+ if isinstance(output[0], str):
258
+ tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
259
+ except Exception as e:
260
+ if debug:
261
+ print(f"Failed check1 exception = {e}")
262
+ pass
263
+
264
+ if tmp_result == True:
265
+ results.append(tmp_result)
266
+ continue
267
+
268
+ # try one more time without \n
269
+ if isinstance(in_outs["outputs"][index], list):
270
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
271
+ in_outs["outputs"][index][tmp_index] = i.split("\n")
272
+ in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
273
+ else:
274
+ in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
275
+ in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
276
+ in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
277
+
278
+ try:
279
+ tmp_result = (output == [in_outs["outputs"][index]])
280
+ if isinstance(in_outs["outputs"][index], list):
281
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
282
+ except Exception as e:
283
+ if debug:
284
+ print(f"Failed check2 exception = {e}")
285
+ pass
286
+
287
+ if tmp_result == True:
288
+ results.append(tmp_result)
289
+ continue
290
+
291
+ # try by converting the output into a split up list too
292
+ if isinstance(output, list):
293
+ output = list(filter(len, output))
294
+
295
+ if debug:
296
+ nl = "\n"
297
+ if not isinstance(inputs, list):
298
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
299
+ else:
300
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
301
+
302
+ if tmp_result == True:
303
+ results.append(tmp_result)
304
+ continue
305
+
306
+ try:
307
+ tmp_result = (output == [in_outs["outputs"][index]])
308
+ if isinstance(in_outs["outputs"][index], list):
309
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
310
+ except Exception as e:
311
+ if debug:
312
+ print(f"Failed check3 exception = {e}")
313
+ pass
314
+
315
+ try:
316
+ output_float = [float(e) for e in output]
317
+ gt_float = [float(e) for e in in_outs['outputs'][index]]
318
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
319
+ except Exception as e:
320
+ pass
321
+ try:
322
+ if isinstance(output[0], list):
323
+ output_float = [float(e) for e in output[0]]
324
+ gt_float = [float(e) for e in in_outs['outputs'][index][0]]
325
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
326
+ except Exception as e:
327
+ pass
328
+
329
+ if tmp_result == True:
330
+ results.append(tmp_result)
331
+ continue
332
+
333
+ # try by converting the stuff into split up list
334
+ if isinstance(in_outs["outputs"][index], list):
335
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
336
+ in_outs["outputs"][index][tmp_index] = set(i.split())
337
+ else:
338
+ in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
339
+
340
+ try:
341
+ tmp_result = (output == in_outs["outputs"][index])
342
+ except Exception as e:
343
+ if debug:
344
+ print(f"Failed check4 exception = {e}")
345
+ continue
346
+
347
+ if tmp_result == True:
348
+ results.append(tmp_result)
349
+ continue
350
+
351
+ # try by converting the output into a split up list too
352
+ if isinstance(output, list):
353
+ for tmp_index, i in enumerate(output):
354
+ output[tmp_index] = i.split()
355
+ output = list(filter(len, output))
356
+ for tmp_index, i in enumerate(output):
357
+ output[tmp_index] = set(i)
358
+ else:
359
+ output = output.split()
360
+ output = list(filter(len, output))
361
+ output = set(output)
362
+
363
+ try:
364
+ tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
365
+ except Exception as e:
366
+ if debug:
367
+ print(f"Failed check5 exception = {e}")
368
+
369
+
370
+ # if they are all numbers, round so that similar numbers are treated as identical
371
+ try:
372
+ tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
373
+ set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
374
+ except Exception as e:
375
+ if debug:
376
+ print(f"Failed check6 exception = {e}")
377
+
378
+ if tmp_result == True and debug:
379
+ print("PASSED")
380
+
381
+ results.append(tmp_result)
382
+
383
+ if debug:
384
+ nl = "\n"
385
+ if not isinstance(inputs, list):
386
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
387
+ else:
388
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
389
+
390
+
391
+ return results
392
+
393
+
394
+ def custom_compare_(output, ground_truth):
395
+
396
+ if isinstance(output, list):
397
+ output_1 = "\n".join(output)
398
+ if stripped_string_compare(output_1, ground_truth):
399
+ return True
400
+
401
+ if isinstance(output, list):
402
+ output_2 = [o.lstrip().rstrip() for o in output]
403
+ output_2 = "\n".join(output_2)
404
+ if stripped_string_compare(output_2, ground_truth):
405
+ return True
406
+
407
+ return False
408
+
409
+ def stripped_string_compare(s1, s2):
410
+ s1 = s1.lstrip().rstrip()
411
+ s2 = s2.lstrip().rstrip()
412
+ return s1 == s2
413
+
414
+ def call_method(method, inputs):
415
+
416
+ if isinstance(inputs, list):
417
+ inputs = "\n".join(inputs)
418
+
419
+ inputs_line_iterator = iter(inputs.split("\n"))
420
+
421
+ # sys.setrecursionlimit(10000)
422
+
423
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
424
+ @patch('builtins.open', mock_open(read_data=inputs))
425
+ @patch('sys.stdin', StringIO(inputs))
426
+ @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
427
+ @patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
428
+ @patch('sys.stdin.read', lambda *args: inputs)
429
+ # @patch('sys.stdout.write', print)
430
+ def _inner_call_method(_method):
431
+ try:
432
+ return _method()
433
+ except SystemExit as e:
434
+ pass
435
+ finally:
436
+ pass
437
+ return _inner_call_method(method)
438
+
tools/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ from typing import Dict
4
+ from datasets import load_dataset
5
+ import tools.testing_util as test_util
6
+
7
+
8
+ DATASET = "codeparrot/apps"
9
+
10
+
11
+ def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
12
+ """We take the list of code generations and try to compile them
13
+ and the run their corresponding unit tests which are retrieved from the APPS dataset.
14
+
15
+ Args:
16
+ generations: list of code generations (same order as samples in APPS dataset)
17
+ level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
18
+
19
+ Returns:
20
+ results: dictionary of results, key is the problem index, value is a list of results for each generation
21
+ [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
22
+ """
23
+
24
+ # generations are code generations in the same order of the dataset
25
+ apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
26
+ results = {}
27
+ for index in range(len(generations)):
28
+ # code generations for problem (index)
29
+ problem_generations = generations[index]
30
+ # get corresponding samples from APPS dataset
31
+ sample = apps_eval[index]
32
+ res = []
33
+ # loop over the generations
34
+ for o_idx, o in enumerate(problem_generations):
35
+ curr_res = [-2]
36
+ try:
37
+ curr_res = test_util.run_test(sample, test=o, debug=debug)
38
+ #if debug:
39
+ print(f"\nSuccessful compilation of task {index}!")
40
+ fixed = []
41
+ for e in curr_res:
42
+ if isinstance(e, np.ndarray):
43
+ e = e.item(0)
44
+ if isinstance(e, np.bool_):
45
+ e = bool(e)
46
+ fixed.append(e)
47
+ curr_res = fixed
48
+ if not np.all(curr_res):
49
+ #if debug:
50
+ print(f"Results were not True for all test cases")
51
+ except Exception as e:
52
+ if debug:
53
+ print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
54
+ break
55
+ finally:
56
+ assert isinstance(curr_res, list)
57
+ res.append(curr_res)
58
+ results[index] = res
59
+ return results
60
+
61
+
62
+ def estimate_pass_at_k(num_samples, num_correct, k):
63
+ """Estimates pass@k of each problem and returns them in an array."""
64
+
65
+ def estimator(n: int, c: int, k: int) -> float:
66
+ """Calculates 1 - comb(n - c, k) / comb(n, k)."""
67
+ if n - c < k:
68
+ return 1.0
69
+ return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
70
+
71
+ if isinstance(num_samples, int):
72
+ num_samples_it = itertools.repeat(num_samples, len(num_correct))
73
+ else:
74
+ assert len(num_samples) == len(num_correct)
75
+ num_samples_it = iter(num_samples)
76
+
77
+ return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
78
+
79
+
80
+ def get_results(results: Dict[int, list], count_errors: bool = False, k_list: list = [1, 10, 100]):
81
+ """
82
+ Given the results evaluated against the testcases we output some statistics.
83
+ For single generations:
84
+ >>> example_results = {0: [[-2]], 1: [[False,False]], 2: [[True,True]], 3: [[False,True,False,True]], 4: [[-1,-1]]}
85
+ >>> get_results(example_results, count_errors=True)
86
+ Computing accuracy metrics...
87
+ number of compile errors = 1 avg = 0.2
88
+ number of runtime errors = 1 avg = 0.2
89
+ number of problems evaluated = 5
90
+ Average Accuracy : 0.3
91
+ Strict Accuracy : 0.2
92
+ {'avg_accuracy': 0.3, 'strict_accuracy': 0.2, 'pass_at_k': None}
93
+
94
+ For multiple generations:
95
+ >>> example_results = {0: [[-2], [True, True, True]], 1: [[-1,-1, -1], [True, False, True]]}
96
+ >>> get_results(example_results, k_list=[1, 2])
97
+ Computing pass@k metric for multiple generations...
98
+ {'pass@1': 0.25, 'pass@2': 0.5}
99
+ {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 0.25, 'pass@2': 0.5}}
100
+ """
101
+
102
+ metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
103
+
104
+ if len(results[0]) == 1:
105
+ # for single generations we compute average accuracy and stric accuracy: original APPS metrics
106
+ print("Computing accuracy metrics...")
107
+ res = []
108
+ per_prob_res = []
109
+ all_correct = []
110
+ for index in results:
111
+ problem_results = np.asarray(results[index])
112
+ res.extend(problem_results)
113
+ per_prob_res.append(np.mean(problem_results > 0))
114
+ all_correct.append(np.all(problem_results > 0))
115
+ # we count campilation and runtime errors once per pronlem
116
+ compile_errors = len([e for e in res if -2 in e])
117
+ runtime_errors = len([e for e in res if -1 in e])
118
+ total_testcases = len(res)
119
+ if count_errors:
120
+ print(f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases}")
121
+ print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}")
122
+ print(f"number of problems evaluated = {total_testcases}")
123
+
124
+ print(f"Average Accuracy : {np.mean(per_prob_res)}")
125
+ print(f"Strict Accuracy : {np.mean(all_correct)}")
126
+ metrics["avg_accuracy"] = np.mean(per_prob_res)
127
+ metrics["strict_accuracy"] = np.mean(all_correct)
128
+
129
+ else:
130
+ # for multiple generations we use pass@k metric used in the HumanEval benchmark
131
+ # we use strict accuracy, a generation is valid if it has to pass all the tests
132
+ print("Computing pass@k metric for multiple generations...")
133
+ # total is list with nb generations per task (task=index)
134
+ # correct is number of generations that passed all tests per task
135
+ total = []
136
+ correct = []
137
+ for index in results:
138
+ all_correct = []
139
+ for generation in results[index]:
140
+ gen = np.array(generation)
141
+ all_correct.append(np.all(gen>0))
142
+ total.append(len(all_correct))
143
+ correct.append(sum(all_correct))
144
+ total = np.array(total)
145
+ correct = np.array(correct)
146
+ ks = k_list
147
+ pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
148
+ print(pass_at_k)
149
+ metrics["pass_at_k"] = pass_at_k
150
+ return metrics
151
+
152
+ def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
153
+ """Return metrics for the given generations.
154
+ Args:
155
+ generations: list of code generations for each problem (each generation is a list of generations)
156
+ k_list: list of k values to compute pass@k when using multiple generations
157
+ count_errors: whether to count compilation and runtime errors when using single generations
158
+ level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
159
+ Returns:
160
+ metrics: dict of metrics
161
+
162
+ Examples:
163
+
164
+ >>> import json
165
+ >>> # lists of solutions to the two first APPS problems (note not all solutions pass all tests)
166
+ >>> solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
167
+ >>> solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
168
+ >>> single_solutions = [solution_sample1[:1], solution_sample2[:1]]
169
+ >>> compute_metrics(single_solutions, level="all")
170
+ Computing accuracy metrics...
171
+ number of compile errors = 0 avg = 0.0
172
+ number of runtime errors = 0 avg = 0.0
173
+ number of problems evaluated = 2
174
+ Average Accuracy : 1.0
175
+ Strict Accuracy : 1.0
176
+ {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
177
+ >>> multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
178
+ >>> compute_metrics(multiple_solutions, level="all", k_list=[1, 2, 3])
179
+ Computing pass@k metric for multiple generations...
180
+ {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
181
+ {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
182
+ """
183
+ results = evaluate_generations(generations, level=level, debug=debug)
184
+ metrics = get_results(results, count_errors=count_errors, k_list=k_list)
185
+ return metrics
186
+
187
+ #import doctest
188
+ #doctest.testmod()
tools/__pycache__/testing_util.cpython-39.pyc ADDED
Binary file (11.6 kB). View file
 
tools/__pycache__/utils.cpython-39.pyc ADDED
Binary file (7.36 kB). View file
 
testing_util.py β†’ tools/testing_util.py RENAMED
@@ -60,8 +60,6 @@ def run_test(sample, test=None, debug=False):
60
  except ValueError:
61
  in_outs = None
62
  if in_outs:
63
- #if debug:
64
- # print(f"test cases json = {in_outs['inputs']} {in_outs['outputs']}")
65
  if in_outs.get("fn_name") is None:
66
  which_type = CODE_TYPE.standard_input # Standard input
67
  method_name = None
@@ -72,8 +70,6 @@ def run_test(sample, test=None, debug=False):
72
  if debug:
73
  print(f"loaded input_output = {datetime.now().time()}")
74
 
75
- #else:
76
- # continue
77
  if test is None:
78
  return in_outs
79
  elif test is not None:
@@ -96,7 +92,8 @@ def run_test(sample, test=None, debug=False):
96
  signal.alarm(0)
97
  except Exception as e:
98
  signal.alarm(0)
99
- print(f"type 0 compilation error = {e}")
 
100
  results.append(-2)
101
  return results
102
  signal.alarm(0)
@@ -138,7 +135,8 @@ def run_test(sample, test=None, debug=False):
138
  signal.alarm(0)
139
  except Exception as e:
140
  signal.alarm(0)
141
- print(f"type 1 compilation error = {e}")
 
142
  results.append(-2)
143
  return results
144
  signal.alarm(0)
@@ -205,8 +203,8 @@ def run_test(sample, test=None, debug=False):
205
  continue
206
  faulthandler.disable()
207
  signal.alarm(0)
208
- #if debug:
209
- #print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
210
  elif which_type == CODE_TYPE.standard_input: # Standard input
211
  faulthandler.enable()
212
  signal.alarm(timeout)
@@ -216,16 +214,14 @@ def run_test(sample, test=None, debug=False):
216
  inputs = "\n".join(inputs)
217
  if isinstance(in_outs['outputs'][index], list):
218
  in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
 
219
  with Capturing() as output:
220
  try:
221
- print("doing call")
222
  call_method(method, inputs)
223
- print("call done")
224
  # reset the alarm
225
  signal.alarm(0)
226
  passed = True
227
  except Exception as e:
228
- print("call not done we are in exception")
229
  # runtime error or took too long
230
  signal.alarm(0)
231
  print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
@@ -261,7 +257,8 @@ def run_test(sample, test=None, debug=False):
261
  if isinstance(output[0], str):
262
  tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
263
  except Exception as e:
264
- print(f"Failed check1 exception = {e}")
 
265
  pass
266
 
267
  if tmp_result == True:
@@ -283,7 +280,8 @@ def run_test(sample, test=None, debug=False):
283
  if isinstance(in_outs["outputs"][index], list):
284
  tmp_result = tmp_result or (output == in_outs["outputs"][index])
285
  except Exception as e:
286
- print(f"Failed check2 exception = {e}")
 
287
  pass
288
 
289
  if tmp_result == True:
@@ -310,7 +308,8 @@ def run_test(sample, test=None, debug=False):
310
  if isinstance(in_outs["outputs"][index], list):
311
  tmp_result = tmp_result or (output == in_outs["outputs"][index])
312
  except Exception as e:
313
- print(f"Failed check3 exception = {e}")
 
314
  pass
315
 
316
  try:
@@ -341,7 +340,8 @@ def run_test(sample, test=None, debug=False):
341
  try:
342
  tmp_result = (output == in_outs["outputs"][index])
343
  except Exception as e:
344
- print(f"Failed check4 exception = {e}")
 
345
  continue
346
 
347
  if tmp_result == True:
@@ -363,7 +363,8 @@ def run_test(sample, test=None, debug=False):
363
  try:
364
  tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
365
  except Exception as e:
366
- print(f"Failed check5 exception = {e}")
 
367
 
368
 
369
  # if they are all numbers, round so that similar numbers are treated as identical
@@ -371,7 +372,8 @@ def run_test(sample, test=None, debug=False):
371
  tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
372
  set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
373
  except Exception as e:
374
- print(f"Failed check6 exception = {e}")
 
375
 
376
  if tmp_result == True and debug:
377
  print("PASSED")
@@ -384,10 +386,11 @@ def run_test(sample, test=None, debug=False):
384
  print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
385
  else:
386
  print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
387
-
388
 
389
  return results
390
 
 
391
  def custom_compare_(output, ground_truth):
392
 
393
  if isinstance(output, list):
@@ -432,3 +435,4 @@ def call_method(method, inputs):
432
  finally:
433
  pass
434
  return _inner_call_method(method)
 
 
60
  except ValueError:
61
  in_outs = None
62
  if in_outs:
 
 
63
  if in_outs.get("fn_name") is None:
64
  which_type = CODE_TYPE.standard_input # Standard input
65
  method_name = None
 
70
  if debug:
71
  print(f"loaded input_output = {datetime.now().time()}")
72
 
 
 
73
  if test is None:
74
  return in_outs
75
  elif test is not None:
 
92
  signal.alarm(0)
93
  except Exception as e:
94
  signal.alarm(0)
95
+ if debug:
96
+ print(f"type 0 compilation error = {e}")
97
  results.append(-2)
98
  return results
99
  signal.alarm(0)
 
135
  signal.alarm(0)
136
  except Exception as e:
137
  signal.alarm(0)
138
+ if debug:
139
+ print(f"type 1 compilation error = {e}")
140
  results.append(-2)
141
  return results
142
  signal.alarm(0)
 
203
  continue
204
  faulthandler.disable()
205
  signal.alarm(0)
206
+ if debug:
207
+ print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
208
  elif which_type == CODE_TYPE.standard_input: # Standard input
209
  faulthandler.enable()
210
  signal.alarm(timeout)
 
214
  inputs = "\n".join(inputs)
215
  if isinstance(in_outs['outputs'][index], list):
216
  in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
217
+
218
  with Capturing() as output:
219
  try:
 
220
  call_method(method, inputs)
 
221
  # reset the alarm
222
  signal.alarm(0)
223
  passed = True
224
  except Exception as e:
 
225
  # runtime error or took too long
226
  signal.alarm(0)
227
  print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
 
257
  if isinstance(output[0], str):
258
  tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
259
  except Exception as e:
260
+ if debug:
261
+ print(f"Failed check1 exception = {e}")
262
  pass
263
 
264
  if tmp_result == True:
 
280
  if isinstance(in_outs["outputs"][index], list):
281
  tmp_result = tmp_result or (output == in_outs["outputs"][index])
282
  except Exception as e:
283
+ if debug:
284
+ print(f"Failed check2 exception = {e}")
285
  pass
286
 
287
  if tmp_result == True:
 
308
  if isinstance(in_outs["outputs"][index], list):
309
  tmp_result = tmp_result or (output == in_outs["outputs"][index])
310
  except Exception as e:
311
+ if debug:
312
+ print(f"Failed check3 exception = {e}")
313
  pass
314
 
315
  try:
 
340
  try:
341
  tmp_result = (output == in_outs["outputs"][index])
342
  except Exception as e:
343
+ if debug:
344
+ print(f"Failed check4 exception = {e}")
345
  continue
346
 
347
  if tmp_result == True:
 
363
  try:
364
  tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
365
  except Exception as e:
366
+ if debug:
367
+ print(f"Failed check5 exception = {e}")
368
 
369
 
370
  # if they are all numbers, round so that similar numbers are treated as identical
 
372
  tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
373
  set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
374
  except Exception as e:
375
+ if debug:
376
+ print(f"Failed check6 exception = {e}")
377
 
378
  if tmp_result == True and debug:
379
  print("PASSED")
 
386
  print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
387
  else:
388
  print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
389
+
390
 
391
  return results
392
 
393
+
394
  def custom_compare_(output, ground_truth):
395
 
396
  if isinstance(output, list):
 
435
  finally:
436
  pass
437
  return _inner_call_method(method)
438
+
utils.py β†’ tools/utils.py RENAMED
@@ -2,19 +2,19 @@ import itertools
2
  import numpy as np
3
  from typing import Dict
4
  from datasets import load_dataset
5
- import testing_util as test_util
6
 
7
 
8
  DATASET = "codeparrot/apps"
9
 
10
 
11
- def evaluate_generations(generations, level=["all"]):
12
  """We take the list of code generations and try to compile them
13
  and the run their corresponding unit tests which are retrieved from the APPS dataset.
14
 
15
  Args:
16
- generations: list of code generations, in the same order as APPS dataset samples
17
- level: list of levels to evaluate, can be "all", "introductory", "interview" or "competition"
18
 
19
  Returns:
20
  results: dictionary of results, key is the problem index, value is a list of results for each generation
@@ -22,20 +22,21 @@ def evaluate_generations(generations, level=["all"]):
22
  """
23
 
24
  # generations are code generations in the same order of the dataset
25
- apps_eval = load_dataset(DATASET, split="test", difficulties=level)
26
  results = {}
27
  for index in range(len(generations)):
28
- print(f"task {index}")
29
- generated_code = generations[index]
 
30
  sample = apps_eval[index]
31
  res = []
32
  # loop over the generations
33
- for o_idx, o in enumerate(generated_code):
34
  curr_res = [-2]
35
  try:
36
- print("Run test")
37
- curr_res = test_util.run_test(sample, test=o, debug=False)
38
- print("\nSuccessful compilation!")
39
  fixed = []
40
  for e in curr_res:
41
  if isinstance(e, np.ndarray):
@@ -45,15 +46,16 @@ def evaluate_generations(generations, level=["all"]):
45
  fixed.append(e)
46
  curr_res = fixed
47
  if not np.all(curr_res):
48
- print(f"Results were not True for all test cases") #{curr_res}")
 
49
  except Exception as e:
50
- print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
 
51
  break
52
  finally:
53
  assert isinstance(curr_res, list)
54
  res.append(curr_res)
55
  results[index] = res
56
-
57
  return results
58
 
59
 
@@ -75,37 +77,41 @@ def estimate_pass_at_k(num_samples, num_correct, k):
75
  return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
76
 
77
 
78
- def get_results(results: Dict, count_errors: bool = False, k_list: list = [1, 10, 100]):
79
  """
80
  Given the results evaluated against the testcases we output some statistics.
81
  For single generations:
82
- >>> example_results = {"0": [[-2]],"1": [[False,False]],"2": [[True,True]],"3": [[False,True,False,True]], "4": [[-1,-1]]}
83
  >>> get_results(example_results, count_errors=True)
 
84
  number of compile errors = 1 avg = 0.2
85
  number of runtime errors = 1 avg = 0.2
86
- number of test cases run = 5
87
- Test Case Average (average accuracy over problems) = 0.3
88
- Strict Accuracy (all test cases passed / total problems) = 0.2
 
89
 
90
  For multiple generations:
91
- >>> example_results = {"0": [[-2], [True, True, True]],"1": [[-1,-1, -1], [True, False, True]]}
92
- >>> get_results(example_results k_list=[1, 2])
 
93
  {'pass@1': 0.25, 'pass@2': 0.5}
 
94
  """
95
 
96
  metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
97
 
98
- if len(results["0"]) == 1:
99
  # for single generations we compute average accuracy and stric accuracy: original APPS metrics
100
  print("Computing accuracy metrics...")
101
  res = []
102
  per_prob_res = []
103
  all_correct = []
104
  for index in results:
105
- results[index] = np.array(results[index])
106
- res.extend(results[index])
107
- per_prob_res.append(np.mean(results[index]>0))
108
- all_correct.append(np.all(results[index]>0))
109
  # we count campilation and runtime errors once per pronlem
110
  compile_errors = len([e for e in res if -2 in e])
111
  runtime_errors = len([e for e in res if -1 in e])
@@ -115,8 +121,8 @@ def get_results(results: Dict, count_errors: bool = False, k_list: list = [1, 10
115
  print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}")
116
  print(f"number of problems evaluated = {total_testcases}")
117
 
118
- print(f"Test Case Average Accuracy (ver tests) = {np.mean(per_prob_res)}")
119
- print(f"Strict Accuracy (over problems that pass all tests) = {np.mean(all_correct)}")
120
  metrics["avg_accuracy"] = np.mean(per_prob_res)
121
  metrics["strict_accuracy"] = np.mean(all_correct)
122
 
@@ -143,16 +149,40 @@ def get_results(results: Dict, count_errors: bool = False, k_list: list = [1, 10
143
  metrics["pass_at_k"] = pass_at_k
144
  return metrics
145
 
146
- def compute_metrics(generations, k_list=[1, 10, 100], count_errors=True, level=["all"]):
147
  """Return metrics for the given generations.
148
  Args:
149
- generations: dict of generations, keyed by problem index
150
  k_list: list of k values to compute pass@k when using multiple generations
151
  count_errors: whether to count compilation and runtime errors when using single generations
152
- level: which level difficulty in APPS dataset was used for the given generations
153
  Returns:
154
  metrics: dict of metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  """
156
- results = evaluate_generations(generations, level=level)
157
  metrics = get_results(results, count_errors=count_errors, k_list=k_list)
158
- return metrics
 
 
 
 
2
  import numpy as np
3
  from typing import Dict
4
  from datasets import load_dataset
5
+ import tools.testing_util as test_util
6
 
7
 
8
  DATASET = "codeparrot/apps"
9
 
10
 
11
+ def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
12
  """We take the list of code generations and try to compile them
13
  and the run their corresponding unit tests which are retrieved from the APPS dataset.
14
 
15
  Args:
16
+ generations: list of code generations (same order as samples in APPS dataset)
17
+ level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
18
 
19
  Returns:
20
  results: dictionary of results, key is the problem index, value is a list of results for each generation
 
22
  """
23
 
24
  # generations are code generations in the same order of the dataset
25
+ apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
26
  results = {}
27
  for index in range(len(generations)):
28
+ # code generations for problem (index)
29
+ problem_generations = generations[index]
30
+ # get corresponding samples from APPS dataset
31
  sample = apps_eval[index]
32
  res = []
33
  # loop over the generations
34
+ for o_idx, o in enumerate(problem_generations):
35
  curr_res = [-2]
36
  try:
37
+ curr_res = test_util.run_test(sample, test=o, debug=debug)
38
+ #if debug:
39
+ print(f"\nSuccessful compilation of task {index}!")
40
  fixed = []
41
  for e in curr_res:
42
  if isinstance(e, np.ndarray):
 
46
  fixed.append(e)
47
  curr_res = fixed
48
  if not np.all(curr_res):
49
+ #if debug:
50
+ print(f"Results were not True for all test cases")
51
  except Exception as e:
52
+ if debug:
53
+ print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
54
  break
55
  finally:
56
  assert isinstance(curr_res, list)
57
  res.append(curr_res)
58
  results[index] = res
 
59
  return results
60
 
61
 
 
77
  return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
78
 
79
 
80
+ def get_results(results: Dict[int, list], count_errors: bool = False, k_list: list = [1, 10, 100]):
81
  """
82
  Given the results evaluated against the testcases we output some statistics.
83
  For single generations:
84
+ >>> example_results = {0: [[-2]], 1: [[False,False]], 2: [[True,True]], 3: [[False,True,False,True]], 4: [[-1,-1]]}
85
  >>> get_results(example_results, count_errors=True)
86
+ Computing accuracy metrics...
87
  number of compile errors = 1 avg = 0.2
88
  number of runtime errors = 1 avg = 0.2
89
+ number of problems evaluated = 5
90
+ Average Accuracy : 0.3
91
+ Strict Accuracy : 0.2
92
+ {'avg_accuracy': 0.3, 'strict_accuracy': 0.2, 'pass_at_k': None}
93
 
94
  For multiple generations:
95
+ >>> example_results = {0: [[-2], [True, True, True]], 1: [[-1,-1, -1], [True, False, True]]}
96
+ >>> get_results(example_results, k_list=[1, 2])
97
+ Computing pass@k metric for multiple generations...
98
  {'pass@1': 0.25, 'pass@2': 0.5}
99
+ {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 0.25, 'pass@2': 0.5}}
100
  """
101
 
102
  metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
103
 
104
+ if len(results[0]) == 1:
105
  # for single generations we compute average accuracy and stric accuracy: original APPS metrics
106
  print("Computing accuracy metrics...")
107
  res = []
108
  per_prob_res = []
109
  all_correct = []
110
  for index in results:
111
+ problem_results = np.asarray(results[index])
112
+ res.extend(problem_results)
113
+ per_prob_res.append(np.mean(problem_results > 0))
114
+ all_correct.append(np.all(problem_results > 0))
115
  # we count campilation and runtime errors once per pronlem
116
  compile_errors = len([e for e in res if -2 in e])
117
  runtime_errors = len([e for e in res if -1 in e])
 
121
  print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}")
122
  print(f"number of problems evaluated = {total_testcases}")
123
 
124
+ print(f"Average Accuracy : {np.mean(per_prob_res)}")
125
+ print(f"Strict Accuracy : {np.mean(all_correct)}")
126
  metrics["avg_accuracy"] = np.mean(per_prob_res)
127
  metrics["strict_accuracy"] = np.mean(all_correct)
128
 
 
149
  metrics["pass_at_k"] = pass_at_k
150
  return metrics
151
 
152
+ def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
153
  """Return metrics for the given generations.
154
  Args:
155
+ generations: list of code generations for each problem (each generation is a list of generations)
156
  k_list: list of k values to compute pass@k when using multiple generations
157
  count_errors: whether to count compilation and runtime errors when using single generations
158
+ level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
159
  Returns:
160
  metrics: dict of metrics
161
+
162
+ Examples:
163
+
164
+ >>> import json
165
+ >>> # lists of solutions to the two first APPS problems (note not all solutions pass all tests)
166
+ >>> solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
167
+ >>> solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
168
+ >>> single_solutions = [solution_sample1[:1], solution_sample2[:1]]
169
+ >>> compute_metrics(single_solutions, level="all")
170
+ Computing accuracy metrics...
171
+ number of compile errors = 0 avg = 0.0
172
+ number of runtime errors = 0 avg = 0.0
173
+ number of problems evaluated = 2
174
+ Average Accuracy : 1.0
175
+ Strict Accuracy : 1.0
176
+ {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
177
+ >>> multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
178
+ >>> compute_metrics(multiple_solutions, level="all", k_list=[1, 2, 3])
179
+ Computing pass@k metric for multiple generations...
180
+ {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
181
+ {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
182
  """
183
+ results = evaluate_generations(generations, level=level, debug=debug)
184
  metrics = get_results(results, count_errors=count_errors, k_list=k_list)
185
+ return metrics
186
+
187
+ #import doctest
188
+ #doctest.testmod()