import gradio as gr from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor, TimeoutError import os import re import subprocess import tempfile import json import datasets import random import time from typing import Tuple, Dict, Any, List from sympy import N, simplify from sympy.parsing.latex import parse_latex from openai import OpenAI import base64 client = OpenAI( base_url=os.environ.get("SERVER_URL"), api_key=os.environ.get("HF_TOKEN"), ) @dataclass class Config: model_id: str # SELECT MODEL revision: str # SELECT REVISION # Append an optional system prompt to each problem system_prompt: str # Number of samples to generate per problem num_samples: int num_generations: int # Generation parameters do_sample: bool temperature: float top_p: float top_k: int max_new_tokens: int restart_on_fail: bool # Enable 4-bit quantization is_quantized: bool # Run on train or test data? is_submission: bool = True if os.getenv("KAGGLE_IS_COMPETITION_RERUN") else False validation_set: str = "kaggle-validation-set-medium" notebook_time_limit: int = 9 * 60 * 60 - 15 * 60 # 9 hours - 15 minute buffer # Debug by solving only the first problem debug: bool = False # Push solutions to the Hub push_to_hub: bool = False class PythonREPL: def __init__(self, timeout=5): self.timeout = timeout def execute(self, query: str) -> Tuple[bool, str]: query = "import math\nimport numpy as np\nimport sympy as sp\n" + query query = query.strip().split("\n") if "print(" not in query[-1]: if "#" in query[-1]: query[-1] = query[-1].split("#")[0] query[-1] = "print(" + query[-1] + ")" query = "\n".join(query) with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, "tmp.py") with open(temp_file_path, "w") as f: f.write(query) result = subprocess.run( ["python3", temp_file_path], capture_output=True, check=False, text=True, timeout=self.timeout, ) if result.returncode == 0: output = result.stdout return True, output.strip() else: error_msg = result.stderr.strip() msgs = error_msg.split("\n") new_msgs = [] want_next = False for m in msgs: if "Traceback" in m: new_msgs.append(m) elif m == msgs[-1]: new_msgs.append(m) elif temp_file_path in m: st = m.index('"/') + 1 if '"/' in m else 0 ed = m.index(temp_file_path) + 1 if temp_file_path in m else None clr = m[st:ed] if not ed else m[st:] m = m.replace(clr, "") new_msgs.append(m) want_next = True elif want_next: new_msgs.append(m) want_next = False error_msg = "\n".join(new_msgs) return False, error_msg.strip() def __call__(self, query: str) -> Tuple[bool, str]: with ThreadPoolExecutor() as executor: future = executor.submit(self.execute, query) try: return future.result(timeout=self.timeout) except TimeoutError: return False, f"Timed out after {self.timeout} seconds." def execute_completion( executor: PythonREPL, completion: str, return_status: bool = False, last_code_block: bool = False, ) -> str | Tuple[str, bool]: # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code] executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) if len(executions) == 0: # directly return cot result return completion, False if return_status else completion else: if last_code_block: executions = [executions[-1]] # Python execution_outputs = [] successes = [] for code in executions: success = False if "subprocess" in code: output = "subprocess is not allowed" execution_outputs.append(output) successes.append(success) continue if "venv" in code: output = "venv is not allowed" execution_outputs.append(output) successes.append(success) continue try: success, output = executor(code) except TimeoutError as e: print("time out") output = e if not success and not return_status: output = "" execution_outputs.append(output) successes.append(success) output = str(execution_outputs[-1]).strip() success = successes[-1] if return_status: return output, success else: return output def postprocess_completion( text: str, return_status: bool = False, last_code_block=False, timeout=5 ) -> str | Tuple[str, bool]: executor = PythonREPL(timeout=timeout) result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) del executor return result def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: return prompt.format(example["prompt"], "{}") def last_boxed_only_string(string): """ Extracts the last LaTeX boxed or framed expression from a string. Args: string (str): The input string containing LaTeX expressions. Returns: str or None: The last boxed or framed expression, if found; otherwise, None. """ idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx : right_brace_idx + 1] return retval def remove_boxed(s): """ Removes the LaTeX boxed command, returning the content inside the braces. Args: s (str): The string containing a LaTeX boxed expression. Returns: str or None: The content inside the boxed command, if valid; otherwise, None. """ left = "\\boxed{" try: assert s[: len(left)] == left assert s[-1] == "}" length = len(left) return s[length:-1] except Exception: return None def extract_boxed_answer(pred_str, strip_double_curly_brace=False): """ Extracts the answer from a LaTeX boxed expression within a prediction string. Args: pred_str (str): The string containing one or more LaTeX boxed expressions. strip_double_curly_brace (bool): If True, removes an additional layer of braces. Returns: str or None: The extracted answer, if any; otherwise, None. """ boxed_str = last_boxed_only_string(pred_str) if boxed_str is None: return None answer = remove_boxed(boxed_str) if answer is None: return None if strip_double_curly_brace: match = re.match("^\{(.*)\}$", answer) # noqa: W605 if match: answer = match.group(1) return answer def normalize_final_answer(final_answer: str) -> str: """ Normalizes a final answer string by removing or replacing various LaTeX and text elements. Args: final_answer (str): The answer string to normalize. Returns: str: The normalized answer string. """ match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) if match: final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本 """Normalize a final answer to a quantitative reasoning question.""" # final_answer = final_answer.split('=')[-1] SUBSTITUTIONS = [ ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), (r"\ ", ""), (" ", ""), ("mbox", "text"), (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"), ("\\le", "<"), ] REMOVED_EXPRESSIONS = [ "square", "ways", "integers", "dollars", "mph", "inches", "ft", "hours", "km", "units", "\\ldots", "sue", "points", "feet", "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds", "meters", "meals", "edges", "students", "childrentickets", "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", "\\text{}^2", "\\text{}^3", "\\text{\n}", "\\text{}", r"\mathrm{th}", r"^\circ", r"^{\circ}", r"\;", r",\!", "{,}", '"', "\\dots", "\n", "\r", "\f", "\%", ] for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, "") # Extract answer that is in LaTeX math, is bold, # is surrounded by a box, etc. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) assert "\n" not in final_answer assert "\r" not in final_answer assert "\f" not in final_answer if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] final_answer = final_answer.strip() if "rac" in final_answer and "\\frac" not in final_answer: final_answer = final_answer.replace("rac", "\\frac") final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) final_answer = final_answer.replace("$", "") if final_answer.replace(",", "").isdigit(): final_answer = final_answer.replace(",", "") return final_answer def naive_parse(answer: str) -> str: """ Extracts and returns the numeric digits from the input string, processing them in reverse order until a non-numeric character is encountered after encountering the first numeric character. Args: answer (str): The input string to parse. Returns: str: A string consisting of the numeric digits extracted from the input, in their original order. Example: >>> naive_parse("abc123def") '123' >>> naive_parse("def456ghi") '456' >>> naive_parse("no numbers here") '' """ out = [] start = False end = False for l in reversed(list(answer)): if l in "0123456789" and not end: start = True out.append(l) else: if start: end = True out = reversed(out) return "".join(out) def validate_answer_is_numeric(x: str | int | float) -> int: FLOAT_TOLERANCE = 0.2 try: x = round(float(x)) f = float(x) if abs(x - f) > FLOAT_TOLERANCE: x = -1 except Exception: x = -1 return x def filter_answers(answers: List[str]) -> List[int]: formatted_answers = [validate_answer_is_numeric(a) for a in answers] # Filter for non-negative answers formatted_answers = [a for a in formatted_answers if a >= 0] # Compute modulo formatted_answers = [a % 1_000 for a in formatted_answers] # less than 2.1 billion or cannot convert to C int (32-bit) formatted_answers = [a for a in formatted_answers if a <= 999] return formatted_answers def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: def do_answers_match(ref_answer: str, model_answer: str) -> bool: ref_sympy = parse_latex(ref_answer) model_sympy = parse_latex(model_answer) diff = simplify(ref_sympy - model_sympy) return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False try: result = do_answers_match(ref_answer, model_answer) return result except Exception as e: print(e) return False def check_string_match(ref_answer: str, model_answer: str) -> bool: try: return ref_answer == model_answer except Exception as e: print(e) return False def check_answer(ref_answer: str, model_answer: str) -> bool: # check if strings are the same correct = check_string_match(ref_answer, model_answer) if correct: return True # use the sympy library to check if the expressions are the same correct = check_sympy_equivalence(ref_answer, model_answer) if correct: return True return False debug = False model_id = "Numina-Math-7B" revision = "main" system_prompt = "{}" validation_set = "kaggle-validation-set-medium" is_submission = True num_samples = 4 num_generations = 4 temperature = 0.8 is_quantized = False restart_on_fail = False top_p = 1.0 top_k = 0 max_new_tokens = 2048 # Papermill related variables push_to_hub = False notebook_name = "" config = Config( debug=debug, push_to_hub=push_to_hub, model_id=model_id, revision=revision, system_prompt=system_prompt, validation_set=validation_set, is_quantized=is_quantized, restart_on_fail=restart_on_fail, is_submission=is_submission, num_samples=num_samples, num_generations=num_generations, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, ) print(f"=== Running submission with config ===\n\n{config}") def generate(message, temperature): """ Generates a chat completion response by streaming data from the client chat model. This function streams the response from the client chat model and yields the content of the response chunk by chunk. If an error occurs, it yields the error message. Parameters: message (str): The input message to be sent to the chat model. temperature (float): The sampling temperature to use. Higher values mean the model will take more risks. Yields: tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred. If no error occurred, the boolean flag will be False and the content will be the response text. If an error occurred, the boolean flag will be True and the content will be the error message. """ stream = client.chat.completions.create( model="tgi", messages=message, stream=True, max_tokens=1024, stop=["```output\n"], temperature=temperature, timeout=30, ) response = stream.response # The reason why the library method is not used here is that if an error occurs, # the returned data will not be a stream, and using the official library will result in an error. for chunk in response.iter_bytes(): chunk = chunk.decode("utf-8") chune_json = json.loads(chunk.replace("data:", "")) try: if "error" in chune_json and chune_json["error"]: yield chune_json["error"], True break content = chune_json["choices"][0]["delta"]["content"] if content is not None: yield content, False except Exception as e: print(f"func: generate error occurred\njson:{chune_json}\nerror:{e}") yield "", True def get_majority_text(data): from collections import Counter # Count the frequency of each answer in model_answers answer_counts = Counter(data["model_answers"]) # Find the majority response majority_response = answer_counts.most_common(1)[0][0] # Find the index of the first occurrence of the majority response majority_index = data["model_answers"].index(majority_response) # Return the corresponding text in gen_texts return data["gen_texts"][majority_index] def extract_solution(text): # Split the text at "### Solution:" parts = text.split("### Solution:", 1) if len(parts) > 1: # Return everything after "### Solution:" return parts[1].strip() else: # Return an empty string if "### Solution:" is not found return "" def process_code( example: Dict[str, Any], config: Config, restart_on_fail: bool = False, last_step: bool = False, ) -> Dict[str, Any]: gen_text = example["gen_texts"] num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) if num_python_blocks == 0: if restart_on_fail: print("no code has ever been generated, RESTARTING") # reset the text to the original example["gen_texts"] = example["text"] else: print("no code has ever been generated, STOP") example["should_prune"] = True example["has_code"] = False return example if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) if num_output_blocks == 0: print("the model hallucinated the code answer") example["should_prune"] = True return example if "boxed" in gen_text[-100:]: try: answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) except Exception: answer = "-1" else: answer = normalize_final_answer(gen_text[-100:]) example["model_answers"] = answer if not config.is_submission: example["corrects"] = check_answer(example["ground_truth"], answer) example["should_prune"] = True print("Answer is: ", answer, example["ground_truth"], example["corrects"]) return example if last_step: # no point in continuing if we are at the last step return example if gen_text[-10:] != "```output\n": # something else has gone wrong with the generation print("warning: output block not found: ", gen_text[-40:]) if restart_on_fail: example["gen_texts"] = example["text"] else: example["should_prune"] = True return example code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) # add the code result for the next round of generation TRUNCATION_LIMIT = 200 if len(code_result) > TRUNCATION_LIMIT: code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" example["gen_texts"] = gen_text + f"{code_result}\n```" return example def solve_problem(problem, temperature, progress=gr.Progress()): """ yield token: string, stop: bool """ problem = apply_template({"prompt": problem}, prompt=config.system_prompt) print(f"Problem: {problem}") sample = { "problem": problem, # not used for the submission TODO Remove "ground_truth": "unknown", # not used for the submission TODO Remove "text": "## Solution:\n", "gen_texts": "## Solution:\n", # used to store all the generated text "should_prune": False, "problem_index": -1, # not used for the submission TODO Remove "model_answers": "-1", "has_code": True, "corrects": False, # not used for the submission TODO Remove } for step in progress.tqdm( range(config.num_generations), desc="Generating candidates" ): # Depth of the tree (e.g. 6 steps = 5 code blocks) step_reponse = sample["gen_texts"] messages = [ {"role": "user", "content": sample["problem"]}, {"role": "assistant", "content": sample["gen_texts"]}, ] for reponse_message, error in generate(messages, temperature): if reponse_message is not None: step_reponse += reponse_message yield step_reponse, False if error: yield step_reponse, True return sample["gen_texts"] = step_reponse # TODO: Maybe it should just return the result of running the code sample = process_code( sample, config=config, restart_on_fail=config.restart_on_fail, last_step=(step == (config.num_generations - 1)), ) sample["gen_texts"] = sample["gen_texts"] + "\n" run_code_reponse = sample["gen_texts"].replace(step_reponse, "") for output_mseeage in run_code_reponse: if output_mseeage is not None: step_reponse += output_mseeage yield step_reponse, False if sample["should_prune"]: break yield sample["gen_texts"], True example_data = datasets.load_dataset( "AI-MO/kaggle-validation-set-medium-extended", split="train", use_auth_token=os.environ.get("HF_DATASET_TOKEN", None), ) with open("app.css", "r") as f: css = f.read() latex_delimiters = [ {"left": "[", "right": "]", "display": True}, ] def get_random_problem(): example = random.choice(list(example_data)) problem = example["problem"] return problem def update_example_problem(): problem_example_text = get_random_problem() return problem_example_text, problem_example_text def clear(): problem_example_text = get_random_problem() return "", 0.1, "", problem_example_text, problem_example_text def preprocess_output(text): return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)") with gr.Blocks(css=css, title="Math Olympiad Solver") as demo: running_done = False btn_list = [] problem_input_ele_list = [] problem_example_text = get_random_problem() with gr.Row(elem_classes="title"): gr.HTML("Math Olympiad Solver", elem_classes="title-content") with gr.Row(elem_classes="sub-title"): gr.HTML( "