MobiLlama / fastchat /llm_judge /gen_judgment.py
Ashmal's picture
Upload folder using huggingface_hub
5472531 verified
"""
Usage:
python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all]
"""
import argparse
from concurrent.futures import ThreadPoolExecutor
import json
import numpy as np
from tqdm import tqdm
from fastchat.llm_judge.common import (
load_questions,
load_model_answers,
load_judge_prompts,
check_data,
play_a_match_pair,
play_a_match_single,
get_model_list,
Judge,
MatchPair,
MatchSingle,
NEED_REF_CATS,
)
def make_match(
questions,
models,
model_answers,
judge,
baseline_model,
ref_answers=None,
multi_turn=False,
):
matches = []
for q in questions:
if multi_turn and len(q["turns"]) != 2:
continue
for i in range(len(models)):
q_id = q["question_id"]
m_1 = models[i]
m_2 = baseline_model
if m_1 == m_2:
continue
a_1 = model_answers[m_1][q_id]
a_2 = model_answers[baseline_model][q_id]
if ref_answers is not None:
ref = ref_answers[judge.model_name][q_id]
match = MatchPair(
dict(q),
m_1,
m_2,
a_1,
a_2,
judge,
ref_answer=ref,
multi_turn=multi_turn,
)
else:
match = MatchPair(
dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn
)
matches.append(match)
return matches
def make_match_all_pairs(
questions,
models,
model_answers,
judge,
baseline_model=None,
ref_answers=None,
multi_turn=False,
):
matches = []
for q in questions:
if multi_turn and len(q["turns"]) != 2:
continue
for i in range(len(models)):
for j in range(i + 1, len(models)):
q_id = q["question_id"]
m_1 = models[i]
m_2 = models[j]
a_1 = model_answers[m_1][q_id]
a_2 = model_answers[m_2][q_id]
if ref_answers is not None:
ref = ref_answers[judge.model_name][q_id]
match = MatchPair(
dict(q),
m_1,
m_2,
a_1,
a_2,
judge,
ref_answer=ref,
multi_turn=multi_turn,
)
else:
match = MatchPair(
dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn
)
matches.append(match)
return matches
def make_match_single(
questions,
models,
model_answers,
judge,
baseline_model=None,
ref_answers=None,
multi_turn=False,
):
matches = []
for q in questions:
if multi_turn and len(q["turns"]) != 2:
continue
for i in range(len(models)):
q_id = q["question_id"]
m = models[i]
a = model_answers[m][q_id]
if ref_answers is not None:
ref = ref_answers[judge.model_name][q_id]
matches.append(
MatchSingle(
dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn
)
)
else:
matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn))
return matches
def make_judge_pairwise(judge_model, judge_prompts):
judges = {}
judges["default"] = Judge(judge_model, judge_prompts["pair-v2"])
judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True)
judges["default-mt"] = Judge(
judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True
)
judges["math-mt"] = Judge(
judge_model,
judge_prompts["pair-math-v1-multi-turn"],
ref_based=True,
multi_turn=True,
)
return judges
def make_judge_single(judge_model, judge_prompts):
judges = {}
judges["default"] = Judge(judge_model, judge_prompts["single-v1"])
judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True)
judges["default-mt"] = Judge(
judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True
)
judges["math-mt"] = Judge(
judge_model,
judge_prompts["single-math-v1-multi-turn"],
ref_based=True,
multi_turn=True,
)
return judges
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bench-name",
type=str,
default="mt_bench",
help="The name of the benchmark question set.",
)
parser.add_argument(
"--judge-file",
type=str,
default="data/judge_prompts.jsonl",
help="The file of judge prompts.",
)
parser.add_argument("--judge-model", type=str, default="gpt-4")
parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo")
parser.add_argument(
"--mode",
type=str,
default="single",
choices=["pairwise-baseline", "pairwise-all", "single"],
help=(
"Evaluation mode. "
"`pairwise-baseline` runs pairwise comparision against a baseline. "
"`pairwise-all` runs pairwise comparision between all pairs. "
"`single` runs single answer grading."
),
)
parser.add_argument(
"--model-list",
type=str,
nargs="+",
default=None,
help="A list of models to be evaluated",
)
parser.add_argument(
"--parallel", type=int, default=1, help="The number of concurrent API calls."
)
parser.add_argument(
"--first-n", type=int, help="A debug option. Only run the first `n` judgments."
)
args = parser.parse_args()
question_file = f"data/{args.bench_name}/question.jsonl"
answer_dir = f"data/{args.bench_name}/model_answer"
ref_answer_dir = f"data/{args.bench_name}/reference_answer"
# Load questions
questions = load_questions(question_file, None, None)
# Load answers
model_answers = load_model_answers(answer_dir)
ref_answers = load_model_answers(ref_answer_dir)
# Load judge
judge_prompts = load_judge_prompts(args.judge_file)
if args.first_n:
questions = questions[: args.first_n]
if args.model_list is None:
models = get_model_list(answer_dir)
else:
models = args.model_list
if args.mode == "single":
judges = make_judge_single(args.judge_model, judge_prompts)
play_a_match_func = play_a_match_single
output_file = (
f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl"
)
make_match_func = make_match_single
baseline_model = None
else:
judges = make_judge_pairwise(args.judge_model, judge_prompts)
play_a_match_func = play_a_match_pair
output_file = (
f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl"
)
if args.mode == "pairwise-all":
make_match_func = make_match_all_pairs
baseline_model = None
else:
make_match_func = make_match
baseline_model = args.baseline_model
check_data(questions, model_answers, ref_answers, models, judges)
question_math = [q for q in questions if q["category"] in NEED_REF_CATS]
question_default = [q for q in questions if q["category"] not in NEED_REF_CATS]
# Make matches
matches = []
matches += make_match_func(
question_default, models, model_answers, judges["default"], baseline_model
)
matches += make_match_func(
question_math,
models,
model_answers,
judges["math"],
baseline_model,
ref_answers,
)
matches += make_match_func(
question_default,
models,
model_answers,
judges["default-mt"],
baseline_model,
multi_turn=True,
)
matches += make_match_func(
question_math,
models,
model_answers,
judges["math-mt"],
baseline_model,
ref_answers,
multi_turn=True,
)
match_stat = {}
match_stat["bench_name"] = args.bench_name
match_stat["mode"] = args.mode
match_stat["judge"] = args.judge_model
match_stat["baseline"] = baseline_model
match_stat["model_list"] = models
match_stat["total_num_questions"] = len(questions)
match_stat["total_num_matches"] = len(matches)
match_stat["output_path"] = output_file
# Show match stats and prompt enter to continue
print("Stats:")
print(json.dumps(match_stat, indent=4))
input("Press Enter to confirm...")
# Play matches
if args.parallel == 1:
for match in tqdm(matches):
play_a_match_func(match, output_file=output_file)
else:
def play_a_match_wrapper(match):
play_a_match_func(match, output_file=output_file)
np.random.seed(0)
np.random.shuffle(matches)
with ThreadPoolExecutor(args.parallel) as executor:
for match in tqdm(
executor.map(play_a_match_wrapper, matches), total=len(matches)
):
pass