amigov1 / eval /eval_usmle.py
asach's picture
Upload folder using huggingface_hub
d727a17
import sys
import os
sys.path.append("..")
import re
import json
import fire
import string
from tqdm.autonotebook import tqdm
from medalpaca.inferer import Inferer
greedy_search = {
"num_beams" : 1,
"do_sample" : False,
"max_new_tokens" : 128,
"early_stopping" : False
}
beam_serach = {
"num_beams" : 4,
"do_sample" : False,
"max_new_tokens" : 128,
"early_stopping" : True,
}
sampling_top_k = {
"do_sample" : True,
"num_beams": 1,
"max_new_tokens": 128,
"early_stopping": True,
"temperature": 0.7,
"top_k": 50
}
sampling_top_p = {
"do_sample" : True,
"top_k": 0,
"num_beams": 1,
"max_new_tokens": 128,
"early_stopping": True,
"temperature": 0.7,
"top_p": 0.9
}
sampling = {
"do_sample" : True,
"top_k": 50,
"num_beams": 1,
"max_new_tokens": 128,
"early_stopping": True,
"temperature": 0.4,
"top_p": 0.9
}
def format_question(d):
question = d["question"]
options = d["options"]
for k, v in options.items():
question += f"\n{k}: {v}"
return question
def strip_special_chars(input_str):
"Remove special characters from string start/end"
if not input_str:
return input_str
start_index = 0
end_index = len(input_str) - 1
while start_index < len(input_str) and input_str[start_index] not in string.ascii_letters + string.digits:
start_index += 1
while end_index >= 0 and input_str[end_index] not in string.ascii_letters + string.digits:
end_index -= 1
if start_index <= end_index:
return input_str[start_index:end_index + 1]
else:
return ""
def starts_with_capital_letter(input_str):
"""
The answers should start like this:
'A: '
'A. '
'A '
"""
pattern = r'^[A-Z](:|\.|) .+'
return bool(re.match(pattern, input_str))
def main(
model_name: str, # "medalpaca/medalpaca-lora-13b-8bit",
prompt_template: str, # "../medalpaca/prompt_templates/medalpaca.json",
base_model: str, # "decapoda-research/llama-13b-hf",
peft: bool, # True,
load_in_8bit: bool, # True
path_to_exams: str, # eval/data/test/
ntries: int = 5,
skip_if_exists: bool = True,
):
model = Inferer(
model_name=model_name,
prompt_template=prompt_template,
base_model=base_model,
peft=peft,
load_in_8bit=load_in_8bit,
)
# for step_idx in [1]:
# with open(os.path.join(path_to_exams, f"test.jsonl")) as fp:
# step = json.loads(fp)
#questions = json.loads()
with open(os.path.join(path_to_exams, f"test.jsonl")) as f:
questions = [json.loads(line) for line in f]
print(questions)
outname = os.path.join(path_to_exams, f"step_{model_name.split('/')[-1]}.json")
if os.path.exists(outname):
with open(outname, "r") as fp:
answers = json.load(fp)
else:
answers = []
# pbar = tqdm(len(questions))
# pbar.set_description_str(f"Evaluating USMLE Step {step_idx}")
# for i, question in enumerate(pbar):
# if skip_if_exists and (i+1) <= len(answers):
# continue
# for j in range(ntries):
print(len(questions))
for question in tqdm(questions):
question = question
n = 0
response = model(
instruction="Answer this multiple choice question.",
input=format_question(question),
output="The Answer to the question is:",
**sampling
)
response = strip_special_chars(response)
print(response)
if starts_with_capital_letter(response):
n += 1
break
else:
print(f"Output not satisfactoy, retrying {n+1}/{ntries}")
question["answer"] = response
answers.append(question)
with open(outname, "w+") as fp:
json.dump(answers, fp)
if __name__ == "__main__":
fire.Fire(main)