File size: 3,767 Bytes
eaaaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import json
import tqdm
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from data_loaders.SequenceClassificationDataLoader import (
    SequenceClassificationDataLoader,
)
from models.SequenceClassificationModule import SequenceClassificationModule


LABEL = [
    "Supported",
    "Refuted",
    "Not Enough Evidence",
    "Conflicting Evidence/Cherrypicking",
]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label."
    )
    parser.add_argument(
        "-i",
        "--claim_with_evidence_file",
        default="data/dev_top3_questions.json",
        help="Json file with claim and top question-answer pairs as evidence.",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        default="data_store/dev_veracity.json",
        help="Json file with the veracity predictions.",
    )
    parser.add_argument(
        "-ckpt",
        "--best_checkpoint",
        type=str,
        default="pretrained_models/bert_veracity.ckpt",
    )
    args = parser.parse_args()

    with open(args.claim_with_evidence_file) as f:
        examples = json.load(f)

    bert_model_name = "bert-base-uncased"

    tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    bert_model = BertForSequenceClassification.from_pretrained(
        bert_model_name, num_labels=4, problem_type="single_label_classification"
    )
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    trained_model = SequenceClassificationModule.load_from_checkpoint(
        args.best_checkpoint, tokenizer=tokenizer, model=bert_model
    ).to(device)

    dataLoader = SequenceClassificationDataLoader(
        tokenizer=tokenizer,
        data_file="this_is_discontinued",
        batch_size=32,
        add_extra_nee=False,
    )

    predictions = []

    for example in tqdm.tqdm(examples):
        example_strings = []
        for evidence in example["evidence"]:
            example_strings.append(
                dataLoader.quadruple_to_string(
                    example["claim"], evidence["question"], evidence["answer"], ""
                )
            )

        if (
            len(example_strings) == 0
        ):  # If we found no evidence e.g. because google returned 0 pages, just output NEI.
            example["label"] = "Not Enough Evidence"
            continue

        tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings)
        example_support = torch.argmax(
            trained_model(tokenized_strings, attention_mask=attention_mask).logits,
            axis=1,
        )

        has_unanswerable = False
        has_true = False
        has_false = False

        for v in example_support:
            if v == 0:
                has_true = True
            if v == 1:
                has_false = True
            if v in (
                2,
                3,
            ):  # TODO another hack -- we cant have different labels for train and test so we do this
                has_unanswerable = True

        if has_unanswerable:
            answer = 2
        elif has_true and not has_false:
            answer = 0
        elif not has_true and has_false:
            answer = 1
        else:
            answer = 3

        json_data = {
            "claim_id": example["claim_id"],
            "claim": example["claim"],
            "evidence": example["evidence"],
            "label": LABEL[answer],
        }
        predictions.append(json_data)

    with open(args.output_file, "w", encoding="utf-8") as output_file:
        json.dump(predictions, output_file, ensure_ascii=False, indent=4)