Spaces:
Runtime error
Runtime error
import ast | |
import importlib | |
import io | |
import os | |
import re | |
import string | |
import time | |
from functools import partial | |
from typing import List | |
import pysnooper | |
FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:" | |
EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \ | |
' llm_query, bool_to_yesno, distance, best_image_match):' | |
class CompileTimeError: | |
pass | |
class ProgramRuntimeError: | |
pass | |
def process_trace(text, function_head, execution_function_head): | |
def remove_indent(lines): | |
n_space = 0 | |
for i, c in enumerate(lines[0]): | |
if c == ' ': | |
n_space += 1 | |
else: | |
break | |
return [line[n_space:] if line[0] == ' ' else line for line in lines] | |
def remove_pre_context(lines: List[str]): # lol, just a random use of List | |
for i in range(len(lines) - 1, -1, -1): | |
line = lines[i] | |
if execution_function_head in line: | |
# assert "call" in line # TODO: further double-check? | |
content = [line.replace(execution_function_head, function_head)] + lines[i + 1:] | |
if line[0] == ' ': | |
return remove_indent(content) | |
else: | |
return content | |
return [] | |
def remove_post_context(lines): | |
for i, line in enumerate(lines): | |
if line.startswith("Source path:") and line.endswith(__file__): | |
return lines[:i] | |
elif line.startswith("Elapsed time"): | |
return lines[:i] | |
return lines | |
def remove_timestamp(lines): | |
ret = [] | |
for line in lines: | |
if len(line) > 0 and line[0] in string.digits: | |
line = line[16:] # remove timestamp | |
ret.append(line) | |
return ret | |
def remove_tensor(line): | |
return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line) | |
lines = text.splitlines() | |
lines = remove_pre_context(lines) | |
lines = remove_post_context(lines) | |
lines = remove_timestamp(lines) | |
lines = [remove_tensor(line) for line in lines] | |
return '\n'.join(lines) | |
cnt = 0 | |
def run_program_with_trace(code, image, input_type_, output_type_): | |
from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno | |
function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) | |
execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) | |
code = str(code) | |
if code.startswith("\ndef"): | |
code = code[1:] # TODO: just a temporary fix | |
if code.startswith('def'): | |
if code.startswith(function_head): | |
code = code.replace(function_head, '') | |
else: | |
print("--- Code with invalid format\n") | |
print(code) | |
code = execution_function_head + code | |
try: | |
code = ast.unparse(ast.parse(code)) | |
except: | |
return None, CompileTimeError(), None | |
global cnt | |
cnt += 1 | |
name = f'x{cnt}' | |
with open(f'{name}.py', 'w') as f: | |
f.write(code) | |
for _ in range(20): | |
try: | |
x = importlib.import_module(name) | |
except ModuleNotFoundError: | |
print("Errrr, import error. Wait a bit while.") | |
time.sleep(60) # I have no idea why it sometimes fails. Probably file system error | |
except Exception as e: | |
print("Import has error:", e) | |
break | |
else: | |
break | |
queues = [None, None] | |
image_patch_partial = partial(ImagePatch, queues=queues) | |
video_segment_partial = None | |
llm_query_partial = partial(llm_query, queues=queues) | |
# signal.signal(signal.SIGALRM, handler) # unfortunately doesn't work | |
# signal.alarm(60 * 20) # timeout = 10min, just in case while True | |
with io.StringIO() as f: | |
with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000): | |
result = None | |
error = None | |
try: | |
result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial, | |
llm_query_partial, bool_to_yesno, distance, best_image_match) | |
except: | |
error = ProgramRuntimeError() | |
# finally: | |
# signal.alarm(0) | |
os.remove(f'{name}.py') | |
f.seek(0) | |
traced = f.read(100000) | |
traced_processed = process_trace(traced, function_head, execution_function_head) | |
return result, error, traced_processed | |