Chenxi Whitehouse
commited on
Commit
•
2b4f5ff
1
Parent(s):
5e94756
add src file for models
Browse files
src/models/DualEncoderModule.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
from transformers.optimization import AdamW
|
4 |
+
import torchmetrics
|
5 |
+
|
6 |
+
|
7 |
+
class DualEncoderModule(pl.LightningModule):
|
8 |
+
|
9 |
+
def __init__(self, tokenizer, model, learning_rate=1e-3):
|
10 |
+
super().__init__()
|
11 |
+
self.tokenizer = tokenizer
|
12 |
+
self.model = model
|
13 |
+
self.learning_rate = learning_rate
|
14 |
+
|
15 |
+
self.train_acc = torchmetrics.Accuracy(
|
16 |
+
task="multiclass", num_classes=model.num_labels
|
17 |
+
)
|
18 |
+
self.val_acc = torchmetrics.Accuracy(
|
19 |
+
task="multiclass", num_classes=model.num_labels
|
20 |
+
)
|
21 |
+
self.test_acc = torchmetrics.Accuracy(
|
22 |
+
task="multiclass", num_classes=model.num_labels
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, input_ids, **kwargs):
|
26 |
+
return self.model(input_ids, **kwargs)
|
27 |
+
|
28 |
+
def configure_optimizers(self):
|
29 |
+
optimizer = AdamW(self.parameters(), lr=self.learning_rate)
|
30 |
+
return optimizer
|
31 |
+
|
32 |
+
def training_step(self, batch, batch_idx):
|
33 |
+
pos_ids, pos_mask, neg_ids, neg_mask = batch
|
34 |
+
|
35 |
+
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
|
36 |
+
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
|
37 |
+
|
38 |
+
pos_outputs = self(
|
39 |
+
pos_ids,
|
40 |
+
attention_mask=pos_mask,
|
41 |
+
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
|
42 |
+
pos_ids.get_device()
|
43 |
+
),
|
44 |
+
)
|
45 |
+
neg_outputs = self(
|
46 |
+
neg_ids,
|
47 |
+
attention_mask=neg_mask,
|
48 |
+
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
|
49 |
+
neg_ids.get_device()
|
50 |
+
),
|
51 |
+
)
|
52 |
+
|
53 |
+
loss_scale = 1.0
|
54 |
+
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
|
55 |
+
|
56 |
+
pos_logits = pos_outputs.logits
|
57 |
+
pos_preds = torch.argmax(pos_logits, axis=1)
|
58 |
+
self.train_acc(
|
59 |
+
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
|
60 |
+
)
|
61 |
+
|
62 |
+
neg_logits = neg_outputs.logits
|
63 |
+
neg_preds = torch.argmax(neg_logits, axis=1)
|
64 |
+
self.train_acc(
|
65 |
+
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
|
66 |
+
)
|
67 |
+
|
68 |
+
return {"loss": loss}
|
69 |
+
|
70 |
+
def validation_step(self, batch, batch_idx):
|
71 |
+
pos_ids, pos_mask, neg_ids, neg_mask = batch
|
72 |
+
|
73 |
+
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
|
74 |
+
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
|
75 |
+
|
76 |
+
pos_outputs = self(
|
77 |
+
pos_ids,
|
78 |
+
attention_mask=pos_mask,
|
79 |
+
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
|
80 |
+
pos_ids.get_device()
|
81 |
+
),
|
82 |
+
)
|
83 |
+
neg_outputs = self(
|
84 |
+
neg_ids,
|
85 |
+
attention_mask=neg_mask,
|
86 |
+
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
|
87 |
+
neg_ids.get_device()
|
88 |
+
),
|
89 |
+
)
|
90 |
+
|
91 |
+
loss_scale = 1.0
|
92 |
+
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
|
93 |
+
|
94 |
+
pos_logits = pos_outputs.logits
|
95 |
+
pos_preds = torch.argmax(pos_logits, axis=1)
|
96 |
+
self.val_acc(
|
97 |
+
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
|
98 |
+
)
|
99 |
+
|
100 |
+
neg_logits = neg_outputs.logits
|
101 |
+
neg_preds = torch.argmax(neg_logits, axis=1)
|
102 |
+
self.val_acc(
|
103 |
+
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
|
104 |
+
)
|
105 |
+
|
106 |
+
self.log("val_acc", self.val_acc)
|
107 |
+
|
108 |
+
return {"loss": loss}
|
109 |
+
|
110 |
+
def test_step(self, batch, batch_idx):
|
111 |
+
pos_ids, pos_mask, neg_ids, neg_mask = batch
|
112 |
+
|
113 |
+
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
|
114 |
+
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
|
115 |
+
|
116 |
+
pos_outputs = self(
|
117 |
+
pos_ids,
|
118 |
+
attention_mask=pos_mask,
|
119 |
+
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
|
120 |
+
pos_ids.get_device()
|
121 |
+
),
|
122 |
+
)
|
123 |
+
neg_outputs = self(
|
124 |
+
neg_ids,
|
125 |
+
attention_mask=neg_mask,
|
126 |
+
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
|
127 |
+
neg_ids.get_device()
|
128 |
+
),
|
129 |
+
)
|
130 |
+
|
131 |
+
pos_logits = pos_outputs.logits
|
132 |
+
pos_preds = torch.argmax(pos_logits, axis=1)
|
133 |
+
self.test_acc(
|
134 |
+
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
|
135 |
+
)
|
136 |
+
|
137 |
+
neg_logits = neg_outputs.logits
|
138 |
+
neg_preds = torch.argmax(neg_logits, axis=1)
|
139 |
+
self.test_acc(
|
140 |
+
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
|
141 |
+
)
|
142 |
+
|
143 |
+
self.log("test_acc", self.test_acc)
|
src/models/SequenceClassificationModule.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
from transformers.optimization import AdamW
|
4 |
+
import torchmetrics
|
5 |
+
from torchmetrics.classification import F1Score
|
6 |
+
|
7 |
+
|
8 |
+
class SequenceClassificationModule(pl.LightningModule):
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.tokenizer = tokenizer
|
15 |
+
self.model = model
|
16 |
+
self.learning_rate = learning_rate
|
17 |
+
|
18 |
+
self.train_acc = torchmetrics.Accuracy(
|
19 |
+
task="multiclass", num_classes=model.num_labels
|
20 |
+
)
|
21 |
+
self.val_acc = torchmetrics.Accuracy(
|
22 |
+
task="multiclass", num_classes=model.num_labels
|
23 |
+
)
|
24 |
+
self.test_acc = torchmetrics.Accuracy(
|
25 |
+
task="multiclass", num_classes=model.num_labels
|
26 |
+
)
|
27 |
+
|
28 |
+
self.train_f1 = F1Score(
|
29 |
+
task="multiclass", num_classes=model.num_labels, average="macro"
|
30 |
+
)
|
31 |
+
self.val_f1 = F1Score(
|
32 |
+
task="multiclass", num_classes=model.num_labels, average=None
|
33 |
+
)
|
34 |
+
self.test_f1 = F1Score(
|
35 |
+
task="multiclass", num_classes=model.num_labels, average=None
|
36 |
+
)
|
37 |
+
|
38 |
+
self.use_question_stance_approach = use_question_stance_approach
|
39 |
+
|
40 |
+
def forward(self, input_ids, **kwargs):
|
41 |
+
return self.model(input_ids, **kwargs)
|
42 |
+
|
43 |
+
def configure_optimizers(self):
|
44 |
+
optimizer = AdamW(self.parameters(), lr=self.learning_rate)
|
45 |
+
return optimizer
|
46 |
+
|
47 |
+
def training_step(self, batch, batch_idx):
|
48 |
+
x, x_mask, y = batch
|
49 |
+
|
50 |
+
outputs = self(x, attention_mask=x_mask, labels=y)
|
51 |
+
logits = outputs.logits
|
52 |
+
loss = outputs.loss
|
53 |
+
|
54 |
+
preds = torch.argmax(logits, axis=1)
|
55 |
+
|
56 |
+
self.log("train_loss", loss)
|
57 |
+
|
58 |
+
return {"loss": loss}
|
59 |
+
|
60 |
+
def validation_step(self, batch, batch_idx):
|
61 |
+
x, x_mask, y = batch
|
62 |
+
|
63 |
+
outputs = self(x, attention_mask=x_mask, labels=y)
|
64 |
+
logits = outputs.logits
|
65 |
+
loss = outputs.loss
|
66 |
+
|
67 |
+
preds = torch.argmax(logits, axis=1)
|
68 |
+
|
69 |
+
if not self.use_question_stance_approach:
|
70 |
+
self.val_acc(preds, y)
|
71 |
+
self.log("val_acc_step", self.val_acc)
|
72 |
+
|
73 |
+
self.val_f1(preds, y)
|
74 |
+
self.log("val_loss", loss)
|
75 |
+
|
76 |
+
return {"val_loss": loss, "src": x, "pred": preds, "target": y}
|
77 |
+
|
78 |
+
def validation_epoch_end(self, outs):
|
79 |
+
if self.use_question_stance_approach:
|
80 |
+
self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
|
81 |
+
|
82 |
+
self.log("val_acc_epoch", self.val_acc)
|
83 |
+
|
84 |
+
f1 = self.val_f1.compute()
|
85 |
+
self.val_f1.reset()
|
86 |
+
|
87 |
+
self.log("val_f1_epoch", torch.mean(f1))
|
88 |
+
|
89 |
+
class_names = ["supported", "refuted", "nei", "conflicting"]
|
90 |
+
for i, c_name in enumerate(class_names):
|
91 |
+
self.log("val_f1_" + c_name, f1[i])
|
92 |
+
|
93 |
+
def test_step(self, batch, batch_idx):
|
94 |
+
x, x_mask, y = batch
|
95 |
+
|
96 |
+
outputs = self(x, attention_mask=x_mask)
|
97 |
+
logits = outputs.logits
|
98 |
+
|
99 |
+
preds = torch.argmax(logits, axis=1)
|
100 |
+
|
101 |
+
if not self.use_question_stance_approach:
|
102 |
+
self.test_acc(preds, y)
|
103 |
+
self.log("test_acc_step", self.test_acc)
|
104 |
+
self.test_f1(preds, y)
|
105 |
+
|
106 |
+
return {"src": x, "pred": preds, "target": y}
|
107 |
+
|
108 |
+
def test_epoch_end(self, outs):
|
109 |
+
if self.use_question_stance_approach:
|
110 |
+
self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
|
111 |
+
|
112 |
+
self.log("test_acc_epoch", self.test_acc)
|
113 |
+
|
114 |
+
f1 = self.test_f1.compute()
|
115 |
+
self.test_f1.reset()
|
116 |
+
self.log("test_f1_epoch", torch.mean(f1))
|
117 |
+
|
118 |
+
class_names = ["supported", "refuted", "nei", "conflicting"]
|
119 |
+
for i, c_name in enumerate(class_names):
|
120 |
+
self.log("test_f1_" + c_name, f1[i])
|
121 |
+
|
122 |
+
def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
|
123 |
+
gold_labels = {}
|
124 |
+
question_support = {}
|
125 |
+
for out in outputs:
|
126 |
+
srcs = out["src"]
|
127 |
+
preds = out["pred"]
|
128 |
+
tgts = out["target"]
|
129 |
+
|
130 |
+
tokens = self.tokenizer.batch_decode(
|
131 |
+
srcs, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
132 |
+
)
|
133 |
+
|
134 |
+
for src, pred, tgt in zip(tokens, preds, tgts):
|
135 |
+
claim_id = src.split("[ question ]")[0]
|
136 |
+
|
137 |
+
if claim_id not in gold_labels:
|
138 |
+
gold_labels[claim_id] = tgt
|
139 |
+
question_support[claim_id] = []
|
140 |
+
|
141 |
+
question_support[claim_id].append(pred)
|
142 |
+
|
143 |
+
for k, gold_label in gold_labels.items():
|
144 |
+
support = question_support[k]
|
145 |
+
|
146 |
+
has_unanswerable = False
|
147 |
+
has_true = False
|
148 |
+
has_false = False
|
149 |
+
|
150 |
+
for v in support:
|
151 |
+
if v == 0:
|
152 |
+
has_true = True
|
153 |
+
if v == 1:
|
154 |
+
has_false = True
|
155 |
+
if v in (
|
156 |
+
2,
|
157 |
+
3,
|
158 |
+
): # TODO very ugly hack -- we cant have different numbers of labels for train and test so we do this
|
159 |
+
has_unanswerable = True
|
160 |
+
|
161 |
+
if has_unanswerable:
|
162 |
+
answer = 2
|
163 |
+
elif has_true and not has_false:
|
164 |
+
answer = 0
|
165 |
+
elif has_false and not has_true:
|
166 |
+
answer = 1
|
167 |
+
elif has_true and has_false:
|
168 |
+
answer = 3
|
169 |
+
|
170 |
+
# TODO this is very hacky and wont work if the device is literally anything other than cuda:0
|
171 |
+
acc_scorer(
|
172 |
+
torch.as_tensor([answer]).to("cuda:0"),
|
173 |
+
torch.as_tensor([gold_label]).to("cuda:0"),
|
174 |
+
)
|
175 |
+
f1_scorer(
|
176 |
+
torch.as_tensor([answer]).to("cuda:0"),
|
177 |
+
torch.as_tensor([gold_label]).to("cuda:0"),
|
178 |
+
)
|
src/reranking/rerank_questions.py
CHANGED
@@ -17,7 +17,7 @@ if __name__ == "__main__":
|
|
17 |
parser.add_argument(
|
18 |
"-i",
|
19 |
"--top_k_qa_file",
|
20 |
-
default="
|
21 |
help="Json file with claim and top k generated question-answer pairs.",
|
22 |
)
|
23 |
parser.add_argument(
|
|
|
17 |
parser.add_argument(
|
18 |
"-i",
|
19 |
"--top_k_qa_file",
|
20 |
+
default="data_store/dev_top_k_qa.json",
|
21 |
help="Json file with claim and top k generated question-answer pairs.",
|
22 |
)
|
23 |
parser.add_argument(
|