AVeriTeC / src /reranking /question_generation_top_sentences.py
Chenxi Whitehouse
update
ed1749d
raw
history blame
No virus
5.98 kB
import argparse
import time
import json
import nltk
from rank_bm25 import BM25Okapi
import numpy as np
import torch
from transformers import BloomTokenizerFast, BloomForCausalLM
def claim2prompts(example):
claim = example["claim"]
# claim_str = "Claim: " + claim + "||Evidence: "
claim_str = "Evidence: "
for question in example["questions"]:
q_text = question["question"].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
for a in question["answers"]:
if a["answer_type"] in ["Extractive", "Abstractive"]:
answer_strings.append(a["answer"])
if a["answer_type"] == "Boolean":
answer_strings.append(
a["answer"]
+ ", because "
+ a["boolean_explanation"].lower().strip()
)
for a_text in answer_strings:
if not a_text[-1] in [".", "!", ":", "?"]:
a_text += "."
# prompt_lookup_str = claim + " " + a_text
prompt_lookup_str = a_text
this_q_claim_str = (
claim_str + " " + a_text.strip() + "||Question answered: " + q_text
)
yield (
prompt_lookup_str,
this_q_claim_str.replace("\n", " ").replace("||", "\n"),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Use a prompt to generate questions that could be answered by top-k retrieved evidence. Output generated questions."
)
parser.add_argument("--reference_corpus", default="data/train.json", help="")
parser.add_argument("--target_file", default="data/dev.json", help="")
parser.add_argument(
"-i",
"--top_k_target_knowledge",
default="data_store/dev_top_k.json",
help="Directory where the sentences for the scraped data is saved.",
)
parser.add_argument(
"-o",
"--output_questions",
default="data_store/dev_bm25_questions.json",
help="Directory where the sentences for the scraped data is saved.",
)
parser.add_argument(
"--top_k",
default=100,
type=int,
help="How many documents should we pick out with BM25",
)
args = parser.parse_args()
# few-shot learning from the training set
with open(args.reference_corpus, "r", encoding="utf-8") as json_file:
train_examples = json.load(json_file)
prompt_corpus, tokenized_corpus = [], []
for example in train_examples:
for lookup_str, prompt in claim2prompts(example):
entry = nltk.word_tokenize(lookup_str)
tokenized_corpus.append(entry)
prompt_corpus.append(prompt)
prompt_bm25 = BM25Okapi(tokenized_corpus)
# Load the bloom model:
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
model = BloomForCausalLM.from_pretrained(
"bigscience/bloom-7b1",
device_map="auto",
torch_dtype=torch.bfloat16,
offload_folder="./offload",
)
with open(args.output_questions, "w", encoding="utf-8") as output_file:
with open(args.top_k_target_knowledge, "r", encoding="utf-8") as json_file:
for i, line in enumerate(json_file):
data = json.loads(line)
top_k_sentences_urls = data[f"top_{args.top_k}"]
claim = data["claim"]
claim_id = data["claim_id"]
bm25_qau = [] # question, answer, url
# Generate questions for those top k:
for sent_i, sentences_urls in enumerate(top_k_sentences_urls):
prompt_lookup_str = sentences_urls["sentence"]
url = sentences_urls["url"]
prompt_s = prompt_bm25.get_scores(
nltk.word_tokenize(prompt_lookup_str)
)
prompt_n = 10
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
claim_prompt = (
"Evidence: "
+ prompt_lookup_str.replace("\n", " ")
+ "\nQuestion answered: "
)
prompt = "\n\n".join(prompt_docs + [claim_prompt])
inputs = tokenizer([prompt], padding=True, return_tensors="pt").to(
model.device
)
st = time.time()
outputs = model.generate(
inputs["input_ids"],
max_length=5000,
num_beams=2,
no_repeat_ngram_size=2,
early_stopping=True,
)
print(
f"Generated QA for sent {sent_i} in file {i}. Time elapsed: {time.time() - st}"
)
tgt_text = tokenizer.batch_decode(
outputs[:, inputs["input_ids"].shape[-1] :],
skip_special_tokens=True,
)[0]
# We are not allowed to generate more than 250 characters:
tgt_text = tgt_text[:250]
qau_pair = [
tgt_text.strip().split("?")[0].replace("\n", " ") + "?",
prompt_lookup_str.replace("\n", " "),
url,
]
bm25_qau.append(qau_pair)
json_data = {
"claim_id": claim_id,
"claim": claim,
"bm25_qau": bm25_qau,
}
output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n")
output_file.flush()