File size: 5,213 Bytes
eaaaf3d
 
 
 
a6e9308
eaaaf3d
a6e9308
eaaaf3d
 
 
 
 
 
 
 
 
 
a6e9308
 
 
 
 
 
 
 
 
 
 
c4a60f3
a6e9308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaaaf3d
 
 
 
 
 
 
4cac25e
eaaaf3d
 
 
 
 
c4a60f3
eaaaf3d
 
 
 
 
 
 
 
 
 
4cac25e
eaaaf3d
4cac25e
 
eaaaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e9308
 
 
eaaaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cac25e
eaaaf3d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import argparse
import json
import tqdm
import torch
import pytorch_lightning as pl
from transformers import BertTokenizer, BertForSequenceClassification
from src.models.SequenceClassificationModule import SequenceClassificationModule


LABEL = [
    "Supported",
    "Refuted",
    "Not Enough Evidence",
    "Conflicting Evidence/Cherrypicking",
]


class SequenceClassificationDataLoader(pl.LightningDataModule):
    def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
        super().__init__()
        self.tokenizer = tokenizer
        self.data_file = data_file
        self.batch_size = batch_size
        self.add_extra_nee = add_extra_nee

    def tokenize_strings(
        self,
        source_sentences,
        max_length=400,
        pad_to_max_length=False,
        return_tensors="pt",
    ):
        encoded_dict = self.tokenizer(
            source_sentences,
            max_length=max_length,
            padding="max_length" if pad_to_max_length else "longest",
            truncation=True,
            return_tensors=return_tensors,
        )

        input_ids = encoded_dict["input_ids"]
        attention_masks = encoded_dict["attention_mask"]

        return input_ids, attention_masks

    def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
        if bool_explanation is not None and len(bool_explanation) > 0:
            bool_explanation = ", because " + bool_explanation.lower().strip()
        else:
            bool_explanation = ""
        return (
            "[CLAIM] "
            + claim.strip()
            + " [QUESTION] "
            + question.strip()
            + " "
            + answer.strip()
            + bool_explanation
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label."
    )
    parser.add_argument(
        "-i",
        "--claim_with_evidence_file",
        default="data_store/dev_top_3_rerank_qa.json",
        help="Json file with claim and top question-answer pairs as evidence.",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        default="data_store/dev_veracity_prediction.json",
        help="Json file with the veracity predictions.",
    )
    parser.add_argument(
        "-ckpt",
        "--best_checkpoint",
        type=str,
        default="pretrained_models/bert_veracity.ckpt",
    )
    args = parser.parse_args()

    examples = []
    with open(args.claim_with_evidence_file) as f:
        for line in f:
            examples.append(json.loads(line))

    bert_model_name = "bert-base-uncased"

    tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    bert_model = BertForSequenceClassification.from_pretrained(
        bert_model_name, num_labels=4, problem_type="single_label_classification"
    )
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    trained_model = SequenceClassificationModule.load_from_checkpoint(
        args.best_checkpoint, tokenizer=tokenizer, model=bert_model
    ).to(device)

    dataLoader = SequenceClassificationDataLoader(
        tokenizer=tokenizer,
        data_file="this_is_discontinued",
        batch_size=32,
        add_extra_nee=False,
    )

    predictions = []

    for example in tqdm.tqdm(examples):
        example_strings = []
        for evidence in example["evidence"]:
            example_strings.append(
                dataLoader.quadruple_to_string(
                    example["claim"], evidence["question"], evidence["answer"], ""
                )
            )

        if (
            len(example_strings) == 0
        ):  # If we found no evidence e.g. because google returned 0 pages, just output NEI.
            example["label"] = "Not Enough Evidence"
            continue

        tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings)
        example_support = torch.argmax(
            trained_model(
                tokenized_strings.to(device), attention_mask=attention_mask.to(device)
            ).logits,
            axis=1,
        )

        has_unanswerable = False
        has_true = False
        has_false = False

        for v in example_support:
            if v == 0:
                has_true = True
            if v == 1:
                has_false = True
            if v in (
                2,
                3,
            ):  # TODO another hack -- we cant have different labels for train and test so we do this
                has_unanswerable = True

        if has_unanswerable:
            answer = 2
        elif has_true and not has_false:
            answer = 0
        elif not has_true and has_false:
            answer = 1
        else:
            answer = 3

        json_data = {
            "claim_id": example["claim_id"],
            "claim": example["claim"],
            "evidence": example["evidence"],
            "pred_label": LABEL[answer],
        }
        predictions.append(json_data)

    with open(args.output_file, "w", encoding="utf-8") as output_file:
        json.dump(predictions, output_file, ensure_ascii=False, indent=4)