import argparse import time import json import nltk from rank_bm25 import BM25Okapi import numpy as np import torch from transformers import BloomTokenizerFast, BloomForCausalLM def claim2prompts(example): claim = example["claim"] # claim_str = "Claim: " + claim + "||Evidence: " claim_str = "Evidence: " for question in example["questions"]: q_text = question["question"].strip() if len(q_text) == 0: continue if not q_text[-1] == "?": q_text += "?" answer_strings = [] for a in question["answers"]: if a["answer_type"] in ["Extractive", "Abstractive"]: answer_strings.append(a["answer"]) if a["answer_type"] == "Boolean": answer_strings.append( a["answer"] + ", because " + a["boolean_explanation"].lower().strip() ) for a_text in answer_strings: if not a_text[-1] in [".", "!", ":", "?"]: a_text += "." # prompt_lookup_str = claim + " " + a_text prompt_lookup_str = a_text this_q_claim_str = ( claim_str + " " + a_text.strip() + "||Question answered: " + q_text ) yield ( prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n"), ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Use a prompt to generate questions that could be answered by top-k retrieved evidence. Output generated questions." ) parser.add_argument("--reference_corpus", default="data/train.json", help="") parser.add_argument("--target_file", default="data/dev.json", help="") parser.add_argument( "-i", "--top_k_target_knowledge", default="data_store/dev_top_k.json", help="Directory where the sentences for the scraped data is saved.", ) parser.add_argument( "-o", "--output_questions", default="data_store/dev_bm25_questions.json", help="Directory where the sentences for the scraped data is saved.", ) parser.add_argument( "--top_k", default=100, type=int, help="How many documents should we pick out with BM25", ) args = parser.parse_args() # few-shot learning from the training set with open(args.reference_corpus, "r", encoding="utf-8") as json_file: train_examples = json.load(json_file) prompt_corpus, tokenized_corpus = [], [] for example in train_examples: for lookup_str, prompt in claim2prompts(example): entry = nltk.word_tokenize(lookup_str) tokenized_corpus.append(entry) prompt_corpus.append(prompt) prompt_bm25 = BM25Okapi(tokenized_corpus) # Load the bloom model: tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") model = BloomForCausalLM.from_pretrained( "bigscience/bloom-7b1", device_map="auto", torch_dtype=torch.bfloat16, offload_folder="./offload", ) with open(args.output_questions, "w", encoding="utf-8") as output_file: with open(args.top_k_target_knowledge, "r", encoding="utf-8") as json_file: for i, line in enumerate(json_file): data = json.loads(line) top_k_sentences_urls = data[f"top_{args.top_k}"] claim = data["claim"] claim_id = data["claim_id"] bm25_qau = [] # question, answer, url # Generate questions for those top k: for sent_i, sentences_urls in enumerate(top_k_sentences_urls): prompt_lookup_str = sentences_urls["sentence"] url = sentences_urls["url"] prompt_s = prompt_bm25.get_scores( nltk.word_tokenize(prompt_lookup_str) ) prompt_n = 10 prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] prompt_docs = [prompt_corpus[i] for i in prompt_top_n] claim_prompt = ( "Evidence: " + prompt_lookup_str.replace("\n", " ") + "\nQuestion answered: " ) prompt = "\n\n".join(prompt_docs + [claim_prompt]) inputs = tokenizer([prompt], padding=True, return_tensors="pt").to( model.device ) st = time.time() outputs = model.generate( inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True, ) print( f"Generated QA for sent {sent_i} in file {i}. Time elapsed: {time.time() - st}" ) tgt_text = tokenizer.batch_decode( outputs[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True, )[0] # We are not allowed to generate more than 250 characters: tgt_text = tgt_text[:250] qau_pair = [ tgt_text.strip().split("?")[0].replace("\n", " ") + "?", prompt_lookup_str.replace("\n", " "), url, ] bm25_qau.append(qau_pair) json_data = { "claim_id": claim_id, "claim": claim, "bm25_qau": bm25_qau, } output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") output_file.flush()