Spaces:
Running
Running
import gradio as gr | |
import os | |
import json | |
import logging | |
import numpy as np | |
from utils import (PromptTemplate, api_configs, setup_logging) | |
from data_loader import load_data | |
from evaluate import evaluate | |
from main import SwiftSage, run_test, run_benchmark | |
import multiprocessing | |
def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage): | |
# Configuration for each LLM | |
max_iterations = int(max_iterations) | |
reward_threshold = int(reward_threshold) | |
swift_config = { | |
"model_id": swift_model_id, | |
"api_config": api_configs['Together'] | |
} | |
reward_config = { | |
"model_id": reward_model_id, | |
"api_config": api_configs['Together'] | |
} | |
sage_config = { | |
"model_id": sage_model_id, | |
"api_config": api_configs['Together'] | |
} | |
# specify the path to the prompt templates | |
prompt_template_dir = './prompt_templates' | |
dataset = [] | |
embeddings = [] # TODO: for retrieval augmentation (not implemented yet now) | |
s2 = SwiftSage( | |
dataset, | |
embeddings, | |
prompt_template_dir, | |
swift_config, | |
sage_config, | |
reward_config, | |
use_retrieval=use_retrieval, | |
start_with_sage=start_with_sage, | |
) | |
reasoning, solution = s2.solve(problem, max_iterations, reward_threshold) | |
solution = solution.replace("Answer (from running the code):\n ", " ") | |
return reasoning, solution | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
# gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning") | |
# use the html and center the title | |
gr.HTML("<h1 style='text-align: center;'>SwiftSage: A Multi-Agent Framework for Reasoning</h1>") | |
with gr.Row(): | |
swift_model_id = gr.Textbox(label="π Swift Model ID", value="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo") | |
reward_model_id = gr.Textbox(label="π€ Feedback Model ID", value="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo") | |
sage_model_id = gr.Textbox(label="π Sage Model ID", value="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo") | |
# the following two should have a smaller width | |
with gr.Accordion(label="βοΈ Advanced Options", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
max_iterations = gr.Textbox(label="Max Iterations", value="5") | |
reward_threshold = gr.Textbox(label="Reward Threshold", value="8") | |
# TODO: add top-p and temperature for each module for controlling | |
with gr.Column(): | |
top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9") | |
temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.7") | |
with gr.Column(): | |
top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9") | |
temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.7") | |
with gr.Column(): | |
top_p_reward = gr.Textbox(label="Top-p for Feedback", value="0.9") | |
temperature_reward = gr.Textbox(label="Temperature for Feedback", value="0.7") | |
use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False) | |
start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False) | |
problem = gr.Textbox(label="Input your problem", value="How many letter r are there in the sentence 'My strawberry is so ridiculously red.'?", lines=2) | |
solve_button = gr.Button("π Solve Problem") | |
reasoning_output = gr.Textbox(label="Reasoning steps with Code", interactive=False) | |
solution_output = gr.Textbox(label="Final answer", interactive=False) | |
solve_button.click( | |
solve_problem, | |
inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage], | |
outputs=[reasoning_output, solution_output] | |
) | |
if __name__ == '__main__': | |
multiprocessing.set_start_method('spawn') | |
demo.launch(share=False, show_api=False) | |