Spaces:
Runtime error
Runtime error
import inspect | |
import json | |
import os | |
import random | |
from typing import Literal, cast | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from gradio.data_classes import InterfaceTypes | |
from gradio.flagging import CSVLogger | |
from torchvision import transforms | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
from trace_exec import run_program_with_trace, CompileTimeError | |
from vision_processes import load_models | |
print("-" * 10, "Loading models...") | |
load_models() | |
with open('joint.prompt') as f: | |
prompt_template = f.read().strip() | |
INPUT_TYPE = 'image' | |
OUTPUT_TYPE = 'str' | |
SIGNATURE = f'def execute_command({INPUT_TYPE}) -> {OUTPUT_TYPE}:' | |
def generate(model, input_text): | |
torch.cuda.empty_cache() | |
print("-" * 10, "Before loading LLM:") | |
print(torch.cuda.memory_summary()) | |
dtype = os.environ.get("CODELLAMA_DTYPE") | |
assert dtype in ['bfloat16', '8bit', '4bit', ] | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
model = LlamaForCausalLM.from_pretrained( | |
model, | |
device_map="auto", | |
load_in_8bit=dtype == "8bit", | |
load_in_4bit=dtype == "4bit", | |
torch_dtype=torch.bfloat16 if dtype == "bfloat16" else None, | |
) | |
print("-" * 10, "LLM loaded:") | |
print(model) | |
print(torch.cuda.memory_summary()) | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
generated_ids = model.generate( | |
input_ids.to('cuda'), max_new_tokens=256, stop_strings=["\n\n"], do_sample=False, tokenizer=tokenizer | |
) | |
generated_ids = generated_ids[0][input_ids.shape[1]:] | |
text = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
del model | |
torch.cuda.empty_cache() | |
print("-" * 10, "After loading LLM:") | |
print(torch.cuda.memory_summary()) | |
return text | |
def to_custom_trace(result, error, traced): | |
if traced is None: | |
assert isinstance(error, CompileTimeError) | |
traced = 'Compile Error' | |
return "-> {}\n\n--- Trace\n\n{}".format(result, traced) | |
def answer_from_trace(x): | |
assert x.startswith("->") | |
return x[2:].splitlines()[0].strip() | |
def debug(image, question, code, traced_info): | |
# critic | |
prompt = f"# Given an image: {question}\n{code}\n\n{traced_info}\n\n# Program is" | |
print("--- For debug: critic prompt is ---") | |
print(prompt) | |
print("---\n") | |
critic_out = generate("VDebugger/VDebugger-critic-generalist-7B", prompt) | |
incorrect = critic_out.strip().startswith('wrong') | |
critic_out = "# Program is" + critic_out | |
if not incorrect: | |
yield code, traced_info, critic_out, "N/A", "N/A", answer_from_trace(traced_info) | |
return | |
else: | |
yield code, traced_info, critic_out, "RUNNING IN PROGRESS...", "", "" | |
# refiner | |
critic_code = ('def execute_command' + critic_out.split('def execute_command')[1]).strip() | |
if '# Program is' in code: | |
critic_code = critic_code.split("# Program is")[0].strip() # errr, an awkward fix | |
prompt = f"# Given an image: {question}\n{critic_code}\n\n{traced_info}\n\n# Correction" | |
print("--- For debug: refiner prompt is ---") | |
print(prompt) | |
print("---\n") | |
refiner_out = generate("VDebugger/VDebugger-refiner-generalist-7B", prompt).strip() | |
yield code, traced_info, critic_out, refiner_out, "RUNNING IN PROGRESS...", "" | |
# execute (again) | |
result, error, traced = run_program_with_trace(refiner_out, image, INPUT_TYPE, OUTPUT_TYPE) | |
traced_info_2 = to_custom_trace(result, error, traced) | |
yield code, traced_info, critic_out, refiner_out, traced_info_2, answer_from_trace(traced_info_2) | |
def predict(image, question): | |
if image is None: | |
gr.Warning("Please provide an image", duration=5) | |
return | |
image = transforms.Compose([transforms.ToTensor()])(image) | |
question = question.strip() | |
if question == "": | |
gr.Warning("Please provide a question", duration=5) | |
return | |
# codellama | |
prompt = prompt_template.replace("INSERT_QUERY_HERE", f"Given an image: {question}\n{SIGNATURE}") | |
code = generate("codellama/CodeLlama-7b-Python-hf", prompt) | |
code = (SIGNATURE + code).strip() | |
yield code, "RUNNING IN PROGRESS...", "", "", "", "" | |
# execute | |
result, error, traced = run_program_with_trace(code, image, INPUT_TYPE, OUTPUT_TYPE) | |
traced_info = to_custom_trace(result, error, traced) | |
yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" | |
for tup in debug(image, question, code, traced_info): | |
yield tup | |
return | |
def re_debug(image, question, code, traced_info): | |
if code is None or code == "" or traced_info is None or traced_info == "": | |
gr.Warning("No prior debugging round", duration=5) | |
return | |
yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" | |
for tup in debug(image, question, code, traced_info): | |
yield tup | |
return | |
DESCRIPTION = """# VDebugger | |
| [Paper](https://arxiv.org/abs/2406.13444) | [Project](https://shirley-wu.github.io/vdebugger/) | [Code](https://github.com/shirley-wu/vdebugger/) | [Models and Data](https://huggingface.co/VDebugger) | | |
**VDebugger** is a novel critic-refiner framework trained to localize and debug *visual programs* by tracking execution step by step. In this demo, we show the visual programs, the outputs from both the critic and the refiner, as well as the final result. | |
**Warning:** Reduced performance and accuracy may be observed. Due to resource limitation of huggingface spaces, this demo runs Llama inference in 4-bit quantization and uses smaller foundation VLMs. For full capacity, please use the original code.""" | |
class MyInterface(gr.Interface): | |
def __init__(self): | |
super(gr.Interface, self).__init__( | |
title=None, | |
theme=None, | |
analytics_enabled=None, | |
mode="tabbed_interface", | |
css=None, | |
js=None, | |
head=None, | |
) | |
self.interface_type = InterfaceTypes.STANDARD | |
self.description = DESCRIPTION | |
self.cache_examples = None | |
self.examples_per_page = 5 | |
self.example_labels = None | |
self.batch = False | |
self.live = False | |
self.api_name = "predict" | |
self.max_batch_size = 4 | |
self.concurrency_limit = 'default' | |
self.show_progress = "full" | |
self.allow_flagging = 'auto' | |
self.flagging_options = [("Flag", ""), ] | |
self.flagging_callback = CSVLogger() | |
self.flagging_dir = 'flagged' | |
# Load examples | |
with open('examples/questions.json') as f: | |
example_questions = json.load(f) | |
self.examples = [] | |
for question in example_questions: | |
self.examples.append([ | |
Image.open('examples/{}.jpg'.format(question['imageId'])), question['question'], | |
]) | |
def load_random_example(): | |
image, question = random.choice(self.examples) | |
return image, question, "", "", "", "", "", "" | |
# Render the Gradio UI | |
with self: | |
self.render_title_description() | |
with gr.Row(): | |
image = gr.Image(label="Image", type="pil", width="30%", scale=1) | |
question = gr.Textbox(label="Question", scale=2) | |
with gr.Row(): | |
_clear_btn = gr.ClearButton(value="Clear", variant="secondary") | |
_random_eg_btn = gr.Button("Random Example Input") | |
_submit_btn = gr.Button("Submit", variant="primary") | |
if inspect.isgeneratorfunction(predict) or inspect.isasyncgenfunction(predict): | |
_stop1_btn = gr.Button("Stop", variant="stop", visible=False) | |
_redebug_btn = gr.Button("Debug for Another Round", variant="primary") | |
if inspect.isgeneratorfunction(re_debug) or inspect.isasyncgenfunction(re_debug): | |
_stop2_btn = gr.Button("Stop", variant="stop", visible=False) | |
with gr.Row(): | |
o1 = gr.Textbox(label="No debugging: program") | |
o2 = gr.Textbox(label="No debugging: execution") | |
with gr.Row(): | |
o3 = gr.Textbox(label="VDebugger: critic") | |
o4 = gr.Textbox(label="VDebugger: refiner") | |
with gr.Row(): | |
o5 = gr.Textbox(label="VDebugger: execution") | |
o6 = gr.Textbox(label="VDebugger: final answer") | |
question.submit(fn=predict, inputs=[image, question], outputs=[o1, o2, o3, o4, o5, o6]) | |
_random_eg_btn.click(fn=load_random_example, outputs=[image, question, o1, o2, o3, o4, o5, o6]) | |
async def cleanup(): | |
return [gr.Button(visible=True), gr.Button(visible=False)] | |
# Setup redebug event | |
triggers = [_redebug_btn.click, ] | |
extra_output = [_redebug_btn, _stop2_btn] | |
predict_event = gr.on( | |
triggers, | |
gr.utils.async_lambda( | |
lambda: ( | |
gr.Button(visible=False), | |
gr.Button(visible=True), | |
) | |
), | |
inputs=None, | |
outputs=[_redebug_btn, _stop2_btn], | |
queue=False, | |
show_api=False, | |
).then( | |
re_debug, | |
[image, question, o4, o5], | |
[o1, o2, o3, o4, o5, o6], | |
api_name=self.api_name, | |
scroll_to_output=False, | |
preprocess=not (self.api_mode), | |
postprocess=not (self.api_mode), | |
batch=self.batch, | |
max_batch_size=self.max_batch_size, | |
concurrency_limit=self.concurrency_limit, | |
show_progress=cast( | |
Literal["full", "minimal", "hidden"], self.show_progress | |
), | |
) | |
redebug_event = predict_event.then( | |
cleanup, | |
inputs=None, | |
outputs=extra_output, # type: ignore | |
queue=False, | |
show_api=False, | |
) | |
_stop2_btn.click( | |
cleanup, | |
inputs=None, | |
outputs=[_redebug_btn, _stop2_btn], | |
cancels=predict_event, | |
queue=False, | |
show_api=False, | |
) | |
# Setup submit event | |
triggers = [_submit_btn.click, question.submit, ] | |
extra_output = [_submit_btn, _stop1_btn] | |
predict_event = gr.on( | |
triggers, | |
gr.utils.async_lambda( | |
lambda: ( | |
gr.Button(visible=False), | |
gr.Button(visible=True), | |
) | |
), | |
inputs=None, | |
outputs=[_submit_btn, _stop1_btn], | |
queue=False, | |
show_api=False, | |
).then( | |
predict, | |
[image, question], | |
[o1, o2, o3, o4, o5, o6], | |
api_name=self.api_name, | |
scroll_to_output=False, | |
preprocess=not (self.api_mode), | |
postprocess=not (self.api_mode), | |
batch=self.batch, | |
max_batch_size=self.max_batch_size, | |
concurrency_limit=self.concurrency_limit, | |
show_progress=cast( | |
Literal["full", "minimal", "hidden"], self.show_progress | |
), | |
) | |
submit_event = predict_event.then( | |
cleanup, | |
inputs=None, | |
outputs=extra_output, # type: ignore | |
queue=False, | |
show_api=False, | |
) | |
_stop1_btn.click( | |
cleanup, | |
inputs=None, | |
outputs=[_submit_btn, _stop1_btn], | |
cancels=predict_event, | |
queue=False, | |
show_api=False, | |
) | |
# Finally borrow Interface stuff | |
self.input_components = [image, question] | |
self.output_components = [o1, o2, o3, o4, o5, o6] | |
self.fn = predict | |
self.attach_clear_events(_clear_btn, None) | |
self.render_examples() | |
if __name__ == "__main__": | |
MyInterface().launch(share=os.environ.get("SHARE", '') != "") | |