loubnabnl HF staff commited on
Commit
942b4fc
1 Parent(s): 186b61d

add global timeout check

Browse files
Files changed (1) hide show
  1. utils.py +25 -1
utils.py CHANGED
@@ -1,11 +1,35 @@
1
  import itertools
 
 
2
  import numpy as np
3
  from typing import Dict
4
  from datasets import load_dataset
5
  from .testing_util import run_test
6
 
7
  DATASET = "codeparrot/apps"
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
11
  """We take the list of code generations and try to compile them
 
1
  import itertools
2
+ import json
3
+ import multiprocessing
4
  import numpy as np
5
  from typing import Dict
6
  from datasets import load_dataset
7
  from .testing_util import run_test
8
 
9
  DATASET = "codeparrot/apps"
10
+ TIMEOUT = 10
11
+
12
+ def check_correctness(sample, generation, timeout, debug=True):
13
+ """Check correctness of code generation with a global timeout.
14
+ The global timeout is to catch some extreme/rare cases not handled by the timeouts
15
+ inside `run_test`"""
16
+ def _temp_run(sample, generation, debug, result):
17
+ result.append(run_test(sample, test=generation, debug=debug))
18
+
19
+ manager = multiprocessing.Manager()
20
+ result = manager.list()
21
+ p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
22
+ p.start()
23
+ p.join(timeout=timeout + 1)
24
+ if p.is_alive():
25
+ p.kill()
26
+ if not result:
27
+ in_outs = json.loads(sample["input_output"])
28
+ # consider that all tests failed
29
+ result = [[-1 for i in range(len(in_outs["inputs"]))]]
30
+ if debug:
31
+ print(f"global timeout")
32
+ return result[0]
33
 
34
  def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
35
  """We take the list of code generations and try to compile them