loubnabnl HF staff commited on
Commit
0535e18
1 Parent(s): 9f92296

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +445 -3
utils.py CHANGED
@@ -2,7 +2,6 @@ import itertools
2
  import numpy as np
3
  from typing import Dict
4
  from datasets import load_dataset
5
- import testing_util as test_util
6
 
7
 
8
  DATASET = "codeparrot/apps"
@@ -34,7 +33,7 @@ def evaluate_generations(generations: list, level: str = "all", debug: bool = Fa
34
  for o_idx, o in enumerate(problem_generations):
35
  curr_res = [-2]
36
  try:
37
- curr_res = test_util.run_test(sample, test=o, debug=debug)
38
  if debug:
39
  print(f"\nSuccessful compilation of task {index}!")
40
  fixed = []
@@ -185,4 +184,447 @@ def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=
185
  return metrics
186
 
187
  #import doctest
188
- #doctest.testmod()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  from typing import Dict
4
  from datasets import load_dataset
 
5
 
6
 
7
  DATASET = "codeparrot/apps"
 
33
  for o_idx, o in enumerate(problem_generations):
34
  curr_res = [-2]
35
  try:
36
+ curr_res = run_test(sample, test=o, debug=debug)
37
  if debug:
38
  print(f"\nSuccessful compilation of task {index}!")
39
  fixed = []
 
184
  return metrics
185
 
186
  #import doctest
187
+ #doctest.testmod()
188
+
189
+ #---------------------------------------------------------------------------------------------
190
+ # below is the content of testing_util.py as a temporary workaround for the relative path problem
191
+ #----------------------------------------------------------------------------------------------
192
+
193
+ import json
194
+ import sys
195
+ import faulthandler
196
+
197
+ # used for debugging to time steps
198
+ from datetime import datetime
199
+
200
+ # to run the solution files we're using a timing based approach
201
+ import signal
202
+
203
+ import numpy as np
204
+ # for capturing the stdout
205
+ from io import StringIO
206
+ # used for testing the code that reads from input
207
+ from unittest.mock import patch, mock_open
208
+
209
+ from pyext import RuntimeModule
210
+
211
+ from enum import Enum
212
+ class CODE_TYPE(Enum):
213
+ call_based = 0
214
+ standard_input = 1
215
+
216
+ # stuff for setting up signal timer
217
+ class TimeoutException(Exception):
218
+ pass
219
+ def timeout_handler(signum, frame):
220
+ print("alarm went off")
221
+ #return
222
+ raise TimeoutException
223
+ signal.signal(signal.SIGALRM, timeout_handler)
224
+ timeout = 4 # seconds
225
+
226
+ # used to capture stdout as a list
227
+ # from https://stackoverflow.com/a/16571630/6416660
228
+ # alternative use redirect_stdout() from contextlib
229
+ class Capturing(list):
230
+ def __enter__(self):
231
+ self._stdout = sys.stdout
232
+ sys.stdout = self._stringio = StringIO()
233
+ # Make closing the StringIO a no-op
234
+ self._stringio.close = lambda x: 1
235
+ return self
236
+ def __exit__(self, *args):
237
+ self.extend(self._stringio.getvalue().splitlines())
238
+ del self._stringio # free up some memory
239
+ sys.stdout = self._stdout
240
+
241
+
242
+ def run_test(sample, test=None, debug=False):
243
+ """
244
+ if test(generated_code) is not None it'll try to run the code.
245
+ otherwise it'll just return an input and output pair.
246
+ """
247
+ if debug:
248
+ print(f"start = {datetime.now().time()}")
249
+
250
+ try:
251
+ in_outs = json.loads(sample["input_output"])
252
+ except ValueError:
253
+ in_outs = None
254
+ if in_outs:
255
+ if in_outs.get("fn_name") is None:
256
+ which_type = CODE_TYPE.standard_input # Standard input
257
+ method_name = None
258
+ else:
259
+ which_type = CODE_TYPE.call_based # Call-based
260
+ method_name = in_outs["fn_name"]
261
+
262
+ if debug:
263
+ print(f"loaded input_output = {datetime.now().time()}")
264
+
265
+ if test is None:
266
+ return in_outs
267
+ elif test is not None:
268
+ results = []
269
+ sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
270
+ if debug:
271
+ print(f"loading test code = {datetime.now().time()}")
272
+
273
+ if which_type == CODE_TYPE.call_based:
274
+ sol += test
275
+ if debug:
276
+ print(f"sol = {sol}")
277
+ signal.alarm(timeout)
278
+ try:
279
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
280
+ if "class Solution" not in test:
281
+ tmp = tmp_sol
282
+ else:
283
+ tmp = tmp_sol.Solution()
284
+ signal.alarm(0)
285
+ except Exception as e:
286
+ signal.alarm(0)
287
+ if debug:
288
+ print(f"type 0 compilation error = {e}")
289
+ results.append(-2)
290
+ return results
291
+ signal.alarm(0)
292
+
293
+ elif which_type == CODE_TYPE.standard_input:
294
+ # sol
295
+ tmp_test = test.split("\n")
296
+
297
+ new_test = []
298
+ for x in tmp_test:
299
+ if (not x.startswith("from ")) and (not x.startswith("import ")):
300
+ new_test.append("\t" + x + "\n")
301
+ else:
302
+ new_test.append(x + "\n")
303
+ tmp_test = new_test
304
+
305
+ new_test = ""
306
+ started = False
307
+ for i in tmp_test:
308
+ if i.startswith("\t") and not started:
309
+ new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
310
+ new_test += "def code():\n"
311
+ new_test += i
312
+ started = True
313
+ elif started and ((i.startswith("from ")) or (i.startswith("import "))):
314
+ new_test += "\t" + i
315
+ else:
316
+ new_test += i
317
+ tmp_test = new_test
318
+
319
+ sol += tmp_test
320
+ if debug:
321
+ print(f"sol = {sol}")
322
+ method_name = "code"
323
+ signal.alarm(timeout)
324
+ try:
325
+ tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
326
+ tmp = tmp_sol
327
+ signal.alarm(0)
328
+ except Exception as e:
329
+ signal.alarm(0)
330
+ if debug:
331
+ print(f"type 1 compilation error = {e}")
332
+ results.append(-2)
333
+ return results
334
+ signal.alarm(0)
335
+ if debug:
336
+ print(f"get method = {datetime.now().time()}")
337
+
338
+ try:
339
+ method = getattr(tmp, method_name) # get_attr second arg must be str
340
+ except:
341
+ signal.alarm(0)
342
+ e = sys.exc_info()
343
+ print(f"unable to get function error = {e}")
344
+ return results
345
+
346
+ for index, inputs in enumerate(in_outs["inputs"]):
347
+ # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
348
+ try:
349
+ if isinstance(inputs[0], dict):
350
+ inputs = [{int(k): v for k,v in inputs[0].items()}]
351
+ except:
352
+ True
353
+ try:
354
+ if isinstance(in_outs["outputs"][index], dict):
355
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
356
+ except:
357
+ True
358
+ try:
359
+ if isinstance(in_outs["outputs"][index][0], dict):
360
+ in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
361
+ except:
362
+ True
363
+
364
+ if debug:
365
+ print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
366
+ if which_type == CODE_TYPE.call_based: # Call-based
367
+ signal.alarm(timeout)
368
+ faulthandler.enable()
369
+ try:
370
+ output = method(*inputs)
371
+
372
+ # ground truth sequences are not tuples
373
+ if isinstance(output, tuple):
374
+ output = list(output)
375
+
376
+ tmp_result = output == in_outs["outputs"][index]
377
+ if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
378
+ tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
379
+
380
+ # ground truth sequences are not tuples
381
+ try:
382
+ if isinstance(output[0], tuple):
383
+ tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
384
+ except:
385
+ True
386
+ results.append(tmp_result)
387
+
388
+ # reset the alarm
389
+ signal.alarm(0)
390
+ except Exception as e:
391
+ signal.alarm(0)
392
+ faulthandler.disable()
393
+ print(f"Standard input runtime error or time limit exceeded error = {e}")
394
+ results.append(-1)
395
+ continue
396
+ faulthandler.disable()
397
+ signal.alarm(0)
398
+ if debug:
399
+ print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
400
+ elif which_type == CODE_TYPE.standard_input: # Standard input
401
+ faulthandler.enable()
402
+ signal.alarm(timeout)
403
+ passed = False
404
+
405
+ if isinstance(inputs, list):
406
+ inputs = "\n".join(inputs)
407
+ if isinstance(in_outs['outputs'][index], list):
408
+ in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
409
+
410
+ with Capturing() as output:
411
+ try:
412
+ call_method(method, inputs)
413
+ # reset the alarm
414
+ signal.alarm(0)
415
+ passed = True
416
+ except Exception as e:
417
+ # runtime error or took too long
418
+ signal.alarm(0)
419
+ print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
420
+ results.append(-1)
421
+ signal.alarm(0)
422
+
423
+ if not passed:
424
+ if debug:
425
+ nl = "\n"
426
+ if not isinstance(inputs, list):
427
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
428
+ else:
429
+ print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
430
+ continue
431
+
432
+ if passed and debug:
433
+ print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
434
+
435
+ if custom_compare_(output, in_outs['outputs'][index]):
436
+ tmp_result = True
437
+ results.append(tmp_result)
438
+ continue
439
+
440
+ # ground truth sequences are expressed as lists not tuples
441
+ if isinstance(output, tuple):
442
+ output = list(output)
443
+
444
+ tmp_result = False
445
+ try:
446
+ tmp_result = (output == [in_outs["outputs"][index]])
447
+ if isinstance(in_outs["outputs"][index], list):
448
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
449
+ if isinstance(output[0], str):
450
+ tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
451
+ except Exception as e:
452
+ if debug:
453
+ print(f"Failed check1 exception = {e}")
454
+ pass
455
+
456
+ if tmp_result == True:
457
+ results.append(tmp_result)
458
+ continue
459
+
460
+ # try one more time without \n
461
+ if isinstance(in_outs["outputs"][index], list):
462
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
463
+ in_outs["outputs"][index][tmp_index] = i.split("\n")
464
+ in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
465
+ else:
466
+ in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
467
+ in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
468
+ in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
469
+
470
+ try:
471
+ tmp_result = (output == [in_outs["outputs"][index]])
472
+ if isinstance(in_outs["outputs"][index], list):
473
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
474
+ except Exception as e:
475
+ if debug:
476
+ print(f"Failed check2 exception = {e}")
477
+ pass
478
+
479
+ if tmp_result == True:
480
+ results.append(tmp_result)
481
+ continue
482
+
483
+ # try by converting the output into a split up list too
484
+ if isinstance(output, list):
485
+ output = list(filter(len, output))
486
+
487
+ if debug:
488
+ nl = "\n"
489
+ if not isinstance(inputs, list):
490
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
491
+ else:
492
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
493
+
494
+ if tmp_result == True:
495
+ results.append(tmp_result)
496
+ continue
497
+
498
+ try:
499
+ tmp_result = (output == [in_outs["outputs"][index]])
500
+ if isinstance(in_outs["outputs"][index], list):
501
+ tmp_result = tmp_result or (output == in_outs["outputs"][index])
502
+ except Exception as e:
503
+ if debug:
504
+ print(f"Failed check3 exception = {e}")
505
+ pass
506
+
507
+ try:
508
+ output_float = [float(e) for e in output]
509
+ gt_float = [float(e) for e in in_outs['outputs'][index]]
510
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
511
+ except Exception as e:
512
+ pass
513
+ try:
514
+ if isinstance(output[0], list):
515
+ output_float = [float(e) for e in output[0]]
516
+ gt_float = [float(e) for e in in_outs['outputs'][index][0]]
517
+ tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
518
+ except Exception as e:
519
+ pass
520
+
521
+ if tmp_result == True:
522
+ results.append(tmp_result)
523
+ continue
524
+
525
+ # try by converting the stuff into split up list
526
+ if isinstance(in_outs["outputs"][index], list):
527
+ for tmp_index, i in enumerate(in_outs["outputs"][index]):
528
+ in_outs["outputs"][index][tmp_index] = set(i.split())
529
+ else:
530
+ in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
531
+
532
+ try:
533
+ tmp_result = (output == in_outs["outputs"][index])
534
+ except Exception as e:
535
+ if debug:
536
+ print(f"Failed check4 exception = {e}")
537
+ continue
538
+
539
+ if tmp_result == True:
540
+ results.append(tmp_result)
541
+ continue
542
+
543
+ # try by converting the output into a split up list too
544
+ if isinstance(output, list):
545
+ for tmp_index, i in enumerate(output):
546
+ output[tmp_index] = i.split()
547
+ output = list(filter(len, output))
548
+ for tmp_index, i in enumerate(output):
549
+ output[tmp_index] = set(i)
550
+ else:
551
+ output = output.split()
552
+ output = list(filter(len, output))
553
+ output = set(output)
554
+
555
+ try:
556
+ tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
557
+ except Exception as e:
558
+ if debug:
559
+ print(f"Failed check5 exception = {e}")
560
+
561
+
562
+ # if they are all numbers, round so that similar numbers are treated as identical
563
+ try:
564
+ tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
565
+ set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
566
+ except Exception as e:
567
+ if debug:
568
+ print(f"Failed check6 exception = {e}")
569
+
570
+ if tmp_result == True and debug:
571
+ print("PASSED")
572
+
573
+ results.append(tmp_result)
574
+
575
+ if debug:
576
+ nl = "\n"
577
+ if not isinstance(inputs, list):
578
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
579
+ else:
580
+ print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
581
+
582
+
583
+ return results
584
+
585
+
586
+ def custom_compare_(output, ground_truth):
587
+
588
+ if isinstance(output, list):
589
+ output_1 = "\n".join(output)
590
+ if stripped_string_compare(output_1, ground_truth):
591
+ return True
592
+
593
+ if isinstance(output, list):
594
+ output_2 = [o.lstrip().rstrip() for o in output]
595
+ output_2 = "\n".join(output_2)
596
+ if stripped_string_compare(output_2, ground_truth):
597
+ return True
598
+
599
+ return False
600
+
601
+ def stripped_string_compare(s1, s2):
602
+ s1 = s1.lstrip().rstrip()
603
+ s2 = s2.lstrip().rstrip()
604
+ return s1 == s2
605
+
606
+ def call_method(method, inputs):
607
+
608
+ if isinstance(inputs, list):
609
+ inputs = "\n".join(inputs)
610
+
611
+ inputs_line_iterator = iter(inputs.split("\n"))
612
+
613
+ # sys.setrecursionlimit(10000)
614
+
615
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
616
+ @patch('builtins.open', mock_open(read_data=inputs))
617
+ @patch('sys.stdin', StringIO(inputs))
618
+ @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
619
+ @patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
620
+ @patch('sys.stdin.read', lambda *args: inputs)
621
+ # @patch('sys.stdout.write', print)
622
+ def _inner_call_method(_method):
623
+ try:
624
+ return _method()
625
+ except SystemExit as e:
626
+ pass
627
+ finally:
628
+ pass
629
+ return _inner_call_method(method)
630
+