loubnabnl HF staff commited on
Commit
f117372
1 Parent(s): b73b144

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +2 -445
utils.py CHANGED
@@ -2,7 +2,7 @@ import itertools
2
  import numpy as np
3
  from typing import Dict
4
  from datasets import load_dataset
5
- #import testing_util as test_util
6
 
7
  DATASET = "codeparrot/apps"
8
 
@@ -33,7 +33,7 @@ def evaluate_generations(generations: list, level: str = "all", debug: bool = Fa
33
  for o_idx, o in enumerate(problem_generations):
34
  curr_res = [-2]
35
  try:
36
- curr_res = run_test(sample, test=o, debug=debug)
37
  if debug:
38
  print(f"\nSuccessful compilation of task {index}!")
39
  fixed = []
@@ -185,446 +185,3 @@ def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=
185
 
186
  #import doctest
187
  #doctest.testmod()
188
-
189
- #---------------------------------------------------------------------------------------------
190
- # below is the content of testing_util.py as a temporary workaround for the relative path problem
191
- #----------------------------------------------------------------------------------------------
192
-
193
- import json
194
- import sys
195
- import faulthandler
196
-
197
- # used for debugging to time steps
198
- from datetime import datetime
199
-
200
- # to run the solution files we're using a timing based approach
201
- import signal
202
-
203
- import numpy as np
204
- # for capturing the stdout
205
- from io import StringIO
206
- # used for testing the code that reads from input
207
- from unittest.mock import patch, mock_open
208
-
209
- from pyext import RuntimeModule
210
-
211
- from enum import Enum
212
- class CODE_TYPE(Enum):
213
- call_based = 0
214
- standard_input = 1
215
-
216
- # stuff for setting up signal timer
217
- class TimeoutException(Exception):
218
- pass
219
- def timeout_handler(signum, frame):
220
- print("alarm went off")
221
- #return
222
- raise TimeoutException
223
- signal.signal(signal.SIGALRM, timeout_handler)
224
- timeout = 4 # seconds
225
-
226
- # used to capture stdout as a list
227
- # from https://stackoverflow.com/a/16571630/6416660
228
- # alternative use redirect_stdout() from contextlib
229
- class Capturing(list):
230
- def __enter__(self):
231
- self._stdout = sys.stdout
232
- sys.stdout = self._stringio = StringIO()
233
- # Make closing the StringIO a no-op
234
- self._stringio.close = lambda x: 1
235
- return self
236
- def __exit__(self, *args):
237
- self.extend(self._stringio.getvalue().splitlines())
238
- del self._stringio # free up some memory
239
- sys.stdout = self._stdout
240
-
241
-
242
- def run_test(sample, test=None, debug=False):
243
- """
244
- if test(generated_code) is not None it'll try to run the code.
245
- otherwise it'll just return an input and output pair.
246
- """
247
- if debug:
248
- print(f"start = {datetime.now().time()}")
249
-
250
- try:
251
- in_outs = json.loads(sample["input_output"])
252
- except ValueError:
253
- in_outs = None
254
- if in_outs:
255
- if in_outs.get("fn_name") is None:
256
- which_type = CODE_TYPE.standard_input # Standard input
257
- method_name = None
258
- else:
259
- which_type = CODE_TYPE.call_based # Call-based
260
- method_name = in_outs["fn_name"]
261
-
262
- if debug:
263
- print(f"loaded input_output = {datetime.now().time()}")
264
-
265
- if test is None:
266
- return in_outs
267
- elif test is not None:
268
- results = []
269
- sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
270
- if debug:
271
- print(f"loading test code = {datetime.now().time()}")
272
-
273
- if which_type == CODE_TYPE.call_based:
274
- sol += test
275
- if debug:
276
- print(f"sol = {sol}")
277
- signal.alarm(timeout)
278
- try:
279
- tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
280
- if "class Solution" not in test:
281
- tmp = tmp_sol
282
- else:
283
- tmp = tmp_sol.Solution()
284
- signal.alarm(0)
285
- except Exception as e:
286
- signal.alarm(0)
287
- if debug:
288
- print(f"type 0 compilation error = {e}")
289
- results.append(-2)
290
- return results
291
- signal.alarm(0)
292
-
293
- elif which_type == CODE_TYPE.standard_input:
294
- # sol
295
- tmp_test = test.split("\n")
296
-
297
- new_test = []
298
- for x in tmp_test:
299
- if (not x.startswith("from ")) and (not x.startswith("import ")):
300
- new_test.append("\t" + x + "\n")
301
- else:
302
- new_test.append(x + "\n")
303
- tmp_test = new_test
304
-
305
- new_test = ""
306
- started = False
307
- for i in tmp_test:
308
- if i.startswith("\t") and not started:
309
- new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
310
- new_test += "def code():\n"
311
- new_test += i
312
- started = True
313
- elif started and ((i.startswith("from ")) or (i.startswith("import "))):
314
- new_test += "\t" + i
315
- else:
316
- new_test += i
317
- tmp_test = new_test
318
-
319
- sol += tmp_test
320
- if debug:
321
- print(f"sol = {sol}")
322
- method_name = "code"
323
- signal.alarm(timeout)
324
- try:
325
- tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
326
- tmp = tmp_sol
327
- signal.alarm(0)
328
- except Exception as e:
329
- signal.alarm(0)
330
- if debug:
331
- print(f"type 1 compilation error = {e}")
332
- results.append(-2)
333
- return results
334
- signal.alarm(0)
335
- if debug:
336
- print(f"get method = {datetime.now().time()}")
337
-
338
- try:
339
- method = getattr(tmp, method_name) # get_attr second arg must be str
340
- except:
341
- signal.alarm(0)
342
- e = sys.exc_info()
343
- print(f"unable to get function error = {e}")
344
- return results
345
-
346
- for index, inputs in enumerate(in_outs["inputs"]):
347
- # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
348
- try:
349
- if isinstance(inputs[0], dict):
350
- inputs = [{int(k): v for k,v in inputs[0].items()}]
351
- except:
352
- True
353
- try:
354
- if isinstance(in_outs["outputs"][index], dict):
355
- in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
356
- except:
357
- True
358
- try:
359
- if isinstance(in_outs["outputs"][index][0], dict):
360
- in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
361
- except:
362
- True
363
-
364
- if debug:
365
- print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
366
- if which_type == CODE_TYPE.call_based: # Call-based
367
- signal.alarm(timeout)
368
- faulthandler.enable()
369
- try:
370
- output = method(*inputs)
371
-
372
- # ground truth sequences are not tuples
373
- if isinstance(output, tuple):
374
- output = list(output)
375
-
376
- tmp_result = output == in_outs["outputs"][index]
377
- if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
378
- tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
379
-
380
- # ground truth sequences are not tuples
381
- try:
382
- if isinstance(output[0], tuple):
383
- tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
384
- except:
385
- True
386
- results.append(tmp_result)
387
-
388
- # reset the alarm
389
- signal.alarm(0)
390
- except Exception as e:
391
- signal.alarm(0)
392
- faulthandler.disable()
393
- print(f"Standard input runtime error or time limit exceeded error = {e}")
394
- results.append(-1)
395
- continue
396
- faulthandler.disable()
397
- signal.alarm(0)
398
- if debug:
399
- print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
400
- elif which_type == CODE_TYPE.standard_input: # Standard input
401
- faulthandler.enable()
402
- signal.alarm(timeout)
403
- passed = False
404
-
405
- if isinstance(inputs, list):
406
- inputs = "\n".join(inputs)
407
- if isinstance(in_outs['outputs'][index], list):
408
- in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
409
-
410
- with Capturing() as output:
411
- try:
412
- call_method(method, inputs)
413
- # reset the alarm
414
- signal.alarm(0)
415
- passed = True
416
- except Exception as e:
417
- # runtime error or took too long
418
- signal.alarm(0)
419
- print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
420
- results.append(-1)
421
- signal.alarm(0)
422
-
423
- if not passed:
424
- if debug:
425
- nl = "\n"
426
- if not isinstance(inputs, list):
427
- print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
428
- else:
429
- print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
430
- continue
431
-
432
- if passed and debug:
433
- print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
434
-
435
- if custom_compare_(output, in_outs['outputs'][index]):
436
- tmp_result = True
437
- results.append(tmp_result)
438
- continue
439
-
440
- # ground truth sequences are expressed as lists not tuples
441
- if isinstance(output, tuple):
442
- output = list(output)
443
-
444
- tmp_result = False
445
- try:
446
- tmp_result = (output == [in_outs["outputs"][index]])
447
- if isinstance(in_outs["outputs"][index], list):
448
- tmp_result = tmp_result or (output == in_outs["outputs"][index])
449
- if isinstance(output[0], str):
450
- tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
451
- except Exception as e:
452
- if debug:
453
- print(f"Failed check1 exception = {e}")
454
- pass
455
-
456
- if tmp_result == True:
457
- results.append(tmp_result)
458
- continue
459
-
460
- # try one more time without \n
461
- if isinstance(in_outs["outputs"][index], list):
462
- for tmp_index, i in enumerate(in_outs["outputs"][index]):
463
- in_outs["outputs"][index][tmp_index] = i.split("\n")
464
- in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
465
- else:
466
- in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
467
- in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
468
- in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
469
-
470
- try:
471
- tmp_result = (output == [in_outs["outputs"][index]])
472
- if isinstance(in_outs["outputs"][index], list):
473
- tmp_result = tmp_result or (output == in_outs["outputs"][index])
474
- except Exception as e:
475
- if debug:
476
- print(f"Failed check2 exception = {e}")
477
- pass
478
-
479
- if tmp_result == True:
480
- results.append(tmp_result)
481
- continue
482
-
483
- # try by converting the output into a split up list too
484
- if isinstance(output, list):
485
- output = list(filter(len, output))
486
-
487
- if debug:
488
- nl = "\n"
489
- if not isinstance(inputs, list):
490
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
491
- else:
492
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
493
-
494
- if tmp_result == True:
495
- results.append(tmp_result)
496
- continue
497
-
498
- try:
499
- tmp_result = (output == [in_outs["outputs"][index]])
500
- if isinstance(in_outs["outputs"][index], list):
501
- tmp_result = tmp_result or (output == in_outs["outputs"][index])
502
- except Exception as e:
503
- if debug:
504
- print(f"Failed check3 exception = {e}")
505
- pass
506
-
507
- try:
508
- output_float = [float(e) for e in output]
509
- gt_float = [float(e) for e in in_outs['outputs'][index]]
510
- tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
511
- except Exception as e:
512
- pass
513
- try:
514
- if isinstance(output[0], list):
515
- output_float = [float(e) for e in output[0]]
516
- gt_float = [float(e) for e in in_outs['outputs'][index][0]]
517
- tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
518
- except Exception as e:
519
- pass
520
-
521
- if tmp_result == True:
522
- results.append(tmp_result)
523
- continue
524
-
525
- # try by converting the stuff into split up list
526
- if isinstance(in_outs["outputs"][index], list):
527
- for tmp_index, i in enumerate(in_outs["outputs"][index]):
528
- in_outs["outputs"][index][tmp_index] = set(i.split())
529
- else:
530
- in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
531
-
532
- try:
533
- tmp_result = (output == in_outs["outputs"][index])
534
- except Exception as e:
535
- if debug:
536
- print(f"Failed check4 exception = {e}")
537
- continue
538
-
539
- if tmp_result == True:
540
- results.append(tmp_result)
541
- continue
542
-
543
- # try by converting the output into a split up list too
544
- if isinstance(output, list):
545
- for tmp_index, i in enumerate(output):
546
- output[tmp_index] = i.split()
547
- output = list(filter(len, output))
548
- for tmp_index, i in enumerate(output):
549
- output[tmp_index] = set(i)
550
- else:
551
- output = output.split()
552
- output = list(filter(len, output))
553
- output = set(output)
554
-
555
- try:
556
- tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
557
- except Exception as e:
558
- if debug:
559
- print(f"Failed check5 exception = {e}")
560
-
561
-
562
- # if they are all numbers, round so that similar numbers are treated as identical
563
- try:
564
- tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
565
- set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
566
- except Exception as e:
567
- if debug:
568
- print(f"Failed check6 exception = {e}")
569
-
570
- if tmp_result == True and debug:
571
- print("PASSED")
572
-
573
- results.append(tmp_result)
574
-
575
- if debug:
576
- nl = "\n"
577
- if not isinstance(inputs, list):
578
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
579
- else:
580
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
581
-
582
-
583
- return results
584
-
585
-
586
- def custom_compare_(output, ground_truth):
587
-
588
- if isinstance(output, list):
589
- output_1 = "\n".join(output)
590
- if stripped_string_compare(output_1, ground_truth):
591
- return True
592
-
593
- if isinstance(output, list):
594
- output_2 = [o.lstrip().rstrip() for o in output]
595
- output_2 = "\n".join(output_2)
596
- if stripped_string_compare(output_2, ground_truth):
597
- return True
598
-
599
- return False
600
-
601
- def stripped_string_compare(s1, s2):
602
- s1 = s1.lstrip().rstrip()
603
- s2 = s2.lstrip().rstrip()
604
- return s1 == s2
605
-
606
- def call_method(method, inputs):
607
-
608
- if isinstance(inputs, list):
609
- inputs = "\n".join(inputs)
610
-
611
- inputs_line_iterator = iter(inputs.split("\n"))
612
-
613
- # sys.setrecursionlimit(10000)
614
-
615
- # @patch('builtins.input', side_effect=inputs.split("\n"))
616
- @patch('builtins.open', mock_open(read_data=inputs))
617
- @patch('sys.stdin', StringIO(inputs))
618
- @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
619
- @patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
620
- @patch('sys.stdin.read', lambda *args: inputs)
621
- # @patch('sys.stdout.write', print)
622
- def _inner_call_method(_method):
623
- try:
624
- return _method()
625
- except SystemExit as e:
626
- pass
627
- finally:
628
- pass
629
- return _inner_call_method(method)
630
-
 
2
  import numpy as np
3
  from typing import Dict
4
  from datasets import load_dataset
5
+ import testing_util as test_util
6
 
7
  DATASET = "codeparrot/apps"
8
 
 
33
  for o_idx, o in enumerate(problem_generations):
34
  curr_res = [-2]
35
  try:
36
+ curr_res = test_util.run_test(sample, test=o, debug=debug)
37
  if debug:
38
  print(f"\nSuccessful compilation of task {index}!")
39
  fixed = []
 
185
 
186
  #import doctest
187
  #doctest.testmod()