|
''' |
|
Adapted from https://github.com/lupantech/ScienceQA |
|
''' |
|
|
|
from dataclasses import dataclass |
|
from typing import List, Optional |
|
|
|
def get_question_text(problem): |
|
question = problem['question'] |
|
return question |
|
|
|
|
|
def get_context_text(problem, use_caption): |
|
txt_context = problem['hint'] |
|
img_context = problem['caption'] if use_caption else "" |
|
context = " ".join([txt_context, img_context]).strip() |
|
if context == "": |
|
context = "N/A" |
|
return context |
|
|
|
|
|
def get_choice_text(probelm, options): |
|
choices = probelm['choices'] |
|
choice_list = [] |
|
for i, c in enumerate(choices): |
|
choice_list.append("({}) {}".format(options[i], c)) |
|
choice_txt = " ".join(choice_list) |
|
|
|
return choice_txt |
|
|
|
def get_origin_answer(problem, options): |
|
return problem['choices'][problem['answer']] |
|
|
|
def get_answer(problem, options): |
|
return options[problem['answer']] |
|
|
|
|
|
def get_lecture_text(problem): |
|
|
|
lecture = problem['lecture'].replace("\n", "\\n") |
|
return lecture |
|
|
|
|
|
def get_solution_text(problem): |
|
|
|
solution = problem['solution'].replace("\n", "\\n") |
|
return solution |
|
|
|
|
|
def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True, WithOutput = False, curr_le_data=None): |
|
|
|
input_format, output_format = format.split("-") |
|
|
|
|
|
if input_format == "CQM": |
|
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" |
|
elif input_format == "QCM": |
|
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" |
|
elif input_format == "QM": |
|
input = f"Question: {question}\nOptions: {choice}\n" |
|
elif input_format == "QC": |
|
input = f"Question: {question}\nContext: {context}\n" |
|
elif input_format == "QCMG": |
|
if curr_le_data is not None: |
|
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n{curr_le_data}\n" |
|
else: |
|
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nSolution: {lecture} {solution}\n" |
|
elif input_format == "CQMG": |
|
if curr_le_data is not None: |
|
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n{curr_le_data}\n" |
|
else: |
|
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\nSolution: {lecture} {solution}\n" |
|
|
|
elif input_format == "QCML": |
|
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" |
|
elif input_format == "QCME": |
|
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" |
|
elif input_format == "QCMLE": |
|
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" |
|
|
|
elif input_format == "QCLM": |
|
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" |
|
elif input_format == "QCEM": |
|
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" |
|
elif input_format == "QCLEM": |
|
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" |
|
elif input_format == "QCMA": |
|
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nAnswer: The answer is {answer}.\n" |
|
elif input_format == "QCA": |
|
input = f"Question: {question}\nContext: {context}\nAnswer: The answer is {answer}. \nBECAUSE:" |
|
|
|
|
|
if test_example: |
|
if output_format == 'A': |
|
output = "Answer:" |
|
elif output_format == 'E': |
|
output = "Solution:" |
|
else: |
|
output = "Solution:" |
|
elif output_format == 'A': |
|
output = f"Answer: The answer is {answer}." |
|
|
|
elif output_format == 'AL': |
|
output = f"Answer: The answer is {answer}. BECAUSE: {solution}" |
|
elif output_format == 'AE': |
|
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" |
|
elif output_format == 'ALE': |
|
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" |
|
elif output_format == 'AEL': |
|
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" |
|
|
|
elif output_format == 'LA': |
|
output = f"Answer: {lecture} The answer is {answer}." |
|
elif output_format == 'EA': |
|
output = f"Answer: {solution} The answer is {answer}." |
|
elif output_format == 'LEA': |
|
output = f"Answer: {lecture} {solution} The answer is {answer}." |
|
elif output_format == 'ELA': |
|
output = f"Answer: {solution} {lecture} The answer is {answer}." |
|
|
|
elif output_format == 'LE': |
|
output = f"Solution: {lecture} {solution}." |
|
|
|
elif output_format == 'E': |
|
output = f"Solution: {solution}" |
|
|
|
|
|
if WithOutput: |
|
if output.endswith("BECAUSE:"): |
|
output = output.replace("BECAUSE:", "").strip() |
|
if output_format == 'E': |
|
text = input + f'Solution:' |
|
elif output_format == 'A': |
|
text = input + f'Answer:' |
|
else: |
|
text = input + f'Solution:' |
|
text = text.replace(" ", " ").strip() |
|
output = output.replace(" ", " ").strip() |
|
return text, output |
|
|
|
|
|
text = input + output |
|
text = text.replace(" ", " ").strip() |
|
if text.endswith("BECAUSE:"): |
|
text = text.replace("BECAUSE:", "").strip() |
|
return text |
|
|
|
|
|
def build_prompt(problems, shot_qids, test_qid, args): |
|
|
|
examples = [] |
|
|
|
|
|
for qid in shot_qids: |
|
question = get_question_text(problems[qid]) |
|
context = get_context_text(problems[qid], args.use_caption) |
|
choice = get_choice_text(problems[qid], args.options) |
|
answer = get_answer(problems[qid], args.options) |
|
lecture = get_lecture_text(problems[qid]) |
|
solution = get_solution_text(problems[qid]) |
|
|
|
train_example = create_one_example(args.prompt_format, |
|
question, |
|
context, |
|
choice, |
|
answer, |
|
lecture, |
|
solution, |
|
test_example=False) |
|
examples.append(train_example) |
|
|
|
|
|
question = get_question_text(problems[test_qid]) |
|
context = get_context_text(problems[test_qid], args.use_caption) |
|
choice = get_choice_text(problems[test_qid], args.options) |
|
answer = get_answer(problems[test_qid], args.options) |
|
lecture = get_lecture_text(problems[test_qid]) |
|
solution = get_solution_text(problems[test_qid]) |
|
|
|
test_example = create_one_example(args.prompt_format, |
|
question, |
|
context, |
|
choice, |
|
answer, |
|
lecture, |
|
solution, |
|
test_example=True) |
|
examples.append(test_example) |
|
|
|
|
|
prompt_input = '\n\n'.join(examples) |
|
|
|
return prompt_input |
|
|
|
def build_train_pair(problems, test_qid, args, curr_le_data=None): |
|
|
|
examples = [] |
|
|
|
|
|
question = get_question_text(problems[test_qid]) |
|
context = get_context_text(problems[test_qid], args.use_caption) |
|
choice = get_choice_text(problems[test_qid], args.options) |
|
|
|
lecture = get_lecture_text(problems[test_qid]) |
|
solution = get_solution_text(problems[test_qid]) |
|
|
|
|
|
answer_option = get_answer(problems[test_qid], args.options) |
|
answer = "(" + answer_option + ")" |
|
|
|
test_example, target = create_one_example(args.prompt_format, |
|
question, |
|
context, |
|
choice, |
|
answer, |
|
lecture, |
|
solution, |
|
test_example=False,WithOutput = True, curr_le_data=curr_le_data) |
|
examples.append(test_example) |
|
|
|
target = target.replace("Answer:", "").strip() |
|
|
|
prompt_input = '\n\n'.join(examples) |
|
|
|
return prompt_input, target |
|
|
|
@dataclass(frozen=True) |
|
class InputFeatures: |
|
""" |
|
A single set of features of data. |
|
Property names are the same names as the corresponding inputs to a model. |
|
""" |
|
|
|
input_ids: List[List[int]] |
|
attention_mask: Optional[List[List[int]]] |
|
token_type_ids: Optional[List[List[int]]] |
|
le_input_ids: List[List[int]] |
|
le_attention_mask: Optional[List[List[int]]] |
|
le_token_type_ids: Optional[List[List[int]]] |
|
label: Optional[int] |