|
import argparse |
|
import json |
|
import tqdm |
|
import torch |
|
import pytorch_lightning as pl |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
from src.models.SequenceClassificationModule import SequenceClassificationModule |
|
|
|
|
|
LABEL = [ |
|
"Supported", |
|
"Refuted", |
|
"Not Enough Evidence", |
|
"Conflicting Evidence/Cherrypicking", |
|
] |
|
|
|
|
|
class SequenceClassificationDataLoader(pl.LightningDataModule): |
|
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): |
|
super().__init__() |
|
self.tokenizer = tokenizer |
|
self.data_file = data_file |
|
self.batch_size = batch_size |
|
self.add_extra_nee = add_extra_nee |
|
|
|
def tokenize_strings( |
|
self, |
|
source_sentences, |
|
max_length=400, |
|
pad_to_max_length=False, |
|
return_tensors="pt", |
|
): |
|
encoded_dict = self.tokenizer( |
|
source_sentences, |
|
max_length=max_length, |
|
padding="max_length" if pad_to_max_length else "longest", |
|
truncation=True, |
|
return_tensors=return_tensors, |
|
) |
|
|
|
input_ids = encoded_dict["input_ids"] |
|
attention_masks = encoded_dict["attention_mask"] |
|
|
|
return input_ids, attention_masks |
|
|
|
def quadruple_to_string(self, claim, question, answer, bool_explanation=""): |
|
if bool_explanation is not None and len(bool_explanation) > 0: |
|
bool_explanation = ", because " + bool_explanation.lower().strip() |
|
else: |
|
bool_explanation = "" |
|
return ( |
|
"[CLAIM] " |
|
+ claim.strip() |
|
+ " [QUESTION] " |
|
+ question.strip() |
|
+ " " |
|
+ answer.strip() |
|
+ bool_explanation |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label." |
|
) |
|
parser.add_argument( |
|
"-i", |
|
"--claim_with_evidence_file", |
|
default="data_store/dev_top_3_rerank_qa.json", |
|
help="Json file with claim and top question-answer pairs as evidence.", |
|
) |
|
parser.add_argument( |
|
"-o", |
|
"--output_file", |
|
default="data_store/dev_veracity_prediction.json", |
|
help="Json file with the veracity predictions.", |
|
) |
|
parser.add_argument( |
|
"-ckpt", |
|
"--best_checkpoint", |
|
type=str, |
|
default="pretrained_models/bert_veracity.ckpt", |
|
) |
|
args = parser.parse_args() |
|
|
|
examples = [] |
|
with open(args.claim_with_evidence_file) as f: |
|
for line in f: |
|
examples.append(json.loads(line)) |
|
|
|
bert_model_name = "bert-base-uncased" |
|
|
|
tokenizer = BertTokenizer.from_pretrained(bert_model_name) |
|
bert_model = BertForSequenceClassification.from_pretrained( |
|
bert_model_name, num_labels=4, problem_type="single_label_classification" |
|
) |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
trained_model = SequenceClassificationModule.load_from_checkpoint( |
|
args.best_checkpoint, tokenizer=tokenizer, model=bert_model |
|
).to(device) |
|
|
|
dataLoader = SequenceClassificationDataLoader( |
|
tokenizer=tokenizer, |
|
data_file="this_is_discontinued", |
|
batch_size=32, |
|
add_extra_nee=False, |
|
) |
|
|
|
predictions = [] |
|
|
|
for example in tqdm.tqdm(examples): |
|
example_strings = [] |
|
for evidence in example["evidence"]: |
|
example_strings.append( |
|
dataLoader.quadruple_to_string( |
|
example["claim"], evidence["question"], evidence["answer"], "" |
|
) |
|
) |
|
|
|
if ( |
|
len(example_strings) == 0 |
|
): |
|
example["label"] = "Not Enough Evidence" |
|
continue |
|
|
|
tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings) |
|
example_support = torch.argmax( |
|
trained_model( |
|
tokenized_strings.to(device), attention_mask=attention_mask.to(device) |
|
).logits, |
|
axis=1, |
|
) |
|
|
|
has_unanswerable = False |
|
has_true = False |
|
has_false = False |
|
|
|
for v in example_support: |
|
if v == 0: |
|
has_true = True |
|
if v == 1: |
|
has_false = True |
|
if v in ( |
|
2, |
|
3, |
|
): |
|
has_unanswerable = True |
|
|
|
if has_unanswerable: |
|
answer = 2 |
|
elif has_true and not has_false: |
|
answer = 0 |
|
elif not has_true and has_false: |
|
answer = 1 |
|
else: |
|
answer = 3 |
|
|
|
json_data = { |
|
"claim_id": example["claim_id"], |
|
"claim": example["claim"], |
|
"evidence": example["evidence"], |
|
"pred_label": LABEL[answer], |
|
} |
|
predictions.append(json_data) |
|
|
|
with open(args.output_file, "w", encoding="utf-8") as output_file: |
|
json.dump(predictions, output_file, ensure_ascii=False, indent=4) |
|
|