Xueqing Wu
init
e20ef71
raw
history blame
4.64 kB
import ast
import importlib
import io
import os
import re
import string
import time
from functools import partial
from typing import List
import pysnooper
FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:"
EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \
' llm_query, bool_to_yesno, distance, best_image_match):'
class CompileTimeError:
pass
class ProgramRuntimeError:
pass
def process_trace(text, function_head, execution_function_head):
def remove_indent(lines):
n_space = 0
for i, c in enumerate(lines[0]):
if c == ' ':
n_space += 1
else:
break
return [line[n_space:] if line[0] == ' ' else line for line in lines]
def remove_pre_context(lines: List[str]): # lol, just a random use of List
for i in range(len(lines) - 1, -1, -1):
line = lines[i]
if execution_function_head in line:
# assert "call" in line # TODO: further double-check?
content = [line.replace(execution_function_head, function_head)] + lines[i + 1:]
if line[0] == ' ':
return remove_indent(content)
else:
return content
return []
def remove_post_context(lines):
for i, line in enumerate(lines):
if line.startswith("Source path:") and line.endswith(__file__):
return lines[:i]
elif line.startswith("Elapsed time"):
return lines[:i]
return lines
def remove_timestamp(lines):
ret = []
for line in lines:
if len(line) > 0 and line[0] in string.digits:
line = line[16:] # remove timestamp
ret.append(line)
return ret
def remove_tensor(line):
return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line)
lines = text.splitlines()
lines = remove_pre_context(lines)
lines = remove_post_context(lines)
lines = remove_timestamp(lines)
lines = [remove_tensor(line) for line in lines]
return '\n'.join(lines)
cnt = 0
def run_program_with_trace(code, image, input_type_, output_type_):
from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno
function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)
execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)
code = str(code)
if code.startswith("\ndef"):
code = code[1:] # TODO: just a temporary fix
if code.startswith('def'):
if code.startswith(function_head):
code = code.replace(function_head, '')
else:
print("--- Code with invalid format\n")
print(code)
code = execution_function_head + code
try:
code = ast.unparse(ast.parse(code))
except:
return None, CompileTimeError(), None
global cnt
cnt += 1
name = f'x{cnt}'
with open(f'{name}.py', 'w') as f:
f.write(code)
for _ in range(20):
try:
x = importlib.import_module(name)
except ModuleNotFoundError:
print("Errrr, import error. Wait a bit while.")
time.sleep(60) # I have no idea why it sometimes fails. Probably file system error
except Exception as e:
print("Import has error:", e)
break
else:
break
queues = [None, None]
image_patch_partial = partial(ImagePatch, queues=queues)
video_segment_partial = None
llm_query_partial = partial(llm_query, queues=queues)
# signal.signal(signal.SIGALRM, handler) # unfortunately doesn't work
# signal.alarm(60 * 20) # timeout = 10min, just in case while True
with io.StringIO() as f:
with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000):
result = None
error = None
try:
result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial,
llm_query_partial, bool_to_yesno, distance, best_image_match)
except:
error = ProgramRuntimeError()
# finally:
# signal.alarm(0)
os.remove(f'{name}.py')
f.seek(0)
traced = f.read(100000)
traced_processed = process_trace(traced, function_head, execution_function_head)
return result, error, traced_processed