Chenxi Whitehouse commited on
Commit
a6e9308
1 Parent(s): ce6cd35

update src

Browse files
README.md CHANGED
@@ -120,7 +120,7 @@ The result for dev and the test set below. We recommend using 0.25 as cut-off sc
120
 
121
  | Model | Split | Q only | Q + A | Veracity @ 0.2 | @ 0.25 | @ 0.3 |
122
  |-------------------|-------|--------|-------|----------------|--------|-------|
123
- | AVeriTeC-BLOOM-7b | dev | | | | | |
124
  | AVeriTeC-BLOOM-7b | test | | | | | |
125
 
126
  ## Citation
 
120
 
121
  | Model | Split | Q only | Q + A | Veracity @ 0.2 | @ 0.25 | @ 0.3 |
122
  |-------------------|-------|--------|-------|----------------|--------|-------|
123
+ | AVeriTeC-BLOOM-7b | dev | 0.24 | 0.19 | 0.19 | 0.09 | 0.05 |
124
  | AVeriTeC-BLOOM-7b | test | | | | | |
125
 
126
  ## Citation
src/prediction/evaluate_veracity.py CHANGED
@@ -23,7 +23,7 @@ def compute_all_pairwise_scores(src_data, tgt_data, metric):
23
  return scores
24
 
25
 
26
- def print_with_space(left, right, left_space=40):
27
  print_spaces = " " * (left_space - len(left))
28
  print(left + print_spaces + right)
29
 
@@ -303,14 +303,9 @@ if __name__ == "__main__":
303
  str(v_score[i]),
304
  )
305
  print("--------------------")
 
306
  type_scores = scorer.evaluate_averitec_veracity_by_type(
307
- predictions, references, threshold=0.2
308
- )
309
- for t, v in type_scores.items():
310
- print_with_space(" * Veracity scores (" + t + "):", str(v))
311
- print("--------------------")
312
- type_scores = scorer.evaluate_averitec_veracity_by_type(
313
- predictions, references, threshold=0.3
314
  )
315
  for t, v in type_scores.items():
316
  print_with_space(" * Veracity scores (" + t + "):", str(v))
 
23
  return scores
24
 
25
 
26
+ def print_with_space(left, right, left_space=45):
27
  print_spaces = " " * (left_space - len(left))
28
  print(left + print_spaces + right)
29
 
 
303
  str(v_score[i]),
304
  )
305
  print("--------------------")
306
+ print("AVeriTeC scores by type @ 0.25:")
307
  type_scores = scorer.evaluate_averitec_veracity_by_type(
308
+ predictions, references, threshold=0.25
 
 
 
 
 
 
309
  )
310
  for t, v in type_scores.items():
311
  print_with_space(" * Veracity scores (" + t + "):", str(v))
src/prediction/veracity_prediction.py CHANGED
@@ -2,11 +2,9 @@ import argparse
2
  import json
3
  import tqdm
4
  import torch
 
5
  from transformers import BertTokenizer, BertForSequenceClassification
6
- from data_loaders.SequenceClassificationDataLoader import (
7
- SequenceClassificationDataLoader,
8
- )
9
- from models.SequenceClassificationModule import SequenceClassificationModule
10
 
11
 
12
  LABEL = [
@@ -17,6 +15,50 @@ LABEL = [
17
  ]
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if __name__ == "__main__":
21
  parser = argparse.ArgumentParser(
22
  description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label."
@@ -83,7 +125,9 @@ if __name__ == "__main__":
83
 
84
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings)
85
  example_support = torch.argmax(
86
- trained_model(tokenized_strings, attention_mask=attention_mask).logits,
 
 
87
  axis=1,
88
  )
89
 
 
2
  import json
3
  import tqdm
4
  import torch
5
+ import pytorch_lightning as pl
6
  from transformers import BertTokenizer, BertForSequenceClassification
7
+ from src.models.SequenceClassificationModule import SequenceClassificationModule
 
 
 
8
 
9
 
10
  LABEL = [
 
15
  ]
16
 
17
 
18
+ class SequenceClassificationDataLoader(pl.LightningDataModule):
19
+ def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
20
+ super().__init__()
21
+ self.tokenizer = tokenizer
22
+ self.data_file = data_file
23
+ self.batch_size = batch_size
24
+ self.add_extra_nee = add_extra_nee
25
+
26
+ def tokenize_strings(
27
+ self,
28
+ source_sentences,
29
+ max_length=512,
30
+ pad_to_max_length=False,
31
+ return_tensors="pt",
32
+ ):
33
+ encoded_dict = self.tokenizer(
34
+ source_sentences,
35
+ max_length=max_length,
36
+ padding="max_length" if pad_to_max_length else "longest",
37
+ truncation=True,
38
+ return_tensors=return_tensors,
39
+ )
40
+
41
+ input_ids = encoded_dict["input_ids"]
42
+ attention_masks = encoded_dict["attention_mask"]
43
+
44
+ return input_ids, attention_masks
45
+
46
+ def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
47
+ if bool_explanation is not None and len(bool_explanation) > 0:
48
+ bool_explanation = ", because " + bool_explanation.lower().strip()
49
+ else:
50
+ bool_explanation = ""
51
+ return (
52
+ "[CLAIM] "
53
+ + claim.strip()
54
+ + " [QUESTION] "
55
+ + question.strip()
56
+ + " "
57
+ + answer.strip()
58
+ + bool_explanation
59
+ )
60
+
61
+
62
  if __name__ == "__main__":
63
  parser = argparse.ArgumentParser(
64
  description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label."
 
125
 
126
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings)
127
  example_support = torch.argmax(
128
+ trained_model(
129
+ tokenized_strings.to(device), attention_mask=attention_mask.to(device)
130
+ ).logits,
131
  axis=1,
132
  )
133