Chenxi Whitehouse commited on
Commit
2b4f5ff
1 Parent(s): 5e94756

add src file for models

Browse files
src/models/DualEncoderModule.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from transformers.optimization import AdamW
4
+ import torchmetrics
5
+
6
+
7
+ class DualEncoderModule(pl.LightningModule):
8
+
9
+ def __init__(self, tokenizer, model, learning_rate=1e-3):
10
+ super().__init__()
11
+ self.tokenizer = tokenizer
12
+ self.model = model
13
+ self.learning_rate = learning_rate
14
+
15
+ self.train_acc = torchmetrics.Accuracy(
16
+ task="multiclass", num_classes=model.num_labels
17
+ )
18
+ self.val_acc = torchmetrics.Accuracy(
19
+ task="multiclass", num_classes=model.num_labels
20
+ )
21
+ self.test_acc = torchmetrics.Accuracy(
22
+ task="multiclass", num_classes=model.num_labels
23
+ )
24
+
25
+ def forward(self, input_ids, **kwargs):
26
+ return self.model(input_ids, **kwargs)
27
+
28
+ def configure_optimizers(self):
29
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate)
30
+ return optimizer
31
+
32
+ def training_step(self, batch, batch_idx):
33
+ pos_ids, pos_mask, neg_ids, neg_mask = batch
34
+
35
+ neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
36
+ neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
37
+
38
+ pos_outputs = self(
39
+ pos_ids,
40
+ attention_mask=pos_mask,
41
+ labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
42
+ pos_ids.get_device()
43
+ ),
44
+ )
45
+ neg_outputs = self(
46
+ neg_ids,
47
+ attention_mask=neg_mask,
48
+ labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
49
+ neg_ids.get_device()
50
+ ),
51
+ )
52
+
53
+ loss_scale = 1.0
54
+ loss = pos_outputs.loss + loss_scale * neg_outputs.loss
55
+
56
+ pos_logits = pos_outputs.logits
57
+ pos_preds = torch.argmax(pos_logits, axis=1)
58
+ self.train_acc(
59
+ pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
60
+ )
61
+
62
+ neg_logits = neg_outputs.logits
63
+ neg_preds = torch.argmax(neg_logits, axis=1)
64
+ self.train_acc(
65
+ neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
66
+ )
67
+
68
+ return {"loss": loss}
69
+
70
+ def validation_step(self, batch, batch_idx):
71
+ pos_ids, pos_mask, neg_ids, neg_mask = batch
72
+
73
+ neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
74
+ neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
75
+
76
+ pos_outputs = self(
77
+ pos_ids,
78
+ attention_mask=pos_mask,
79
+ labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
80
+ pos_ids.get_device()
81
+ ),
82
+ )
83
+ neg_outputs = self(
84
+ neg_ids,
85
+ attention_mask=neg_mask,
86
+ labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
87
+ neg_ids.get_device()
88
+ ),
89
+ )
90
+
91
+ loss_scale = 1.0
92
+ loss = pos_outputs.loss + loss_scale * neg_outputs.loss
93
+
94
+ pos_logits = pos_outputs.logits
95
+ pos_preds = torch.argmax(pos_logits, axis=1)
96
+ self.val_acc(
97
+ pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
98
+ )
99
+
100
+ neg_logits = neg_outputs.logits
101
+ neg_preds = torch.argmax(neg_logits, axis=1)
102
+ self.val_acc(
103
+ neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
104
+ )
105
+
106
+ self.log("val_acc", self.val_acc)
107
+
108
+ return {"loss": loss}
109
+
110
+ def test_step(self, batch, batch_idx):
111
+ pos_ids, pos_mask, neg_ids, neg_mask = batch
112
+
113
+ neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
114
+ neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
115
+
116
+ pos_outputs = self(
117
+ pos_ids,
118
+ attention_mask=pos_mask,
119
+ labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
120
+ pos_ids.get_device()
121
+ ),
122
+ )
123
+ neg_outputs = self(
124
+ neg_ids,
125
+ attention_mask=neg_mask,
126
+ labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
127
+ neg_ids.get_device()
128
+ ),
129
+ )
130
+
131
+ pos_logits = pos_outputs.logits
132
+ pos_preds = torch.argmax(pos_logits, axis=1)
133
+ self.test_acc(
134
+ pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
135
+ )
136
+
137
+ neg_logits = neg_outputs.logits
138
+ neg_preds = torch.argmax(neg_logits, axis=1)
139
+ self.test_acc(
140
+ neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
141
+ )
142
+
143
+ self.log("test_acc", self.test_acc)
src/models/SequenceClassificationModule.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from transformers.optimization import AdamW
4
+ import torchmetrics
5
+ from torchmetrics.classification import F1Score
6
+
7
+
8
+ class SequenceClassificationModule(pl.LightningModule):
9
+
10
+ def __init__(
11
+ self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3
12
+ ):
13
+ super().__init__()
14
+ self.tokenizer = tokenizer
15
+ self.model = model
16
+ self.learning_rate = learning_rate
17
+
18
+ self.train_acc = torchmetrics.Accuracy(
19
+ task="multiclass", num_classes=model.num_labels
20
+ )
21
+ self.val_acc = torchmetrics.Accuracy(
22
+ task="multiclass", num_classes=model.num_labels
23
+ )
24
+ self.test_acc = torchmetrics.Accuracy(
25
+ task="multiclass", num_classes=model.num_labels
26
+ )
27
+
28
+ self.train_f1 = F1Score(
29
+ task="multiclass", num_classes=model.num_labels, average="macro"
30
+ )
31
+ self.val_f1 = F1Score(
32
+ task="multiclass", num_classes=model.num_labels, average=None
33
+ )
34
+ self.test_f1 = F1Score(
35
+ task="multiclass", num_classes=model.num_labels, average=None
36
+ )
37
+
38
+ self.use_question_stance_approach = use_question_stance_approach
39
+
40
+ def forward(self, input_ids, **kwargs):
41
+ return self.model(input_ids, **kwargs)
42
+
43
+ def configure_optimizers(self):
44
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate)
45
+ return optimizer
46
+
47
+ def training_step(self, batch, batch_idx):
48
+ x, x_mask, y = batch
49
+
50
+ outputs = self(x, attention_mask=x_mask, labels=y)
51
+ logits = outputs.logits
52
+ loss = outputs.loss
53
+
54
+ preds = torch.argmax(logits, axis=1)
55
+
56
+ self.log("train_loss", loss)
57
+
58
+ return {"loss": loss}
59
+
60
+ def validation_step(self, batch, batch_idx):
61
+ x, x_mask, y = batch
62
+
63
+ outputs = self(x, attention_mask=x_mask, labels=y)
64
+ logits = outputs.logits
65
+ loss = outputs.loss
66
+
67
+ preds = torch.argmax(logits, axis=1)
68
+
69
+ if not self.use_question_stance_approach:
70
+ self.val_acc(preds, y)
71
+ self.log("val_acc_step", self.val_acc)
72
+
73
+ self.val_f1(preds, y)
74
+ self.log("val_loss", loss)
75
+
76
+ return {"val_loss": loss, "src": x, "pred": preds, "target": y}
77
+
78
+ def validation_epoch_end(self, outs):
79
+ if self.use_question_stance_approach:
80
+ self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
81
+
82
+ self.log("val_acc_epoch", self.val_acc)
83
+
84
+ f1 = self.val_f1.compute()
85
+ self.val_f1.reset()
86
+
87
+ self.log("val_f1_epoch", torch.mean(f1))
88
+
89
+ class_names = ["supported", "refuted", "nei", "conflicting"]
90
+ for i, c_name in enumerate(class_names):
91
+ self.log("val_f1_" + c_name, f1[i])
92
+
93
+ def test_step(self, batch, batch_idx):
94
+ x, x_mask, y = batch
95
+
96
+ outputs = self(x, attention_mask=x_mask)
97
+ logits = outputs.logits
98
+
99
+ preds = torch.argmax(logits, axis=1)
100
+
101
+ if not self.use_question_stance_approach:
102
+ self.test_acc(preds, y)
103
+ self.log("test_acc_step", self.test_acc)
104
+ self.test_f1(preds, y)
105
+
106
+ return {"src": x, "pred": preds, "target": y}
107
+
108
+ def test_epoch_end(self, outs):
109
+ if self.use_question_stance_approach:
110
+ self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
111
+
112
+ self.log("test_acc_epoch", self.test_acc)
113
+
114
+ f1 = self.test_f1.compute()
115
+ self.test_f1.reset()
116
+ self.log("test_f1_epoch", torch.mean(f1))
117
+
118
+ class_names = ["supported", "refuted", "nei", "conflicting"]
119
+ for i, c_name in enumerate(class_names):
120
+ self.log("test_f1_" + c_name, f1[i])
121
+
122
+ def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
123
+ gold_labels = {}
124
+ question_support = {}
125
+ for out in outputs:
126
+ srcs = out["src"]
127
+ preds = out["pred"]
128
+ tgts = out["target"]
129
+
130
+ tokens = self.tokenizer.batch_decode(
131
+ srcs, skip_special_tokens=True, clean_up_tokenization_spaces=True
132
+ )
133
+
134
+ for src, pred, tgt in zip(tokens, preds, tgts):
135
+ claim_id = src.split("[ question ]")[0]
136
+
137
+ if claim_id not in gold_labels:
138
+ gold_labels[claim_id] = tgt
139
+ question_support[claim_id] = []
140
+
141
+ question_support[claim_id].append(pred)
142
+
143
+ for k, gold_label in gold_labels.items():
144
+ support = question_support[k]
145
+
146
+ has_unanswerable = False
147
+ has_true = False
148
+ has_false = False
149
+
150
+ for v in support:
151
+ if v == 0:
152
+ has_true = True
153
+ if v == 1:
154
+ has_false = True
155
+ if v in (
156
+ 2,
157
+ 3,
158
+ ): # TODO very ugly hack -- we cant have different numbers of labels for train and test so we do this
159
+ has_unanswerable = True
160
+
161
+ if has_unanswerable:
162
+ answer = 2
163
+ elif has_true and not has_false:
164
+ answer = 0
165
+ elif has_false and not has_true:
166
+ answer = 1
167
+ elif has_true and has_false:
168
+ answer = 3
169
+
170
+ # TODO this is very hacky and wont work if the device is literally anything other than cuda:0
171
+ acc_scorer(
172
+ torch.as_tensor([answer]).to("cuda:0"),
173
+ torch.as_tensor([gold_label]).to("cuda:0"),
174
+ )
175
+ f1_scorer(
176
+ torch.as_tensor([answer]).to("cuda:0"),
177
+ torch.as_tensor([gold_label]).to("cuda:0"),
178
+ )
src/reranking/rerank_questions.py CHANGED
@@ -17,7 +17,7 @@ if __name__ == "__main__":
17
  parser.add_argument(
18
  "-i",
19
  "--top_k_qa_file",
20
- default="data/dev_top_k_qa.json",
21
  help="Json file with claim and top k generated question-answer pairs.",
22
  )
23
  parser.add_argument(
 
17
  parser.add_argument(
18
  "-i",
19
  "--top_k_qa_file",
20
+ default="data_store/dev_top_k_qa.json",
21
  help="Json file with claim and top k generated question-answer pairs.",
22
  )
23
  parser.add_argument(