import gradio as gr from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor, TimeoutError import os import re import subprocess import tempfile import json import datasets import random import time from typing import Tuple, Dict, Any, List from sympy import N, simplify from sympy.parsing.latex import parse_latex import openai import base64 # Initialize OpenAI client to use local API openai.api_base = os.environ.get("SERVER_URL", "http://0.0.0.0:6061") openai.api_key = os.environ.get("HF_TOKEN", "") # If no key needed, set empty string @dataclass class Config: model_id: str # SELECT MODEL revision: str # SELECT REVISION # Append an optional system prompt to each problem system_prompt: str # Number of samples to generate per problem num_samples: int num_generations: int # Generation parameters do_sample: bool temperature: float top_p: float top_k: int max_new_tokens: int restart_on_fail: bool # Enable 4-bit quantization is_quantized: bool # Run on train or test data? is_submission: bool = True if os.getenv("KAGGLE_IS_COMPETITION_RERUN") else False validation_set: str = "kaggle-validation-set-medium" notebook_time_limit: int = 9 * 60 * 60 - 15 * 60 # 9 hours - 15 minute buffer # Debug by solving only the first problem debug: bool = False # Push solutions to the Hub push_to_hub: bool = False class PythonREPL: def __init__(self, timeout=5): self.timeout = timeout def execute(self, query: str) -> Tuple[bool, str]: query = "import math\nimport numpy as np\nimport sympy as sp\n" + query query = query.strip().split("\n") if "print(" not in query[-1]: if "#" in query[-1]: query[-1] = query[-1].split("#")[0] query[-1] = "print(" + query[-1] + ")" query = "\n".join(query) with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, "tmp.py") with open(temp_file_path, "w") as f: f.write(query) try: result = subprocess.run( ["python3", temp_file_path], capture_output=True, check=False, text=True, timeout=self.timeout, ) except subprocess.TimeoutExpired: return False, f"Timed out after {self.timeout} seconds." if result.returncode == 0: output = result.stdout return True, output.strip() else: error_msg = result.stderr.strip() msgs = error_msg.split("\n") new_msgs = [] want_next = False for m in msgs: if "Traceback" in m: new_msgs.append(m) elif m == msgs[-1]: new_msgs.append(m) elif temp_file_path in m: st = m.index('"/') + 1 if '"/' in m else 0 ed = m.index(temp_file_path) + 1 if temp_file_path in m else None clr = m[st:ed] if not ed else m[st:] m = m.replace(clr, "") new_msgs.append(m) want_next = True elif want_next: new_msgs.append(m) want_next = False error_msg = "\n".join(new_msgs) return False, error_msg.strip() def __call__(self, query: str) -> Tuple[bool, str]: with ThreadPoolExecutor() as executor: future = executor.submit(self.execute, query) try: return future.result(timeout=self.timeout) except TimeoutError: return False, f"Timed out after {self.timeout} seconds." def execute_completion( executor: PythonREPL, completion: str, return_status: bool = False, last_code_block: bool = False, ) -> str | Tuple[str, bool]: # Extract python code blocks enclosed in triple backticks with language 'python' executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) if len(executions) == 0: # directly return COT result return completion, False if return_status else completion else: if last_code_block: executions = [executions[-1]] # Python execution_outputs = [] successes = [] for code in executions: success = False if "subprocess" in code: output = "subprocess is not allowed" execution_outputs.append(output) successes.append(success) continue if "venv" in code: output = "venv is not allowed" execution_outputs.append(output) successes.append(success) continue try: success, output = executor(code) except TimeoutError as e: print("time out") output = str(e) if not success and not return_status: output = "" execution_outputs.append(output) successes.append(success) output = str(execution_outputs[-1]).strip() success = successes[-1] if return_status: return output, success else: return output def postprocess_completion( text: str, return_status: bool = False, last_code_block=False, timeout=5 ) -> str | Tuple[str, bool]: executor = PythonREPL(timeout=timeout) result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) del executor return result def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: return prompt.format(example["prompt"], "{}") def last_boxed_only_string(string): """ Extracts the last LaTeX boxed or framed expression from a string. Args: string (str): The input string containing LaTeX expressions. Returns: str or None: The last boxed or framed expression, if found; otherwise, None. """ idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx : right_brace_idx + 1] return retval def remove_boxed(s): """ Removes the LaTeX boxed command, returning the content inside the braces. Args: s (str): The string containing a LaTeX boxed expression. Returns: str or None: The content inside the boxed command, if valid; otherwise, None. """ left = "\\boxed{" try: assert s[: len(left)] == left assert s[-1] == "}" length = len(left) return s[length:-1] except Exception: return None def extract_boxed_answer(pred_str, strip_double_curly_brace=False): """ Extracts the answer from a LaTeX boxed expression within a prediction string. Args: pred_str (str): The string containing one or more LaTeX boxed expressions. strip_double_curly_brace (bool): If True, removes an additional layer of braces. Returns: str or None: The extracted answer, if any; otherwise, None. """ boxed_str = last_boxed_only_string(pred_str) if boxed_str is None: return None answer = remove_boxed(boxed_str) if answer is None: return None if strip_double_curly_brace: match = re.match(r"^\{(.*)\}$", answer) # noqa: W605 if match: answer = match.group(1) return answer def normalize_final_answer(final_answer: str) -> str: """ Normalizes a final answer string by removing or replacing various LaTeX and text elements. Args: final_answer (str): The answer string to normalize. Returns: str: The normalized answer string. """ match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) if match: final_answer = match.group(1) # Return all text before 'Problem' """Normalize a final answer to a quantitative reasoning question.""" SUBSTITUTIONS = [ ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), (r"\ ", ""), (" ", ""), ("mbox", "text"), (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"), ("\\le", "<"), ] REMOVED_EXPRESSIONS = [ "square", "ways", "integers", "dollars", "mph", "inches", "ft", "hours", "km", "units", "\\ldots", "sue", "points", "feet", "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds", "meters", "meals", "edges", "students", "childrentickets", "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", "\\text{}^2", "\\text{}^3", "\\text{\n}", "\\text{}", r"\mathrm{th}", r"^\circ", r"^{\circ}", r"\;", r",\!", "{,}", '"', "\\dots", "\n", "\r", "\f", "\%", ] for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, "") # Extract answer that is in LaTeX math, is bold, # is surrounded by a box, etc. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) assert "\n" not in final_answer assert "\r" not in final_answer assert "\f" not in final_answer if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] final_answer = final_answer.strip() if "rac" in final_answer and "\\frac" not in final_answer: final_answer = final_answer.replace("rac", "\\frac") final_answer = re.sub(r"(frac)([^{])(.)", r"frac{\2}{\3}", final_answer) final_answer = re.sub(r"(sqrt)([^{])", r"sqrt{\2}", final_answer) final_answer = final_answer.replace("$", "") if final_answer.replace(",", "").isdigit(): final_answer = final_answer.replace(",", "") return final_answer def naive_parse(answer: str) -> str: """ Extracts and returns the numeric digits from the input string, processing them in reverse order until a non-numeric character is encountered after encountering the first numeric character. Args: answer (str): The input string to parse. Returns: str: A string consisting of the numeric digits extracted from the input, in their original order. Example: >>> naive_parse("abc123def") '123' >>> naive_parse("def456ghi") '456' >>> naive_parse("no numbers here") '' """ out = [] start = False end = False for l in reversed(list(answer)): if l in "0123456789" and not end: start = True out.append(l) else: if start: end = True out = reversed(out) return "".join(out) def validate_answer_is_numeric(x: str | int | float) -> int: FLOAT_TOLERANCE = 0.2 try: x = round(float(x)) f = float(x) if abs(x - f) > FLOAT_TOLERANCE: x = -1 except Exception: x = -1 return x def filter_answers(answers: List[str]) -> List[int]: formatted_answers = [validate_answer_is_numeric(a) for a in answers] # Filter for non-negative answers formatted_answers = [a for a in formatted_answers if a >= 0] # Compute modulo formatted_answers = [a % 1_000 for a in formatted_answers] # less than 2.1 billion or cannot convert to C int (32-bit) formatted_answers = [a for a in formatted_answers if a <= 999] return formatted_answers def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: def do_answers_match(ref_answer: str, model_answer: str) -> bool: ref_sympy = parse_latex(ref_answer) model_sympy = parse_latex(model_answer) diff = simplify(ref_sympy - model_sympy) return True if (-1e-12 < N(diff) < 1e-12) or diff.is_zero else False try: result = do_answers_match(ref_answer, model_answer) return result except Exception as e: print(e) return False def check_string_match(ref_answer: str, model_answer: str) -> bool: try: return ref_answer == model_answer except Exception as e: print(e) return False def check_answer(ref_answer: str, model_answer: str) -> bool: # check if strings are the same correct = check_string_match(ref_answer, model_answer) if correct: return True # use the sympy library to check if the expressions are the same correct = check_sympy_equivalence(ref_answer, model_answer) if correct: return True return False # Configuration Parameters debug = False model_id = "qwen2-7b-math-q8_0" # Update model ID revision = "main" system_prompt = "{}" validation_set = "kaggle-validation-set-medium" is_submission = True num_samples = 4 num_generations = 4 temperature = 0.8 is_quantized = False restart_on_fail = False top_p = 1.0 top_k = 0 max_new_tokens = 2048 # Papermill related variables push_to_hub = False notebook_name = "" config = Config( debug=debug, push_to_hub=push_to_hub, model_id=model_id, revision=revision, system_prompt=system_prompt, validation_set=validation_set, is_quantized=is_quantized, restart_on_fail=restart_on_fail, is_submission=is_submission, num_samples=num_samples, num_generations=num_generations, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, ) print(f"=== Running submission with config ===\n\n{config}") def generate(messages, temperature): """ Generates a chat completion response by streaming data from the client chat model. This function streams the response from the client chat model and yields the content of the response chunk by chunk. If an error occurs, it yields the error message. Parameters: messages (list of dict): The list of message dicts for the chat model. temperature (float): The sampling temperature to use. Yields: tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred. If no error occurred, the boolean flag will be False and the content will be the response text. If an error occurred, the boolean flag will be True and the content will be the error message. """ try: response = openai.ChatCompletion.create( model=config.model_id, messages=messages, stream=True, max_tokens=1024, temperature=temperature, ) except Exception as e: yield str(e), True return for chunk in response: if 'choices' in chunk: choice = chunk['choices'][0] if 'delta' in choice: content = choice['delta'].get('content', '') if content: yield content, False if choice.get('finish_reason') is not None: break elif 'error' in chunk: yield chunk['error']['message'], True break def get_majority_text(data): from collections import Counter # Count the frequency of each answer in model_answers answer_counts = Counter(data["model_answers"]) # Find the majority response majority_response = answer_counts.most_common(1)[0][0] # Find the index of the first occurrence of the majority response majority_index = data["model_answers"].index(majority_response) # Return the corresponding text in gen_texts return data["gen_texts"][majority_index] def extract_solution(text): # Split the text at "### Solution:" parts = text.split("### Solution:", 1) if len(parts) > 1: # Return everything after "### Solution:" return parts[1].strip() else: # Return an empty string if "### Solution:" is not found return "" def process_code( example: Dict[str, Any], config: Config, restart_on_fail: bool = False, last_step: bool = False, ) -> Dict[str, Any]: gen_text = example["gen_texts"] num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) if num_python_blocks == 0: if restart_on_fail: print("No code has been generated. Restarting generation.") # Reset the text to the original example["gen_texts"] = "## Solution:\n" else: print("No code has been generated. Stopping.") example["should_prune"] = True example["has_code"] = False return example if not gen_text.endswith("```output\n") and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) if num_output_blocks == 0: print("The model hallucinated the code answer.") example["should_prune"] = True return example if "boxed" in gen_text[-100:]: try: answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) except Exception: answer = "-1" else: answer = normalize_final_answer(gen_text[-100:]) example["model_answers"] = answer if not config.is_submission: example["corrects"] = check_answer(example["ground_truth"], answer) example["should_prune"] = True print("Answer is: ", answer, example["ground_truth"], example["corrects"]) return example if last_step: # No point in continuing if we are at the last step return example if not gen_text.endswith("```output\n"): # Something else has gone wrong with the generation print("Warning: Output block not found: ", gen_text[-40:]) if restart_on_fail: example["gen_texts"] = "## Solution:\n" else: example["should_prune"] = True return example code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) # Add the code result for the next round of generation TRUNCATION_LIMIT = 200 if len(code_result) > TRUNCATION_LIMIT: code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" example["gen_texts"] = gen_text + f"```\n{code_result}\n```\n" return example def solve_problem(problem, temperature, progress=gr.Progress()): """ yield token: string, stop: bool """ # Apply the system prompt template problem_formatted = config.system_prompt.format(problem) print(f"Problem: {problem_formatted}") sample = { "problem": problem_formatted, "ground_truth": "unknown", "text": "## Solution:\n", "gen_texts": "## Solution:\n", "should_prune": False, "problem_index": -1, "model_answers": "-1", "has_code": True, "corrects": False, } for step in progress.tqdm( range(config.num_generations), desc="Generating candidates" ): step_response = sample["gen_texts"] messages = [ {"role": "system", "content": config.system_prompt.format(problem)}, {"role": "user", "content": sample["gen_texts"]}, ] for response_message, error in generate(messages, temperature): if response_message: step_response += response_message yield preprocess_output(step_response) if error: yield step_response, True return sample["gen_texts"] = step_response # Process the generated code sample = process_code( sample, config=config, restart_on_fail=config.restart_on_fail, last_step=(step == (config.num_generations - 1)), ) sample["gen_texts"] = sample["gen_texts"] + "\n" # Extract any run code response run_code_response = sample["gen_texts"].replace(step_response, "") # Append the run code response if it exists if run_code_response.strip(): step_response += run_code_response yield preprocess_output(run_code_response) if sample["should_prune"]: break yield sample["gen_texts"], True # Load the dataset example_data = datasets.load_dataset( "AI-MO/kaggle-validation-set-medium-extended", split="train", use_auth_token=os.environ.get("HF_DATASET_TOKEN", None), ) # Load CSS if available css = "" if os.path.exists("app.css"): with open("app.css", "r") as f: css = f.read() latex_delimiters = [ {"left": "[", "right": "]", "display": True}, ] def get_random_problem(): example = random.choice(list(example_data)) problem = example["problem"] return problem def update_example_problem(): problem_example_text = get_random_problem() return problem_example_text, problem_example_text def clear(): problem_example_text = get_random_problem() return "", 0.1, "", problem_example_text, problem_example_text def preprocess_output(text): return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)") with gr.Blocks(css=css, title="Math Olympiad Solver") as demo: running_done = False btn_list = [] problem_input_ele_list = [] problem_example_text = get_random_problem() with gr.Row(elem_classes="title"): gr.HTML("Math Olympiad Solver", elem_classes="title-content") with gr.Row(elem_classes="sub-title"): gr.HTML( "
Demo of the qwen2-7b-math-q8_0. Example data are drawn randomly from AMC12, year 2022-2023.
", elem_classes="sub-title-content", ) with gr.Row(elem_classes="main-area"): with gr.Column(scale=1, elem_classes="left"): with gr.Row(elem_classes="problem-example-container"): with gr.Blocks(elem_classes="problem-example-title"): gr.HTML("Problem Example", elem_classes="problem-example-title-content") with gr.Blocks(elem_classes="action-container"): another_btn = gr.Button( "Another Problem", elem_classes="problem-example-another", # Removed icon path to prevent errors ) copy_btn = gr.Button("Copy", elem_classes="problem-example-copy") problem_example = gr.HTML( problem_example_text, elem_classes="problem-example-content", ) with gr.Row(elem_classes="problem-input-container"): inp = gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5) problem_markdown = gr.Markdown( visible=False, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": r"\(", "right": r"\)", "display": False}, ], ) inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown]) problem_input_ele_list.extend([inp, problem_markdown]) with gr.Accordion("Advanced Options", open=False): temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.1, label="Temperature") with gr.Row() as btn_area: btn_clear = gr.Button("Clear", elem_classes="clear-btn") btn_run = gr.Button("Run", elem_classes="run-btn") btn_list.extend([btn_clear, btn_run]) with gr.Column(scale=1, elem_classes="right"): gr.HTML("Solution", elem_classes="solution-title-content") out = gr.Markdown( elem_classes="solution-content", latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": r"\(", "right": r"\)", "display": False}, ], ) problem_example_text_hidden = gr.Markdown(value=problem_example_text, visible=False) def solve_problem_wrapper(inp_text, temperature): global running_done try: for after_tokens, stop in solve_problem(inp_text, temperature): yield preprocess_output(after_tokens) if stop: running_done = True except Exception as e: running_done = True yield str(e) def mount_run_btn(btn): btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature_slider], outputs=out) btn.click(get_running_btns, None, outputs=btn_list) btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list) def get_run_after_problem_input(): return gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5, visible=False), gr.Markdown( visible=True, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, ], elem_classes="problem-input-markdown", ) def get_init_problem_input(): return gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5, visible=True), gr.Markdown( visible=False, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, ], ) def get_running_btns(): global running_done btn_clear = gr.Button("Clear") btn_run = gr.Button("", elem_classes="run-btn running-btn") yield [btn_clear, btn_run] time.sleep(3) btn_clear = gr.Button("Clear") btn_run = gr.Button("Run", elem_classes="run-btn") while True: if running_done: running_done = False yield [btn_clear, btn_run] time.sleep(1) mount_run_btn(btn_run) break time.sleep(1) copy_btn.click(fn=lambda _: gr.update(value=problem_example_text, interactive=True), inputs=None, outputs=inp) btn_clear.click( fn=clear, inputs=[], outputs=[ inp, temperature_slider, out, problem_example, problem_example_text_hidden, ], ) btn_clear.click(get_init_problem_input, None, outputs=problem_input_ele_list) mount_run_btn(btn_run) demo.load( update_example_problem, inputs=None, outputs=[ problem_example, problem_example_text_hidden, ], ) another_btn.click( fn=update_example_problem, inputs=[], outputs=[ problem_example, problem_example_text_hidden, ], ) if __name__ == "__main__": demo.queue(default_concurrency_limit=5).launch(share=True)