Spaces:
Runtime error
Runtime error
PFEemp2024
commited on
Commit
•
63775f2
1
Parent(s):
c2c01a0
add necessary file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- anonymous_demo/__init__.py +5 -0
- anonymous_demo/core/__init__.py +0 -0
- anonymous_demo/core/tad/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/README.MD +3 -0
- anonymous_demo/core/tad/classic/__bert__/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +121 -0
- anonymous_demo/core/tad/classic/__bert__/models/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +46 -0
- anonymous_demo/core/tad/classic/__init__.py +0 -0
- anonymous_demo/core/tad/models/__init__.py +9 -0
- anonymous_demo/core/tad/prediction/__init__.py +0 -0
- anonymous_demo/core/tad/prediction/tad_classifier.py +518 -0
- anonymous_demo/functional/__init__.py +3 -0
- anonymous_demo/functional/checkpoint/__init__.py +1 -0
- anonymous_demo/functional/checkpoint/checkpoint_manager.py +19 -0
- anonymous_demo/functional/config/__init__.py +1 -0
- anonymous_demo/functional/config/config_manager.py +64 -0
- anonymous_demo/functional/config/tad_config_manager.py +229 -0
- anonymous_demo/functional/dataset/__init__.py +1 -0
- anonymous_demo/functional/dataset/dataset_manager.py +45 -0
- anonymous_demo/network/__init__.py +0 -0
- anonymous_demo/network/lcf_pooler.py +28 -0
- anonymous_demo/network/lsa.py +73 -0
- anonymous_demo/network/sa_encoder.py +199 -0
- anonymous_demo/utils/__init__.py +0 -0
- anonymous_demo/utils/demo_utils.py +247 -0
- anonymous_demo/utils/logger.py +38 -0
- checkpoints.zip +3 -0
- text_defense/201.SST2/stsa.binary.dev.dat +0 -0
- text_defense/201.SST2/stsa.binary.test.dat +0 -0
- text_defense/201.SST2/stsa.binary.train.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.test.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.train.dat +0 -0
- text_defense/202.IMDB10K/imdb10k.valid.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.test.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.train.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.valid.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat +0 -0
- textattack/__init__.py +39 -0
- textattack/__main__.py +6 -0
- textattack/attack.py +492 -0
- textattack/attack_args.py +763 -0
- textattack/attack_recipes/__init__.py +43 -0
- textattack/attack_recipes/a2t_yoo_2021.py +74 -0
- textattack/attack_recipes/attack_recipe.py +30 -0
- textattack/attack_recipes/bae_garg_2019.py +123 -0
- textattack/attack_recipes/bert_attack_li_2020.py +95 -0
- textattack/attack_recipes/checklist_ribeiro_2020.py +53 -0
anonymous_demo/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "1.0.0"
|
2 |
+
|
3 |
+
__name__ = "anonymous_demo"
|
4 |
+
|
5 |
+
from anonymous_demo.functional import TADCheckpointManager
|
anonymous_demo/core/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/classic/__bert__/README.MD
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## This is the simple migration from ABSA-PyTorch under MIT license
|
2 |
+
|
3 |
+
Project Address: https://github.com/songyouwei/ABSA-PyTorch
|
anonymous_demo/core/tad/classic/__bert__/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .models import *
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
from findfile import find_cwd_dir
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
|
6 |
+
|
7 |
+
class Tokenizer4Pretraining:
|
8 |
+
def __init__(self, max_seq_len, opt, **kwargs):
|
9 |
+
if kwargs.pop("offline", False):
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
11 |
+
find_cwd_dir(opt.pretrained_bert.split("/")[-1]),
|
12 |
+
do_lower_case="uncased" in opt.pretrained_bert,
|
13 |
+
)
|
14 |
+
else:
|
15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
16 |
+
opt.pretrained_bert, do_lower_case="uncased" in opt.pretrained_bert
|
17 |
+
)
|
18 |
+
self.max_seq_len = max_seq_len
|
19 |
+
|
20 |
+
def text_to_sequence(self, text, reverse=False, padding="post", truncating="post"):
|
21 |
+
return self.tokenizer.encode(
|
22 |
+
text,
|
23 |
+
truncation=True,
|
24 |
+
padding="max_length",
|
25 |
+
max_length=self.max_seq_len,
|
26 |
+
return_tensors="pt",
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class BERTTADDataset(Dataset):
|
31 |
+
def __init__(self, tokenizer, opt):
|
32 |
+
self.bert_baseline_input_colses = {"bert": ["text_bert_indices"]}
|
33 |
+
|
34 |
+
self.tokenizer = tokenizer
|
35 |
+
self.opt = opt
|
36 |
+
self.all_data = []
|
37 |
+
|
38 |
+
def parse_sample(self, text):
|
39 |
+
return [text]
|
40 |
+
|
41 |
+
def prepare_infer_sample(self, text: str, ignore_error):
|
42 |
+
self.process_data(self.parse_sample(text), ignore_error=ignore_error)
|
43 |
+
|
44 |
+
def process_data(self, samples, ignore_error=True):
|
45 |
+
all_data = []
|
46 |
+
if len(samples) > 100:
|
47 |
+
it = tqdm.tqdm(
|
48 |
+
samples, postfix="preparing text classification inference dataloader..."
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
it = samples
|
52 |
+
for text in it:
|
53 |
+
try:
|
54 |
+
# handle for empty lines in inference datasets
|
55 |
+
if text is None or "" == text.strip():
|
56 |
+
raise RuntimeError("Invalid Input!")
|
57 |
+
|
58 |
+
if "!ref!" in text:
|
59 |
+
text, _, labels = text.strip().partition("!ref!")
|
60 |
+
text = text.strip()
|
61 |
+
if labels.count(",") == 2:
|
62 |
+
label, is_adv, adv_train_label = labels.strip().split(",")
|
63 |
+
label, is_adv, adv_train_label = (
|
64 |
+
label.strip(),
|
65 |
+
is_adv.strip(),
|
66 |
+
adv_train_label.strip(),
|
67 |
+
)
|
68 |
+
elif labels.count(",") == 1:
|
69 |
+
label, is_adv = labels.strip().split(",")
|
70 |
+
label, is_adv = label.strip(), is_adv.strip()
|
71 |
+
adv_train_label = "-100"
|
72 |
+
elif labels.count(",") == 0:
|
73 |
+
label = labels.strip()
|
74 |
+
adv_train_label = "-100"
|
75 |
+
is_adv = "-100"
|
76 |
+
else:
|
77 |
+
label = "-100"
|
78 |
+
adv_train_label = "-100"
|
79 |
+
is_adv = "-100"
|
80 |
+
|
81 |
+
label = int(label)
|
82 |
+
adv_train_label = int(adv_train_label)
|
83 |
+
is_adv = int(is_adv)
|
84 |
+
|
85 |
+
else:
|
86 |
+
text = text.strip()
|
87 |
+
label = -100
|
88 |
+
adv_train_label = -100
|
89 |
+
is_adv = -100
|
90 |
+
|
91 |
+
text_indices = self.tokenizer.text_to_sequence("{}".format(text))
|
92 |
+
|
93 |
+
data = {
|
94 |
+
"text_bert_indices": text_indices[0],
|
95 |
+
"text_raw": text,
|
96 |
+
"label": label,
|
97 |
+
"adv_train_label": adv_train_label,
|
98 |
+
"is_adv": is_adv,
|
99 |
+
# 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
|
100 |
+
#
|
101 |
+
# 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
|
102 |
+
#
|
103 |
+
# 'is_adv': self.opt.is_adv_to_index.get(is_adv, -100) if isinstance(is_adv, str) else is_adv,
|
104 |
+
}
|
105 |
+
|
106 |
+
all_data.append(data)
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
if ignore_error:
|
110 |
+
print("Ignore error while processing:", text)
|
111 |
+
else:
|
112 |
+
raise e
|
113 |
+
|
114 |
+
self.all_data = all_data
|
115 |
+
return self.all_data
|
116 |
+
|
117 |
+
def __getitem__(self, index):
|
118 |
+
return self.all_data[index]
|
119 |
+
|
120 |
+
def __len__(self):
|
121 |
+
return len(self.all_data)
|
anonymous_demo/core/tad/classic/__bert__/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .tad_bert import TADBERT
|
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers.models.bert.modeling_bert import BertPooler
|
4 |
+
|
5 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
6 |
+
|
7 |
+
|
8 |
+
class TADBERT(nn.Module):
|
9 |
+
inputs = ["text_bert_indices"]
|
10 |
+
|
11 |
+
def __init__(self, bert, opt):
|
12 |
+
super(TADBERT, self).__init__()
|
13 |
+
self.opt = opt
|
14 |
+
self.bert = bert
|
15 |
+
self.pooler = BertPooler(bert.config)
|
16 |
+
self.dense1 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
17 |
+
self.dense2 = nn.Linear(self.opt.hidden_dim, self.opt.adv_det_dim)
|
18 |
+
self.dense3 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
19 |
+
|
20 |
+
self.encoder1 = Encoder(self.bert.config, opt=opt)
|
21 |
+
self.encoder2 = Encoder(self.bert.config, opt=opt)
|
22 |
+
self.encoder3 = Encoder(self.bert.config, opt=opt)
|
23 |
+
|
24 |
+
def forward(self, inputs):
|
25 |
+
text_raw_indices = inputs[0]
|
26 |
+
last_hidden_state = self.bert(text_raw_indices)["last_hidden_state"]
|
27 |
+
|
28 |
+
sent_logits = self.dense1(self.pooler(last_hidden_state))
|
29 |
+
advdet_logits = self.dense2(self.pooler(last_hidden_state))
|
30 |
+
adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
|
31 |
+
|
32 |
+
att_score = torch.nn.functional.normalize(
|
33 |
+
last_hidden_state.abs().sum(dim=1, keepdim=False)
|
34 |
+
- last_hidden_state.abs().min(dim=1, keepdim=True)[0],
|
35 |
+
p=1,
|
36 |
+
dim=1,
|
37 |
+
)
|
38 |
+
|
39 |
+
outputs = {
|
40 |
+
"sent_logits": sent_logits,
|
41 |
+
"advdet_logits": advdet_logits,
|
42 |
+
"adv_tr_logits": adv_tr_logits,
|
43 |
+
"last_hidden_state": last_hidden_state,
|
44 |
+
"att_score": att_score,
|
45 |
+
}
|
46 |
+
return outputs
|
anonymous_demo/core/tad/classic/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/models/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import anonymous_demo.core.tad.classic.__bert__.models
|
2 |
+
|
3 |
+
|
4 |
+
class BERTTADModelList(list):
|
5 |
+
TADBERT = anonymous_demo.core.tad.classic.__bert__.TADBERT
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
model_list = [self.TADBERT]
|
9 |
+
super().__init__(model_list)
|
anonymous_demo/core/tad/prediction/__init__.py
ADDED
File without changes
|
anonymous_demo/core/tad/prediction/tad_classifier.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
from findfile import find_file, find_cwd_dir
|
9 |
+
from termcolor import colored
|
10 |
+
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from transformers import (
|
13 |
+
AutoTokenizer,
|
14 |
+
AutoModel,
|
15 |
+
AutoConfig,
|
16 |
+
DebertaV2ForMaskedLM,
|
17 |
+
RobertaForMaskedLM,
|
18 |
+
BertForMaskedLM,
|
19 |
+
)
|
20 |
+
|
21 |
+
from ....functional.dataset.dataset_manager import detect_infer_dataset
|
22 |
+
|
23 |
+
from ..models import BERTTADModelList
|
24 |
+
from ..classic.__bert__.dataset_utils.data_utils_for_inference import (
|
25 |
+
BERTTADDataset,
|
26 |
+
Tokenizer4Pretraining,
|
27 |
+
)
|
28 |
+
|
29 |
+
from ....utils.demo_utils import (
|
30 |
+
print_args,
|
31 |
+
TransformerConnectionError,
|
32 |
+
get_device,
|
33 |
+
build_embedding_matrix,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def init_attacker(tad_classifier, defense):
|
38 |
+
try:
|
39 |
+
from textattack import Attacker
|
40 |
+
from textattack.attack_recipes import (
|
41 |
+
BAEGarg2019,
|
42 |
+
PWWSRen2019,
|
43 |
+
TextFoolerJin2019,
|
44 |
+
PSOZang2020,
|
45 |
+
IGAWang2019,
|
46 |
+
GeneticAlgorithmAlzantot2018,
|
47 |
+
DeepWordBugGao2018,
|
48 |
+
)
|
49 |
+
from textattack.datasets import Dataset
|
50 |
+
from textattack.models.wrappers import HuggingFaceModelWrapper
|
51 |
+
|
52 |
+
class DemoModelWrapper(HuggingFaceModelWrapper):
|
53 |
+
def __init__(self, model):
|
54 |
+
self.model = model # pipeline = pipeline
|
55 |
+
|
56 |
+
def __call__(self, text_inputs, **kwargs):
|
57 |
+
outputs = []
|
58 |
+
for text_input in text_inputs:
|
59 |
+
raw_outputs = self.model.infer(
|
60 |
+
text_input, print_result=False, **kwargs
|
61 |
+
)
|
62 |
+
outputs.append(raw_outputs["probs"])
|
63 |
+
return outputs
|
64 |
+
|
65 |
+
class SentAttacker:
|
66 |
+
def __init__(self, model, recipe_class=BAEGarg2019):
|
67 |
+
model = model
|
68 |
+
model_wrapper = DemoModelWrapper(model)
|
69 |
+
|
70 |
+
recipe = recipe_class.build(model_wrapper)
|
71 |
+
|
72 |
+
_dataset = [("", 0)]
|
73 |
+
_dataset = Dataset(_dataset)
|
74 |
+
|
75 |
+
self.attacker = Attacker(recipe, _dataset)
|
76 |
+
|
77 |
+
attackers = {
|
78 |
+
"bae": BAEGarg2019,
|
79 |
+
"pwws": PWWSRen2019,
|
80 |
+
"textfooler": TextFoolerJin2019,
|
81 |
+
"pso": PSOZang2020,
|
82 |
+
"iga": IGAWang2019,
|
83 |
+
"ga": GeneticAlgorithmAlzantot2018,
|
84 |
+
"wordbugger": DeepWordBugGao2018,
|
85 |
+
}
|
86 |
+
return SentAttacker(tad_classifier, attackers[defense])
|
87 |
+
except Exception as e:
|
88 |
+
print("Original error:", e)
|
89 |
+
|
90 |
+
|
91 |
+
def get_mlm_and_tokenizer(text_classifier, config):
|
92 |
+
if isinstance(text_classifier, TADTextClassifier):
|
93 |
+
base_model = text_classifier.model.bert.base_model
|
94 |
+
else:
|
95 |
+
base_model = text_classifier.bert.base_model
|
96 |
+
pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
|
97 |
+
if "deberta-v3" in config.pretrained_bert:
|
98 |
+
MLM = DebertaV2ForMaskedLM(pretrained_config)
|
99 |
+
MLM.deberta = base_model
|
100 |
+
elif "roberta" in config.pretrained_bert:
|
101 |
+
MLM = RobertaForMaskedLM(pretrained_config)
|
102 |
+
MLM.roberta = base_model
|
103 |
+
else:
|
104 |
+
MLM = BertForMaskedLM(pretrained_config)
|
105 |
+
MLM.bert = base_model
|
106 |
+
return MLM, AutoTokenizer.from_pretrained(config.pretrained_bert)
|
107 |
+
|
108 |
+
|
109 |
+
class TADTextClassifier:
|
110 |
+
def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
|
111 |
+
"""
|
112 |
+
from_train_model: load inference model from trained model
|
113 |
+
"""
|
114 |
+
self.cal_perplexity = cal_perplexity
|
115 |
+
# load from a training
|
116 |
+
if not isinstance(model_arg, str):
|
117 |
+
print("Load text classifier from training")
|
118 |
+
self.model = model_arg[0]
|
119 |
+
self.opt = model_arg[1]
|
120 |
+
self.tokenizer = model_arg[2]
|
121 |
+
else:
|
122 |
+
try:
|
123 |
+
if "fine-tuned" in model_arg:
|
124 |
+
raise ValueError(
|
125 |
+
"Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!"
|
126 |
+
)
|
127 |
+
print("Load text classifier from", model_arg)
|
128 |
+
state_dict_path = find_file(
|
129 |
+
model_arg, key=".state_dict", exclude_key=["__MACOSX"]
|
130 |
+
)
|
131 |
+
model_path = find_file(
|
132 |
+
model_arg, key=".model", exclude_key=["__MACOSX"]
|
133 |
+
)
|
134 |
+
tokenizer_path = find_file(
|
135 |
+
model_arg, key=".tokenizer", exclude_key=["__MACOSX"]
|
136 |
+
)
|
137 |
+
config_path = find_file(
|
138 |
+
model_arg, key=".config", exclude_key=["__MACOSX"]
|
139 |
+
)
|
140 |
+
|
141 |
+
print("config: {}".format(config_path))
|
142 |
+
print("state_dict: {}".format(state_dict_path))
|
143 |
+
print("model: {}".format(model_path))
|
144 |
+
print("tokenizer: {}".format(tokenizer_path))
|
145 |
+
|
146 |
+
with open(config_path, mode="rb") as f:
|
147 |
+
self.opt = pickle.load(f)
|
148 |
+
self.opt.device = get_device(kwargs.pop("auto_device", True))[0]
|
149 |
+
|
150 |
+
if state_dict_path or model_path:
|
151 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
152 |
+
if state_dict_path:
|
153 |
+
if kwargs.pop("offline", False):
|
154 |
+
self.bert = AutoModel.from_pretrained(
|
155 |
+
find_cwd_dir(
|
156 |
+
self.opt.pretrained_bert.split("/")[-1]
|
157 |
+
)
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
self.bert = AutoModel.from_pretrained(
|
161 |
+
self.opt.pretrained_bert
|
162 |
+
)
|
163 |
+
self.model = self.opt.model(self.bert, self.opt)
|
164 |
+
self.model.load_state_dict(
|
165 |
+
torch.load(state_dict_path, map_location="cpu")
|
166 |
+
)
|
167 |
+
elif model_path:
|
168 |
+
self.model = torch.load(model_path, map_location="cpu")
|
169 |
+
|
170 |
+
try:
|
171 |
+
self.tokenizer = Tokenizer4Pretraining(
|
172 |
+
max_seq_len=self.opt.max_seq_len, opt=self.opt, **kwargs
|
173 |
+
)
|
174 |
+
except ValueError:
|
175 |
+
if tokenizer_path:
|
176 |
+
with open(tokenizer_path, mode="rb") as f:
|
177 |
+
self.tokenizer = pickle.load(f)
|
178 |
+
else:
|
179 |
+
raise TransformerConnectionError()
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
raise RuntimeError(
|
183 |
+
"Exception: {} Fail to load the model from {}! ".format(
|
184 |
+
e, model_arg
|
185 |
+
)
|
186 |
+
)
|
187 |
+
|
188 |
+
self.infer_dataloader = None
|
189 |
+
self.opt.eval_batch_size = kwargs.pop("eval_batch_size", 128)
|
190 |
+
|
191 |
+
self.opt.initializer = self.opt.initializer
|
192 |
+
|
193 |
+
if self.cal_perplexity:
|
194 |
+
try:
|
195 |
+
self.MLM, self.MLM_tokenizer = get_mlm_and_tokenizer(self, self.opt)
|
196 |
+
except Exception as e:
|
197 |
+
self.MLM, self.MLM_tokenizer = None, None
|
198 |
+
|
199 |
+
self.to(self.opt.device)
|
200 |
+
|
201 |
+
def to(self, device=None):
|
202 |
+
self.opt.device = device
|
203 |
+
self.model.to(device)
|
204 |
+
if hasattr(self, "MLM"):
|
205 |
+
self.MLM.to(self.opt.device)
|
206 |
+
|
207 |
+
def cpu(self):
|
208 |
+
self.opt.device = "cpu"
|
209 |
+
self.model.to("cpu")
|
210 |
+
if hasattr(self, "MLM"):
|
211 |
+
self.MLM.to("cpu")
|
212 |
+
|
213 |
+
def cuda(self, device="cuda:0"):
|
214 |
+
self.opt.device = device
|
215 |
+
self.model.to(device)
|
216 |
+
if hasattr(self, "MLM"):
|
217 |
+
self.MLM.to(device)
|
218 |
+
|
219 |
+
def _log_write_args(self):
|
220 |
+
n_trainable_params, n_nontrainable_params = 0, 0
|
221 |
+
for p in self.model.parameters():
|
222 |
+
n_params = torch.prod(torch.tensor(p.shape))
|
223 |
+
if p.requires_grad:
|
224 |
+
n_trainable_params += n_params
|
225 |
+
else:
|
226 |
+
n_nontrainable_params += n_params
|
227 |
+
print(
|
228 |
+
"n_trainable_params: {0}, n_nontrainable_params: {1}".format(
|
229 |
+
n_trainable_params, n_nontrainable_params
|
230 |
+
)
|
231 |
+
)
|
232 |
+
for arg in vars(self.opt):
|
233 |
+
if getattr(self.opt, arg) is not None:
|
234 |
+
print(">>> {0}: {1}".format(arg, getattr(self.opt, arg)))
|
235 |
+
|
236 |
+
def batch_infer(
|
237 |
+
self,
|
238 |
+
target_file=None,
|
239 |
+
print_result=True,
|
240 |
+
save_result=False,
|
241 |
+
ignore_error=True,
|
242 |
+
defense: str = None,
|
243 |
+
):
|
244 |
+
save_path = os.path.join(os.getcwd(), "tad_text_classification.result.json")
|
245 |
+
|
246 |
+
target_file = detect_infer_dataset(target_file, task="text_defense")
|
247 |
+
if not target_file:
|
248 |
+
raise FileNotFoundError("Can not find inference datasets!")
|
249 |
+
|
250 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
251 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
252 |
+
|
253 |
+
dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
|
254 |
+
self.infer_dataloader = DataLoader(
|
255 |
+
dataset=dataset,
|
256 |
+
batch_size=self.opt.eval_batch_size,
|
257 |
+
pin_memory=True,
|
258 |
+
shuffle=False,
|
259 |
+
)
|
260 |
+
return self._infer(
|
261 |
+
save_path=save_path if save_result else None,
|
262 |
+
print_result=print_result,
|
263 |
+
defense=defense,
|
264 |
+
)
|
265 |
+
|
266 |
+
def infer(
|
267 |
+
self,
|
268 |
+
text: str = None,
|
269 |
+
print_result=True,
|
270 |
+
ignore_error=True,
|
271 |
+
defense: str = None,
|
272 |
+
):
|
273 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
274 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
275 |
+
|
276 |
+
if text:
|
277 |
+
dataset.prepare_infer_sample(text, ignore_error=ignore_error)
|
278 |
+
else:
|
279 |
+
raise RuntimeError("Please specify your datasets path!")
|
280 |
+
self.infer_dataloader = DataLoader(
|
281 |
+
dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False
|
282 |
+
)
|
283 |
+
return self._infer(print_result=print_result, defense=defense)[0]
|
284 |
+
|
285 |
+
def _infer(self, save_path=None, print_result=True, defense=None):
|
286 |
+
_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
287 |
+
|
288 |
+
correct = {True: "Correct", False: "Wrong"}
|
289 |
+
results = []
|
290 |
+
|
291 |
+
with torch.no_grad():
|
292 |
+
self.model.eval()
|
293 |
+
n_correct = 0
|
294 |
+
n_labeled = 0
|
295 |
+
|
296 |
+
n_advdet_correct = 0
|
297 |
+
n_advdet_labeled = 0
|
298 |
+
if len(self.infer_dataloader.dataset) >= 100:
|
299 |
+
it = tqdm.tqdm(self.infer_dataloader, postfix="inferring...")
|
300 |
+
else:
|
301 |
+
it = self.infer_dataloader
|
302 |
+
for _, sample in enumerate(it):
|
303 |
+
inputs = [
|
304 |
+
sample[col].to(self.opt.device) for col in self.opt.inputs_cols
|
305 |
+
]
|
306 |
+
outputs = self.model(inputs)
|
307 |
+
logits, advdet_logits, adv_tr_logits = (
|
308 |
+
outputs["sent_logits"],
|
309 |
+
outputs["advdet_logits"],
|
310 |
+
outputs["adv_tr_logits"],
|
311 |
+
)
|
312 |
+
probs, advdet_probs, adv_tr_probs = (
|
313 |
+
torch.softmax(logits, dim=-1),
|
314 |
+
torch.softmax(advdet_logits, dim=-1),
|
315 |
+
torch.softmax(adv_tr_logits, dim=-1),
|
316 |
+
)
|
317 |
+
|
318 |
+
for i, (prob, advdet_prob, adv_tr_prob) in enumerate(
|
319 |
+
zip(probs, advdet_probs, adv_tr_probs)
|
320 |
+
):
|
321 |
+
text_raw = sample["text_raw"][i]
|
322 |
+
|
323 |
+
pred_label = int(prob.argmax(axis=-1))
|
324 |
+
pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
|
325 |
+
pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
|
326 |
+
ref_label = (
|
327 |
+
int(sample["label"][i])
|
328 |
+
if int(sample["label"][i]) in self.opt.index_to_label
|
329 |
+
else ""
|
330 |
+
)
|
331 |
+
ref_is_adv_label = (
|
332 |
+
int(sample["is_adv"][i])
|
333 |
+
if int(sample["is_adv"][i]) in self.opt.index_to_is_adv
|
334 |
+
else ""
|
335 |
+
)
|
336 |
+
ref_adv_tr_label = (
|
337 |
+
int(sample["adv_train_label"][i])
|
338 |
+
if int(sample["adv_train_label"][i])
|
339 |
+
in self.opt.index_to_adv_train_label
|
340 |
+
else ""
|
341 |
+
)
|
342 |
+
|
343 |
+
if self.cal_perplexity:
|
344 |
+
ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
|
345 |
+
ids["labels"] = ids["input_ids"].clone()
|
346 |
+
ids = ids.to(self.opt.device)
|
347 |
+
loss = self.MLM(**ids)["loss"]
|
348 |
+
perplexity = float(torch.exp(loss / ids["input_ids"].size(1)))
|
349 |
+
else:
|
350 |
+
perplexity = "N.A."
|
351 |
+
|
352 |
+
result = {
|
353 |
+
"text": text_raw,
|
354 |
+
"label": self.opt.index_to_label[pred_label],
|
355 |
+
"probs": prob.cpu().numpy(),
|
356 |
+
"confidence": float(max(prob)),
|
357 |
+
"ref_label": self.opt.index_to_label[ref_label]
|
358 |
+
if isinstance(ref_label, int)
|
359 |
+
else ref_label,
|
360 |
+
"ref_label_check": correct[pred_label == ref_label]
|
361 |
+
if ref_label != -100
|
362 |
+
else "",
|
363 |
+
"is_fixed": False,
|
364 |
+
"is_adv_label": self.opt.index_to_is_adv[pred_is_adv_label],
|
365 |
+
"is_adv_probs": advdet_prob.cpu().numpy(),
|
366 |
+
"is_adv_confidence": float(max(advdet_prob)),
|
367 |
+
"ref_is_adv_label": self.opt.index_to_is_adv[ref_is_adv_label]
|
368 |
+
if isinstance(ref_is_adv_label, int)
|
369 |
+
else ref_is_adv_label,
|
370 |
+
"ref_is_adv_check": correct[
|
371 |
+
pred_is_adv_label == ref_is_adv_label
|
372 |
+
]
|
373 |
+
if ref_is_adv_label != -100
|
374 |
+
and isinstance(ref_is_adv_label, int)
|
375 |
+
else "",
|
376 |
+
"pred_adv_tr_label": self.opt.index_to_label[pred_adv_tr_label],
|
377 |
+
"ref_adv_tr_label": self.opt.index_to_label[ref_adv_tr_label],
|
378 |
+
"perplexity": perplexity,
|
379 |
+
}
|
380 |
+
if defense:
|
381 |
+
try:
|
382 |
+
if not hasattr(self, "sent_attacker"):
|
383 |
+
self.sent_attacker = init_attacker(
|
384 |
+
self, defense.lower()
|
385 |
+
)
|
386 |
+
if result["is_adv_label"] == "1":
|
387 |
+
res = self.sent_attacker.attacker.simple_attack(
|
388 |
+
text_raw, int(result["label"])
|
389 |
+
)
|
390 |
+
new_infer_res = self.infer(
|
391 |
+
res.perturbed_result.attacked_text.text,
|
392 |
+
print_result=False,
|
393 |
+
)
|
394 |
+
result["perturbed_label"] = result["label"]
|
395 |
+
result["label"] = new_infer_res["label"]
|
396 |
+
result["probs"] = new_infer_res["probs"]
|
397 |
+
result["ref_label_check"] = (
|
398 |
+
correct[int(result["label"]) == ref_label]
|
399 |
+
if ref_label != -100
|
400 |
+
else ""
|
401 |
+
)
|
402 |
+
result[
|
403 |
+
"restored_text"
|
404 |
+
] = res.perturbed_result.attacked_text.text
|
405 |
+
result["is_fixed"] = True
|
406 |
+
else:
|
407 |
+
result["restored_text"] = ""
|
408 |
+
result["is_fixed"] = False
|
409 |
+
|
410 |
+
except Exception as e:
|
411 |
+
print(
|
412 |
+
"Error:{}, try install TextAttack and tensorflow_text after 10 seconds...".format(
|
413 |
+
e
|
414 |
+
)
|
415 |
+
)
|
416 |
+
time.sleep(10)
|
417 |
+
raise RuntimeError("Installation done, please run again...")
|
418 |
+
|
419 |
+
if ref_label != -100:
|
420 |
+
n_labeled += 1
|
421 |
+
|
422 |
+
if result["label"] == result["ref_label"]:
|
423 |
+
n_correct += 1
|
424 |
+
|
425 |
+
if ref_is_adv_label != -100:
|
426 |
+
n_advdet_labeled += 1
|
427 |
+
if ref_is_adv_label == pred_is_adv_label:
|
428 |
+
n_advdet_correct += 1
|
429 |
+
|
430 |
+
results.append(result)
|
431 |
+
|
432 |
+
try:
|
433 |
+
if print_result:
|
434 |
+
for ex_id, result in enumerate(results):
|
435 |
+
text_printing = result["text"][:]
|
436 |
+
text_info = ""
|
437 |
+
if result["label"] != "-100":
|
438 |
+
if not result["ref_label"]:
|
439 |
+
text_info += " -> <CLS:{}(ref:{} confidence:{})>".format(
|
440 |
+
result["label"],
|
441 |
+
result["ref_label"],
|
442 |
+
result["confidence"],
|
443 |
+
)
|
444 |
+
elif result["label"] == result["ref_label"]:
|
445 |
+
text_info += colored(
|
446 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
447 |
+
result["label"],
|
448 |
+
result["ref_label"],
|
449 |
+
result["confidence"],
|
450 |
+
),
|
451 |
+
"green",
|
452 |
+
)
|
453 |
+
else:
|
454 |
+
text_info += colored(
|
455 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
456 |
+
result["label"],
|
457 |
+
result["ref_label"],
|
458 |
+
result["confidence"],
|
459 |
+
),
|
460 |
+
"red",
|
461 |
+
)
|
462 |
+
|
463 |
+
# AdvDet
|
464 |
+
if result["is_adv_label"] != "-100":
|
465 |
+
if not result["ref_is_adv_label"]:
|
466 |
+
text_info += " -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
467 |
+
result["is_adv_label"],
|
468 |
+
result["ref_is_adv_check"],
|
469 |
+
result["is_adv_confidence"],
|
470 |
+
)
|
471 |
+
elif result["is_adv_label"] == result["ref_is_adv_label"]:
|
472 |
+
text_info += colored(
|
473 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
474 |
+
result["is_adv_label"],
|
475 |
+
result["ref_is_adv_label"],
|
476 |
+
result["is_adv_confidence"],
|
477 |
+
),
|
478 |
+
"green",
|
479 |
+
)
|
480 |
+
else:
|
481 |
+
text_info += colored(
|
482 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
483 |
+
result["is_adv_label"],
|
484 |
+
result["ref_is_adv_label"],
|
485 |
+
result["is_adv_confidence"],
|
486 |
+
),
|
487 |
+
"red",
|
488 |
+
)
|
489 |
+
text_printing += text_info
|
490 |
+
if self.cal_perplexity:
|
491 |
+
text_printing += colored(
|
492 |
+
" --> <perplexity:{}>".format(result["perplexity"]),
|
493 |
+
"yellow",
|
494 |
+
)
|
495 |
+
print("Example {}: {}".format(ex_id, text_printing))
|
496 |
+
if save_path:
|
497 |
+
with open(save_path, "w", encoding="utf8") as fout:
|
498 |
+
json.dump(str(results), fout, ensure_ascii=False)
|
499 |
+
print("inference result saved in: {}".format(save_path))
|
500 |
+
except Exception as e:
|
501 |
+
print("Can not save result: {}, Exception: {}".format(text_raw, e))
|
502 |
+
|
503 |
+
if len(results) > 1:
|
504 |
+
print(
|
505 |
+
"CLS Acc:{}%".format(100 * n_correct / n_labeled if n_labeled else "")
|
506 |
+
)
|
507 |
+
print(
|
508 |
+
"AdvDet Acc:{}%".format(
|
509 |
+
100 * n_advdet_correct / n_advdet_labeled
|
510 |
+
if n_advdet_labeled
|
511 |
+
else ""
|
512 |
+
)
|
513 |
+
)
|
514 |
+
|
515 |
+
return results
|
516 |
+
|
517 |
+
def clear_input_samples(self):
|
518 |
+
self.dataset.all_data = []
|
anonymous_demo/functional/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from anonymous_demo.functional.checkpoint.checkpoint_manager import TADCheckpointManager
|
2 |
+
|
3 |
+
from anonymous_demo.functional.config import TADConfigManager
|
anonymous_demo/functional/checkpoint/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .checkpoint_manager import TADCheckpointManager
|
anonymous_demo/functional/checkpoint/checkpoint_manager.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from findfile import find_file
|
3 |
+
|
4 |
+
from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier
|
5 |
+
from anonymous_demo.utils.demo_utils import retry
|
6 |
+
|
7 |
+
|
8 |
+
class CheckpointManager:
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class TADCheckpointManager(CheckpointManager):
|
13 |
+
@staticmethod
|
14 |
+
@retry
|
15 |
+
def get_tad_text_classifier(checkpoint: str = None, eval_batch_size=128, **kwargs):
|
16 |
+
tad_text_classifier = TADTextClassifier(
|
17 |
+
checkpoint, eval_batch_size=eval_batch_size, **kwargs
|
18 |
+
)
|
19 |
+
return tad_text_classifier
|
anonymous_demo/functional/config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .tad_config_manager import TADConfigManager
|
anonymous_demo/functional/config/config_manager.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
one_shot_messages = set()
|
6 |
+
|
7 |
+
|
8 |
+
def config_check(args):
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class ConfigManager(Namespace):
|
13 |
+
def __init__(self, args=None, **kwargs):
|
14 |
+
"""
|
15 |
+
The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
|
16 |
+
:param args: A parameter dict
|
17 |
+
:param kwargs: Same param as Namespce
|
18 |
+
"""
|
19 |
+
if not args:
|
20 |
+
args = {}
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
|
23 |
+
if isinstance(args, Namespace):
|
24 |
+
self.args = vars(args)
|
25 |
+
self.args_call_count = {arg: 0 for arg in vars(args)}
|
26 |
+
else:
|
27 |
+
self.args = args
|
28 |
+
self.args_call_count = {arg: 0 for arg in args}
|
29 |
+
|
30 |
+
def __getattribute__(self, arg_name):
|
31 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
32 |
+
return super().__getattribute__(arg_name)
|
33 |
+
try:
|
34 |
+
value = super().__getattribute__("args")[arg_name]
|
35 |
+
args_call_count = super().__getattribute__("args_call_count")
|
36 |
+
args_call_count[arg_name] += 1
|
37 |
+
super().__setattr__("args_call_count", args_call_count)
|
38 |
+
return value
|
39 |
+
|
40 |
+
except Exception as e:
|
41 |
+
return super().__getattribute__(arg_name)
|
42 |
+
|
43 |
+
def __setattr__(self, arg_name, value):
|
44 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
45 |
+
super().__setattr__(arg_name, value)
|
46 |
+
return
|
47 |
+
try:
|
48 |
+
args = super().__getattribute__("args")
|
49 |
+
args[arg_name] = value
|
50 |
+
super().__setattr__("args", args)
|
51 |
+
args_call_count = super().__getattribute__("args_call_count")
|
52 |
+
|
53 |
+
if arg_name in args_call_count:
|
54 |
+
# args_call_count[arg_name] += 1
|
55 |
+
super().__setattr__("args_call_count", args_call_count)
|
56 |
+
|
57 |
+
else:
|
58 |
+
args_call_count[arg_name] = 0
|
59 |
+
super().__setattr__("args_call_count", args_call_count)
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
super().__setattr__(arg_name, value)
|
63 |
+
|
64 |
+
config_check(args)
|
anonymous_demo/functional/config/tad_config_manager.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from anonymous_demo.functional.config.config_manager import ConfigManager
|
4 |
+
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
|
5 |
+
|
6 |
+
_tad_config_template = {
|
7 |
+
"model": TADBERT,
|
8 |
+
"optimizer": "adamw",
|
9 |
+
"learning_rate": 0.00002,
|
10 |
+
"patience": 99999,
|
11 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
12 |
+
"cache_dataset": True,
|
13 |
+
"warmup_step": -1,
|
14 |
+
"show_metric": False,
|
15 |
+
"max_seq_len": 80,
|
16 |
+
"dropout": 0,
|
17 |
+
"l2reg": 0.000001,
|
18 |
+
"num_epoch": 10,
|
19 |
+
"batch_size": 16,
|
20 |
+
"initializer": "xavier_uniform_",
|
21 |
+
"seed": 52,
|
22 |
+
"polarities_dim": 3,
|
23 |
+
"log_step": 10,
|
24 |
+
"evaluate_begin": 0,
|
25 |
+
"cross_validate_fold": -1,
|
26 |
+
"use_amp": False,
|
27 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
28 |
+
}
|
29 |
+
|
30 |
+
_tad_config_base = {
|
31 |
+
"model": TADBERT,
|
32 |
+
"optimizer": "adamw",
|
33 |
+
"learning_rate": 0.00002,
|
34 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
35 |
+
"cache_dataset": True,
|
36 |
+
"warmup_step": -1,
|
37 |
+
"show_metric": False,
|
38 |
+
"max_seq_len": 80,
|
39 |
+
"patience": 99999,
|
40 |
+
"dropout": 0,
|
41 |
+
"l2reg": 0.000001,
|
42 |
+
"num_epoch": 10,
|
43 |
+
"batch_size": 16,
|
44 |
+
"initializer": "xavier_uniform_",
|
45 |
+
"seed": 52,
|
46 |
+
"polarities_dim": 3,
|
47 |
+
"log_step": 10,
|
48 |
+
"evaluate_begin": 0,
|
49 |
+
"cross_validate_fold": -1
|
50 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
51 |
+
}
|
52 |
+
|
53 |
+
_tad_config_english = {
|
54 |
+
"model": TADBERT,
|
55 |
+
"optimizer": "adamw",
|
56 |
+
"learning_rate": 0.00002,
|
57 |
+
"patience": 99999,
|
58 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
59 |
+
"cache_dataset": True,
|
60 |
+
"warmup_step": -1,
|
61 |
+
"show_metric": False,
|
62 |
+
"max_seq_len": 80,
|
63 |
+
"dropout": 0,
|
64 |
+
"l2reg": 0.000001,
|
65 |
+
"num_epoch": 10,
|
66 |
+
"batch_size": 16,
|
67 |
+
"initializer": "xavier_uniform_",
|
68 |
+
"seed": 52,
|
69 |
+
"polarities_dim": 3,
|
70 |
+
"log_step": 10,
|
71 |
+
"evaluate_begin": 0,
|
72 |
+
"cross_validate_fold": -1
|
73 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
74 |
+
}
|
75 |
+
|
76 |
+
_tad_config_multilingual = {
|
77 |
+
"model": TADBERT,
|
78 |
+
"optimizer": "adamw",
|
79 |
+
"learning_rate": 0.00002,
|
80 |
+
"patience": 99999,
|
81 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
82 |
+
"cache_dataset": True,
|
83 |
+
"warmup_step": -1,
|
84 |
+
"show_metric": False,
|
85 |
+
"max_seq_len": 80,
|
86 |
+
"dropout": 0,
|
87 |
+
"l2reg": 0.000001,
|
88 |
+
"num_epoch": 10,
|
89 |
+
"batch_size": 16,
|
90 |
+
"initializer": "xavier_uniform_",
|
91 |
+
"seed": 52,
|
92 |
+
"polarities_dim": 3,
|
93 |
+
"log_step": 10,
|
94 |
+
"evaluate_begin": 0,
|
95 |
+
"cross_validate_fold": -1
|
96 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
97 |
+
}
|
98 |
+
|
99 |
+
_tad_config_chinese = {
|
100 |
+
"model": TADBERT,
|
101 |
+
"optimizer": "adamw",
|
102 |
+
"learning_rate": 0.00002,
|
103 |
+
"patience": 99999,
|
104 |
+
"cache_dataset": True,
|
105 |
+
"warmup_step": -1,
|
106 |
+
"show_metric": False,
|
107 |
+
"pretrained_bert": "bert-base-chinese",
|
108 |
+
"max_seq_len": 80,
|
109 |
+
"dropout": 0,
|
110 |
+
"l2reg": 0.000001,
|
111 |
+
"num_epoch": 10,
|
112 |
+
"batch_size": 16,
|
113 |
+
"initializer": "xavier_uniform_",
|
114 |
+
"seed": 52,
|
115 |
+
"polarities_dim": 3,
|
116 |
+
"log_step": 10,
|
117 |
+
"evaluate_begin": 0,
|
118 |
+
"cross_validate_fold": -1
|
119 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
class TADConfigManager(ConfigManager):
|
124 |
+
def __init__(self, args, **kwargs):
|
125 |
+
"""
|
126 |
+
Available Params: {'model': BERT,
|
127 |
+
'optimizer': "adamw",
|
128 |
+
'learning_rate': 0.00002,
|
129 |
+
'pretrained_bert': "roberta-base",
|
130 |
+
'cache_dataset': True,
|
131 |
+
'warmup_step': -1,
|
132 |
+
'show_metric': False,
|
133 |
+
'max_seq_len': 80,
|
134 |
+
'patience': 99999,
|
135 |
+
'dropout': 0,
|
136 |
+
'l2reg': 0.000001,
|
137 |
+
'num_epoch': 10,
|
138 |
+
'batch_size': 16,
|
139 |
+
'initializer': 'xavier_uniform_',
|
140 |
+
'seed': {52, 25}
|
141 |
+
'embed_dim': 768,
|
142 |
+
'hidden_dim': 768,
|
143 |
+
'polarities_dim': 3,
|
144 |
+
'log_step': 10,
|
145 |
+
'evaluate_begin': 0,
|
146 |
+
'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training
|
147 |
+
}
|
148 |
+
:param args:
|
149 |
+
:param kwargs:
|
150 |
+
"""
|
151 |
+
super().__init__(args, **kwargs)
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def set_tad_config(configType: str, newitem: dict):
|
155 |
+
if isinstance(newitem, dict):
|
156 |
+
if configType == "template":
|
157 |
+
_tad_config_template.update(newitem)
|
158 |
+
elif configType == "base":
|
159 |
+
_tad_config_base.update(newitem)
|
160 |
+
elif configType == "english":
|
161 |
+
_tad_config_english.update(newitem)
|
162 |
+
elif configType == "chinese":
|
163 |
+
_tad_config_chinese.update(newitem)
|
164 |
+
elif configType == "multilingual":
|
165 |
+
_tad_config_multilingual.update(newitem)
|
166 |
+
elif configType == "glove":
|
167 |
+
_tad_config_glove.update(newitem)
|
168 |
+
else:
|
169 |
+
raise ValueError(
|
170 |
+
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
raise TypeError(
|
174 |
+
"Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}"
|
175 |
+
)
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def set_tad_config_template(newitem):
|
179 |
+
TADConfigManager.set_tad_config("template", newitem)
|
180 |
+
|
181 |
+
@staticmethod
|
182 |
+
def set_tad_config_base(newitem):
|
183 |
+
TADConfigManager.set_tad_config("base", newitem)
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def set_tad_config_english(newitem):
|
187 |
+
TADConfigManager.set_tad_config("english", newitem)
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def set_tad_config_chinese(newitem):
|
191 |
+
TADConfigManager.set_tad_config("chinese", newitem)
|
192 |
+
|
193 |
+
@staticmethod
|
194 |
+
def set_tad_config_multilingual(newitem):
|
195 |
+
TADConfigManager.set_tad_config("multilingual", newitem)
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def set_tad_config_glove(newitem):
|
199 |
+
TADConfigManager.set_tad_config("glove", newitem)
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def get_tad_config_template() -> ConfigManager:
|
203 |
+
_tad_config_template.update(_tad_config_template)
|
204 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def get_tad_config_base() -> ConfigManager:
|
208 |
+
_tad_config_template.update(_tad_config_base)
|
209 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def get_tad_config_english() -> ConfigManager:
|
213 |
+
_tad_config_template.update(_tad_config_english)
|
214 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def get_tad_config_chinese() -> ConfigManager:
|
218 |
+
_tad_config_template.update(_tad_config_chinese)
|
219 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def get_tad_config_multilingual() -> ConfigManager:
|
223 |
+
_tad_config_template.update(_tad_config_multilingual)
|
224 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
def get_tad_config_glove() -> ConfigManager:
|
228 |
+
_tad_config_template.update(_tad_config_glove)
|
229 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
anonymous_demo/functional/dataset/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from anonymous_demo.functional.dataset.dataset_manager import detect_infer_dataset
|
anonymous_demo/functional/dataset/dataset_manager.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from findfile import find_files, find_dir
|
3 |
+
|
4 |
+
filter_key_words = [
|
5 |
+
".py",
|
6 |
+
".md",
|
7 |
+
"readme",
|
8 |
+
"log",
|
9 |
+
"result",
|
10 |
+
"zip",
|
11 |
+
".state_dict",
|
12 |
+
".model",
|
13 |
+
".png",
|
14 |
+
"acc_",
|
15 |
+
"f1_",
|
16 |
+
".backup",
|
17 |
+
".bak",
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def detect_infer_dataset(dataset_path, task="apc"):
|
22 |
+
dataset_file = []
|
23 |
+
if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
|
24 |
+
dataset_file.append(dataset_path)
|
25 |
+
return dataset_file
|
26 |
+
|
27 |
+
for d in dataset_path:
|
28 |
+
if not os.path.exists(d):
|
29 |
+
search_path = find_dir(
|
30 |
+
os.getcwd(),
|
31 |
+
[d, task, "dataset"],
|
32 |
+
exclude_key=filter_key_words,
|
33 |
+
disable_alert=False,
|
34 |
+
)
|
35 |
+
dataset_file += find_files(
|
36 |
+
search_path,
|
37 |
+
[".inference", d],
|
38 |
+
exclude_key=["train."] + filter_key_words,
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
dataset_file += find_files(
|
42 |
+
d, [".inference", task], exclude_key=["train."] + filter_key_words
|
43 |
+
)
|
44 |
+
|
45 |
+
return dataset_file
|
anonymous_demo/network/__init__.py
ADDED
File without changes
|
anonymous_demo/network/lcf_pooler.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class LCF_Pooler(nn.Module):
|
7 |
+
def __init__(self, config):
|
8 |
+
super().__init__()
|
9 |
+
self.config = config
|
10 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
11 |
+
self.activation = nn.Tanh()
|
12 |
+
|
13 |
+
def forward(self, hidden_states, lcf_vec):
|
14 |
+
device = hidden_states.device
|
15 |
+
lcf_vec = lcf_vec.detach().cpu().numpy()
|
16 |
+
|
17 |
+
pooled_output = numpy.zeros(
|
18 |
+
(hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32
|
19 |
+
)
|
20 |
+
hidden_states = hidden_states.detach().cpu().numpy()
|
21 |
+
for i, vec in enumerate(lcf_vec):
|
22 |
+
lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0]
|
23 |
+
pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
|
24 |
+
|
25 |
+
pooled_output = torch.Tensor(pooled_output).to(device)
|
26 |
+
pooled_output = self.dense(pooled_output)
|
27 |
+
pooled_output = self.activation(pooled_output)
|
28 |
+
return pooled_output
|
anonymous_demo/network/lsa.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class LSA(nn.Module):
|
7 |
+
def __init__(self, bert, opt):
|
8 |
+
super(LSA, self).__init__()
|
9 |
+
self.opt = opt
|
10 |
+
|
11 |
+
self.encoder = Encoder(bert.config, opt)
|
12 |
+
self.encoder_left = Encoder(bert.config, opt)
|
13 |
+
self.encoder_right = Encoder(bert.config, opt)
|
14 |
+
self.linear_window_3h = nn.Linear(opt.embed_dim * 3, opt.embed_dim)
|
15 |
+
self.linear_window_2h = nn.Linear(opt.embed_dim * 2, opt.embed_dim)
|
16 |
+
self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
17 |
+
self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
18 |
+
|
19 |
+
def forward(
|
20 |
+
self,
|
21 |
+
global_context_features,
|
22 |
+
spc_mask_vec,
|
23 |
+
lcf_matrix,
|
24 |
+
left_lcf_matrix,
|
25 |
+
right_lcf_matrix,
|
26 |
+
):
|
27 |
+
masked_global_context_features = torch.mul(
|
28 |
+
spc_mask_vec, global_context_features
|
29 |
+
)
|
30 |
+
|
31 |
+
# # --------------------------------------------------- #
|
32 |
+
lcf_features = torch.mul(global_context_features, lcf_matrix)
|
33 |
+
lcf_features = self.encoder(lcf_features)
|
34 |
+
# # --------------------------------------------------- #
|
35 |
+
left_lcf_features = torch.mul(masked_global_context_features, left_lcf_matrix)
|
36 |
+
left_lcf_features = self.encoder_left(left_lcf_features)
|
37 |
+
# # --------------------------------------------------- #
|
38 |
+
right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
|
39 |
+
right_lcf_features = self.encoder_right(right_lcf_features)
|
40 |
+
# # --------------------------------------------------- #
|
41 |
+
if "lr" == self.opt.window or "rl" == self.opt.window:
|
42 |
+
if self.eta1 <= 0 and self.opt.eta != -1:
|
43 |
+
torch.nn.init.uniform_(self.eta1)
|
44 |
+
print("reset eta1 to: {}".format(self.eta1.item()))
|
45 |
+
if self.eta2 <= 0 and self.opt.eta != -1:
|
46 |
+
torch.nn.init.uniform_(self.eta2)
|
47 |
+
print("reset eta2 to: {}".format(self.eta2.item()))
|
48 |
+
if self.opt.eta >= 0:
|
49 |
+
cat_features = torch.cat(
|
50 |
+
(
|
51 |
+
lcf_features,
|
52 |
+
self.eta1 * left_lcf_features,
|
53 |
+
self.eta2 * right_lcf_features,
|
54 |
+
),
|
55 |
+
-1,
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
cat_features = torch.cat(
|
59 |
+
(lcf_features, left_lcf_features, right_lcf_features), -1
|
60 |
+
)
|
61 |
+
sent_out = self.linear_window_3h(cat_features)
|
62 |
+
elif "l" == self.opt.window:
|
63 |
+
sent_out = self.linear_window_2h(
|
64 |
+
torch.cat((lcf_features, self.eta1 * left_lcf_features), -1)
|
65 |
+
)
|
66 |
+
elif "r" == self.opt.window:
|
67 |
+
sent_out = self.linear_window_2h(
|
68 |
+
torch.cat((lcf_features, self.eta2 * right_lcf_features), -1)
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
raise KeyError("Invalid parameter:", self.opt.window)
|
72 |
+
|
73 |
+
return sent_out
|
anonymous_demo/network/sa_encoder.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class BertSelfAttention(nn.Module):
|
9 |
+
def __init__(self, config):
|
10 |
+
super().__init__()
|
11 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
12 |
+
config, "embedding_size"
|
13 |
+
):
|
14 |
+
raise ValueError(
|
15 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
16 |
+
f"heads ({config.num_attention_heads})"
|
17 |
+
)
|
18 |
+
|
19 |
+
self.num_attention_heads = config.num_attention_heads
|
20 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
21 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
22 |
+
|
23 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
24 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
25 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
26 |
+
|
27 |
+
self.dropout = nn.Dropout(
|
28 |
+
config.attention_probs_dropout_prob
|
29 |
+
if hasattr(config, "attention_probs_dropout_prob")
|
30 |
+
else 0
|
31 |
+
)
|
32 |
+
self.position_embedding_type = getattr(
|
33 |
+
config, "position_embedding_type", "absolute"
|
34 |
+
)
|
35 |
+
if (
|
36 |
+
self.position_embedding_type == "relative_key"
|
37 |
+
or self.position_embedding_type == "relative_key_query"
|
38 |
+
):
|
39 |
+
self.max_position_embeddings = config.max_position_embeddings
|
40 |
+
self.distance_embedding = nn.Embedding(
|
41 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
42 |
+
)
|
43 |
+
|
44 |
+
self.is_decoder = config.is_decoder
|
45 |
+
|
46 |
+
def transpose_for_scores(self, x):
|
47 |
+
new_x_shape = x.size()[:-1] + (
|
48 |
+
self.num_attention_heads,
|
49 |
+
self.attention_head_size,
|
50 |
+
)
|
51 |
+
x = x.view(*new_x_shape)
|
52 |
+
return x.permute(0, 2, 1, 3)
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
hidden_states,
|
57 |
+
attention_mask=None,
|
58 |
+
head_mask=None,
|
59 |
+
encoder_hidden_states=None,
|
60 |
+
encoder_attention_mask=None,
|
61 |
+
past_key_value=None,
|
62 |
+
output_attentions=False,
|
63 |
+
):
|
64 |
+
mixed_query_layer = self.query(hidden_states)
|
65 |
+
|
66 |
+
# If this is instantiated as a cross-attention module, the keys
|
67 |
+
# and values come from an encoder; the attention mask needs to be
|
68 |
+
# such that the encoder's padding tokens are not attended to.
|
69 |
+
is_cross_attention = encoder_hidden_states is not None
|
70 |
+
|
71 |
+
if is_cross_attention and past_key_value is not None:
|
72 |
+
# reuse k,v, cross_attentions
|
73 |
+
key_layer = past_key_value[0]
|
74 |
+
value_layer = past_key_value[1]
|
75 |
+
attention_mask = encoder_attention_mask
|
76 |
+
elif is_cross_attention:
|
77 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
78 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
79 |
+
attention_mask = encoder_attention_mask
|
80 |
+
elif past_key_value is not None:
|
81 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
82 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
83 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
84 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
85 |
+
else:
|
86 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
87 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
88 |
+
|
89 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
90 |
+
|
91 |
+
if self.is_decoder:
|
92 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
93 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
94 |
+
# key/value_states (first "if" case)
|
95 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
96 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
97 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
98 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
99 |
+
past_key_value = (key_layer, value_layer)
|
100 |
+
|
101 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
102 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
103 |
+
|
104 |
+
if (
|
105 |
+
self.position_embedding_type == "relative_key"
|
106 |
+
or self.position_embedding_type == "relative_key_query"
|
107 |
+
):
|
108 |
+
seq_length = hidden_states.size()[1]
|
109 |
+
position_ids_l = torch.arange(
|
110 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
111 |
+
).view(-1, 1)
|
112 |
+
position_ids_r = torch.arange(
|
113 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
114 |
+
).view(1, -1)
|
115 |
+
distance = position_ids_l - position_ids_r
|
116 |
+
positional_embedding = self.distance_embedding(
|
117 |
+
distance + self.max_position_embeddings - 1
|
118 |
+
)
|
119 |
+
positional_embedding = positional_embedding.to(
|
120 |
+
dtype=query_layer.dtype
|
121 |
+
) # fp16 compatibility
|
122 |
+
|
123 |
+
if self.position_embedding_type == "relative_key":
|
124 |
+
relative_position_scores = torch.einsum(
|
125 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
126 |
+
)
|
127 |
+
attention_scores = attention_scores + relative_position_scores
|
128 |
+
elif self.position_embedding_type == "relative_key_query":
|
129 |
+
relative_position_scores_query = torch.einsum(
|
130 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
131 |
+
)
|
132 |
+
relative_position_scores_key = torch.einsum(
|
133 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
134 |
+
)
|
135 |
+
attention_scores = (
|
136 |
+
attention_scores
|
137 |
+
+ relative_position_scores_query
|
138 |
+
+ relative_position_scores_key
|
139 |
+
)
|
140 |
+
|
141 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
142 |
+
if attention_mask is not None:
|
143 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
144 |
+
attention_scores = attention_scores + attention_mask
|
145 |
+
|
146 |
+
# Normalize the attention scores to probabilities.
|
147 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
148 |
+
|
149 |
+
# This is actually dropping out entire tokens to attend to, which might
|
150 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
151 |
+
attention_probs = self.dropout(attention_probs)
|
152 |
+
|
153 |
+
# Mask heads if we want to
|
154 |
+
if head_mask is not None:
|
155 |
+
attention_probs = attention_probs * head_mask
|
156 |
+
|
157 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
158 |
+
|
159 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
160 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
161 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
162 |
+
|
163 |
+
outputs = (
|
164 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
165 |
+
)
|
166 |
+
|
167 |
+
if self.is_decoder:
|
168 |
+
outputs = outputs + (past_key_value,)
|
169 |
+
return outputs
|
170 |
+
|
171 |
+
|
172 |
+
class Encoder(nn.Module):
|
173 |
+
def __init__(self, config, opt, layer_num=1):
|
174 |
+
super(Encoder, self).__init__()
|
175 |
+
self.opt = opt
|
176 |
+
self.config = config
|
177 |
+
self.encoder = nn.ModuleList(
|
178 |
+
[SelfAttention(config, opt) for _ in range(layer_num)]
|
179 |
+
)
|
180 |
+
self.tanh = torch.nn.Tanh()
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
for i, enc in enumerate(self.encoder):
|
184 |
+
x = self.tanh(enc(x)[0])
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class SelfAttention(nn.Module):
|
189 |
+
def __init__(self, config, opt):
|
190 |
+
super(SelfAttention, self).__init__()
|
191 |
+
self.opt = opt
|
192 |
+
self.config = config
|
193 |
+
self.SA = BertSelfAttention(config)
|
194 |
+
|
195 |
+
def forward(self, inputs):
|
196 |
+
zero_vec = np.zeros((inputs.size(0), 1, 1, self.opt.max_seq_len))
|
197 |
+
zero_tensor = torch.tensor(zero_vec).float().to(inputs.device)
|
198 |
+
SA_out = self.SA(inputs, zero_tensor)
|
199 |
+
return SA_out
|
anonymous_demo/utils/__init__.py
ADDED
File without changes
|
anonymous_demo/utils/demo_utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import signal
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import gdown
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
import torch
|
13 |
+
import tqdm
|
14 |
+
from autocuda import auto_cuda, auto_cuda_name
|
15 |
+
from findfile import find_files, find_cwd_file, find_file
|
16 |
+
from termcolor import colored
|
17 |
+
from functools import wraps
|
18 |
+
|
19 |
+
from update_checker import parse_version
|
20 |
+
|
21 |
+
from anonymous_demo import __version__
|
22 |
+
|
23 |
+
|
24 |
+
def save_args(config, save_path):
|
25 |
+
f = open(os.path.join(save_path), mode="w", encoding="utf8")
|
26 |
+
for arg in config.args:
|
27 |
+
if config.args_call_count[arg]:
|
28 |
+
f.write("{}: {}\n".format(arg, config.args[arg]))
|
29 |
+
f.close()
|
30 |
+
|
31 |
+
|
32 |
+
def print_args(config, logger=None, mode=0):
|
33 |
+
args = [key for key in sorted(config.args.keys())]
|
34 |
+
for arg in args:
|
35 |
+
if logger:
|
36 |
+
logger.info(
|
37 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
38 |
+
arg, config.args[arg], config.args_call_count[arg]
|
39 |
+
)
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
44 |
+
arg, config.args[arg], config.args_call_count[arg]
|
45 |
+
)
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
50 |
+
if "-100" in label_set:
|
51 |
+
label_to_index = {
|
52 |
+
origin_label: int(idx) - 1 if origin_label != "-100" else -100
|
53 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
54 |
+
}
|
55 |
+
index_to_label = {
|
56 |
+
int(idx) - 1 if origin_label != "-100" else -100: origin_label
|
57 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
58 |
+
}
|
59 |
+
else:
|
60 |
+
label_to_index = {
|
61 |
+
origin_label: int(idx)
|
62 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
63 |
+
}
|
64 |
+
index_to_label = {
|
65 |
+
int(idx): origin_label
|
66 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
67 |
+
}
|
68 |
+
if "index_to_label" not in opt.args:
|
69 |
+
opt.index_to_label = index_to_label
|
70 |
+
opt.label_to_index = label_to_index
|
71 |
+
|
72 |
+
if opt.index_to_label != index_to_label:
|
73 |
+
opt.index_to_label.update(index_to_label)
|
74 |
+
opt.label_to_index.update(label_to_index)
|
75 |
+
num_label = {l: 0 for l in label_set}
|
76 |
+
num_label["Sum"] = len(all_data)
|
77 |
+
for item in all_data:
|
78 |
+
try:
|
79 |
+
num_label[item[label_name]] += 1
|
80 |
+
item[label_name] = label_to_index[item[label_name]]
|
81 |
+
except Exception as e:
|
82 |
+
# print(e)
|
83 |
+
num_label[item.polarity] += 1
|
84 |
+
item.polarity = label_to_index[item.polarity]
|
85 |
+
print("Dataset Label Details: {}".format(num_label))
|
86 |
+
|
87 |
+
|
88 |
+
def check_and_fix_IOB_labels(label_map, opt):
|
89 |
+
index_to_IOB_label = {
|
90 |
+
int(label_map[origin_label]): origin_label for origin_label in label_map
|
91 |
+
}
|
92 |
+
opt.index_to_IOB_label = index_to_IOB_label
|
93 |
+
|
94 |
+
|
95 |
+
def get_device(auto_device):
|
96 |
+
if isinstance(auto_device, str) and auto_device == "allcuda":
|
97 |
+
device = "cuda"
|
98 |
+
elif isinstance(auto_device, str):
|
99 |
+
device = auto_device
|
100 |
+
elif isinstance(auto_device, bool):
|
101 |
+
device = auto_cuda() if auto_device else "cpu"
|
102 |
+
else:
|
103 |
+
device = auto_cuda()
|
104 |
+
try:
|
105 |
+
torch.device(device)
|
106 |
+
except RuntimeError as e:
|
107 |
+
print(
|
108 |
+
colored("Device assignment error: {}, redirect to CPU".format(e), "red")
|
109 |
+
)
|
110 |
+
device = "cpu"
|
111 |
+
device_name = auto_cuda_name()
|
112 |
+
return device, device_name
|
113 |
+
|
114 |
+
|
115 |
+
def _load_word_vec(path, word2idx=None, embed_dim=300):
|
116 |
+
fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore")
|
117 |
+
word_vec = {}
|
118 |
+
for line in tqdm.tqdm(fin.readlines(), postfix="Loading embedding file..."):
|
119 |
+
tokens = line.rstrip().split()
|
120 |
+
word, vec = " ".join(tokens[:-embed_dim]), tokens[-embed_dim:]
|
121 |
+
if word in word2idx.keys():
|
122 |
+
word_vec[word] = np.asarray(vec, dtype="float32")
|
123 |
+
return word_vec
|
124 |
+
|
125 |
+
|
126 |
+
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
|
127 |
+
if not os.path.exists("run"):
|
128 |
+
os.makedirs("run")
|
129 |
+
embed_matrix_path = "run/{}".format(os.path.join(opt.dataset_name, dat_fname))
|
130 |
+
if os.path.exists(embed_matrix_path):
|
131 |
+
print(
|
132 |
+
colored(
|
133 |
+
"Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)".format(
|
134 |
+
embed_matrix_path
|
135 |
+
),
|
136 |
+
"green",
|
137 |
+
)
|
138 |
+
)
|
139 |
+
embedding_matrix = pickle.load(open(embed_matrix_path, "rb"))
|
140 |
+
else:
|
141 |
+
glove_path = prepare_glove840_embedding(embed_matrix_path)
|
142 |
+
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
|
143 |
+
|
144 |
+
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
|
145 |
+
|
146 |
+
for word, i in tqdm.tqdm(
|
147 |
+
word2idx.items(),
|
148 |
+
postfix=colored("Building embedding_matrix {}".format(dat_fname), "yellow"),
|
149 |
+
):
|
150 |
+
vec = word_vec.get(word)
|
151 |
+
if vec is not None:
|
152 |
+
embedding_matrix[i] = vec
|
153 |
+
pickle.dump(embedding_matrix, open(embed_matrix_path, "wb"))
|
154 |
+
return embedding_matrix
|
155 |
+
|
156 |
+
|
157 |
+
def pad_and_truncate(
|
158 |
+
sequence, maxlen, dtype="int64", padding="post", truncating="post", value=0
|
159 |
+
):
|
160 |
+
x = (np.ones(maxlen) * value).astype(dtype)
|
161 |
+
if truncating == "pre":
|
162 |
+
trunc = sequence[-maxlen:]
|
163 |
+
else:
|
164 |
+
trunc = sequence[:maxlen]
|
165 |
+
trunc = np.asarray(trunc, dtype=dtype)
|
166 |
+
if padding == "post":
|
167 |
+
x[: len(trunc)] = trunc
|
168 |
+
else:
|
169 |
+
x[-len(trunc) :] = trunc
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class TransformerConnectionError(ValueError):
|
174 |
+
def __init__(self):
|
175 |
+
pass
|
176 |
+
|
177 |
+
|
178 |
+
def retry(f):
|
179 |
+
@wraps(f)
|
180 |
+
def decorated(*args, **kwargs):
|
181 |
+
count = 5
|
182 |
+
while count:
|
183 |
+
try:
|
184 |
+
return f(*args, **kwargs)
|
185 |
+
except (
|
186 |
+
TransformerConnectionError,
|
187 |
+
requests.exceptions.RequestException,
|
188 |
+
requests.exceptions.ConnectionError,
|
189 |
+
requests.exceptions.HTTPError,
|
190 |
+
requests.exceptions.ConnectTimeout,
|
191 |
+
requests.exceptions.ProxyError,
|
192 |
+
requests.exceptions.SSLError,
|
193 |
+
requests.exceptions.BaseHTTPError,
|
194 |
+
) as e:
|
195 |
+
print(colored("Training Exception: {}, will retry later".format(e)))
|
196 |
+
time.sleep(60)
|
197 |
+
count -= 1
|
198 |
+
|
199 |
+
return decorated
|
200 |
+
|
201 |
+
|
202 |
+
def save_json(dic, save_path):
|
203 |
+
if isinstance(dic, str):
|
204 |
+
dic = eval(dic)
|
205 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
206 |
+
# f.write(str(dict))
|
207 |
+
str_ = json.dumps(dic, ensure_ascii=False)
|
208 |
+
f.write(str_)
|
209 |
+
|
210 |
+
|
211 |
+
def load_json(save_path):
|
212 |
+
with open(save_path, "r", encoding="utf-8") as f:
|
213 |
+
data = f.readline().strip()
|
214 |
+
print(type(data), data)
|
215 |
+
dic = json.loads(data)
|
216 |
+
return dic
|
217 |
+
|
218 |
+
|
219 |
+
def init_optimizer(optimizer):
|
220 |
+
optimizers = {
|
221 |
+
"adadelta": torch.optim.Adadelta, # default lr=1.0
|
222 |
+
"adagrad": torch.optim.Adagrad, # default lr=0.01
|
223 |
+
"adam": torch.optim.Adam, # default lr=0.001
|
224 |
+
"adamax": torch.optim.Adamax, # default lr=0.002
|
225 |
+
"asgd": torch.optim.ASGD, # default lr=0.01
|
226 |
+
"rmsprop": torch.optim.RMSprop, # default lr=0.01
|
227 |
+
"sgd": torch.optim.SGD,
|
228 |
+
"adamw": torch.optim.AdamW,
|
229 |
+
torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
|
230 |
+
torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
|
231 |
+
torch.optim.Adam: torch.optim.Adam, # default lr=0.001
|
232 |
+
torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002
|
233 |
+
torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01
|
234 |
+
torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01
|
235 |
+
torch.optim.SGD: torch.optim.SGD,
|
236 |
+
torch.optim.AdamW: torch.optim.AdamW,
|
237 |
+
}
|
238 |
+
if optimizer in optimizers:
|
239 |
+
return optimizers[optimizer]
|
240 |
+
elif hasattr(torch.optim, optimizer.__name__):
|
241 |
+
return optimizer
|
242 |
+
else:
|
243 |
+
raise KeyError(
|
244 |
+
"Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer".format(
|
245 |
+
optimizer
|
246 |
+
)
|
247 |
+
)
|
anonymous_demo/utils/logger.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
|
6 |
+
import termcolor
|
7 |
+
|
8 |
+
today = time.strftime("%Y%m%d %H%M%S", time.localtime(time.time()))
|
9 |
+
|
10 |
+
|
11 |
+
def get_logger(log_path, log_name="", log_type="training_log"):
|
12 |
+
if not log_path:
|
13 |
+
log_dir = os.path.join(log_path, "logs")
|
14 |
+
else:
|
15 |
+
log_dir = os.path.join(".", "logs")
|
16 |
+
|
17 |
+
full_path = os.path.join(log_dir, log_name + "_" + today)
|
18 |
+
if not os.path.exists(full_path):
|
19 |
+
os.makedirs(full_path)
|
20 |
+
log_path = os.path.join(full_path, "{}.log".format(log_type))
|
21 |
+
logger = logging.getLogger(log_name)
|
22 |
+
if not logger.handlers:
|
23 |
+
formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
|
24 |
+
|
25 |
+
file_handler = logging.FileHandler(log_path, encoding="utf8")
|
26 |
+
file_handler.setFormatter(formatter)
|
27 |
+
file_handler.setLevel(logging.INFO)
|
28 |
+
|
29 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
30 |
+
console_handler.formatter = formatter
|
31 |
+
console_handler.setLevel(logging.INFO)
|
32 |
+
|
33 |
+
logger.addHandler(file_handler)
|
34 |
+
logger.addHandler(console_handler)
|
35 |
+
|
36 |
+
logger.setLevel(logging.INFO)
|
37 |
+
|
38 |
+
return logger
|
checkpoints.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f77ae4a45785183900ee874cb318a16b0e2f173b31749a2555215aca93672f26
|
3 |
+
size 2456834455
|
text_defense/201.SST2/stsa.binary.dev.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/201.SST2/stsa.binary.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/201.SST2/stsa.binary.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/202.IMDB10K/imdb10k.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/202.IMDB10K/imdb10k.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/202.IMDB10K/imdb10k.valid.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/204.AGNews10K/AGNews10K.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/204.AGNews10K/AGNews10K.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/204.AGNews10K/AGNews10K.valid.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat
ADDED
The diff for this file is too large to render.
See raw diff
|
|
textattack/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Welcome to the API references for TextAttack!
|
2 |
+
|
3 |
+
What is TextAttack?
|
4 |
+
|
5 |
+
`TextAttack <https://github.com/QData/TextAttack>`__ is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
|
6 |
+
|
7 |
+
TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It's also useful for NLP model training, adversarial training, and data augmentation.
|
8 |
+
|
9 |
+
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
|
10 |
+
"""
|
11 |
+
from .attack_args import AttackArgs, CommandLineAttackArgs
|
12 |
+
from .augment_args import AugmenterArgs
|
13 |
+
from .dataset_args import DatasetArgs
|
14 |
+
from .model_args import ModelArgs
|
15 |
+
from .training_args import TrainingArgs, CommandLineTrainingArgs
|
16 |
+
from .attack import Attack
|
17 |
+
from .attacker import Attacker
|
18 |
+
from .trainer import Trainer
|
19 |
+
from .metrics import Metric
|
20 |
+
|
21 |
+
from . import (
|
22 |
+
attack_recipes,
|
23 |
+
attack_results,
|
24 |
+
augmentation,
|
25 |
+
commands,
|
26 |
+
constraints,
|
27 |
+
datasets,
|
28 |
+
goal_function_results,
|
29 |
+
goal_functions,
|
30 |
+
loggers,
|
31 |
+
metrics,
|
32 |
+
models,
|
33 |
+
search_methods,
|
34 |
+
shared,
|
35 |
+
transformations,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
name = "textattack"
|
textattack/__main__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
import textattack
|
5 |
+
|
6 |
+
textattack.commands.textattack_cli.main()
|
textattack/attack.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Attack Class
|
3 |
+
============
|
4 |
+
"""
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
from typing import List, Union
|
8 |
+
|
9 |
+
import lru
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import textattack
|
13 |
+
from textattack.attack_results import (
|
14 |
+
FailedAttackResult,
|
15 |
+
MaximizedAttackResult,
|
16 |
+
SkippedAttackResult,
|
17 |
+
SuccessfulAttackResult,
|
18 |
+
)
|
19 |
+
from textattack.constraints import Constraint, PreTransformationConstraint
|
20 |
+
from textattack.goal_function_results import GoalFunctionResultStatus
|
21 |
+
from textattack.goal_functions import GoalFunction
|
22 |
+
from textattack.models.wrappers import ModelWrapper
|
23 |
+
from textattack.search_methods import SearchMethod
|
24 |
+
from textattack.shared import AttackedText, utils
|
25 |
+
from textattack.transformations import CompositeTransformation, Transformation
|
26 |
+
|
27 |
+
|
28 |
+
class Attack:
|
29 |
+
"""An attack generates adversarial examples on text.
|
30 |
+
|
31 |
+
An attack is comprised of a goal function, constraints, transformation, and a search method. Use :meth:`attack` method to attack one sample at a time.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
goal_function (:class:`~textattack.goal_functions.GoalFunction`):
|
35 |
+
A function for determining how well a perturbation is doing at achieving the attack's goal.
|
36 |
+
constraints (list of :class:`~textattack.constraints.Constraint` or :class:`~textattack.constraints.PreTransformationConstraint`):
|
37 |
+
A list of constraints to add to the attack, defining which perturbations are valid.
|
38 |
+
transformation (:class:`~textattack.transformations.Transformation`):
|
39 |
+
The transformation applied at each step of the attack.
|
40 |
+
search_method (:class:`~textattack.search_methods.SearchMethod`):
|
41 |
+
The method for exploring the search space of possible perturbations
|
42 |
+
transformation_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
|
43 |
+
The number of items to keep in the transformations cache
|
44 |
+
constraint_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
|
45 |
+
The number of items to keep in the constraints cache
|
46 |
+
|
47 |
+
Example::
|
48 |
+
|
49 |
+
>>> import textattack
|
50 |
+
>>> import transformers
|
51 |
+
|
52 |
+
>>> # Load model, tokenizer, and model_wrapper
|
53 |
+
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
|
54 |
+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
|
55 |
+
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
56 |
+
|
57 |
+
>>> # Construct our four components for `Attack`
|
58 |
+
>>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
|
59 |
+
>>> from textattack.constraints.semantics import WordEmbeddingDistance
|
60 |
+
|
61 |
+
>>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
62 |
+
>>> constraints = [
|
63 |
+
... RepeatModification(),
|
64 |
+
... StopwordModification()
|
65 |
+
... WordEmbeddingDistance(min_cos_sim=0.9)
|
66 |
+
... ]
|
67 |
+
>>> transformation = WordSwapEmbedding(max_candidates=50)
|
68 |
+
>>> search_method = GreedyWordSwapWIR(wir_method="delete")
|
69 |
+
|
70 |
+
>>> # Construct the actual attack
|
71 |
+
>>> attack = Attack(goal_function, constraints, transformation, search_method)
|
72 |
+
|
73 |
+
>>> input_text = "I really enjoyed the new movie that came out last month."
|
74 |
+
>>> label = 1 #Positive
|
75 |
+
>>> attack_result = attack.attack(input_text, label)
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
goal_function: GoalFunction,
|
81 |
+
constraints: List[Union[Constraint, PreTransformationConstraint]],
|
82 |
+
transformation: Transformation,
|
83 |
+
search_method: SearchMethod,
|
84 |
+
transformation_cache_size=2**15,
|
85 |
+
constraint_cache_size=2**15,
|
86 |
+
):
|
87 |
+
"""Initialize an attack object.
|
88 |
+
|
89 |
+
Attacks can be run multiple times.
|
90 |
+
"""
|
91 |
+
assert isinstance(
|
92 |
+
goal_function, GoalFunction
|
93 |
+
), f"`goal_function` must be of type `textattack.goal_functions.GoalFunction`, but got type `{type(goal_function)}`."
|
94 |
+
assert isinstance(
|
95 |
+
constraints, list
|
96 |
+
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
|
97 |
+
for c in constraints:
|
98 |
+
assert isinstance(
|
99 |
+
c, (Constraint, PreTransformationConstraint)
|
100 |
+
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
|
101 |
+
assert isinstance(
|
102 |
+
transformation, Transformation
|
103 |
+
), f"`transformation` must be of type `textattack.transformations.Transformation`, but got type `{type(transformation)}`."
|
104 |
+
assert isinstance(
|
105 |
+
search_method, SearchMethod
|
106 |
+
), f"`search_method` must be of type `textattack.search_methods.SearchMethod`, but got type `{type(search_method)}`."
|
107 |
+
|
108 |
+
self.goal_function = goal_function
|
109 |
+
self.search_method = search_method
|
110 |
+
self.transformation = transformation
|
111 |
+
self.is_black_box = (
|
112 |
+
getattr(transformation, "is_black_box", True) and search_method.is_black_box
|
113 |
+
)
|
114 |
+
|
115 |
+
if not self.search_method.check_transformation_compatibility(
|
116 |
+
self.transformation
|
117 |
+
):
|
118 |
+
raise ValueError(
|
119 |
+
f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
|
120 |
+
)
|
121 |
+
|
122 |
+
self.constraints = []
|
123 |
+
self.pre_transformation_constraints = []
|
124 |
+
for constraint in constraints:
|
125 |
+
if isinstance(
|
126 |
+
constraint,
|
127 |
+
textattack.constraints.PreTransformationConstraint,
|
128 |
+
):
|
129 |
+
self.pre_transformation_constraints.append(constraint)
|
130 |
+
else:
|
131 |
+
self.constraints.append(constraint)
|
132 |
+
|
133 |
+
# Check if we can use transformation cache for our transformation.
|
134 |
+
if not self.transformation.deterministic:
|
135 |
+
self.use_transformation_cache = False
|
136 |
+
elif isinstance(self.transformation, CompositeTransformation):
|
137 |
+
self.use_transformation_cache = True
|
138 |
+
for t in self.transformation.transformations:
|
139 |
+
if not t.deterministic:
|
140 |
+
self.use_transformation_cache = False
|
141 |
+
break
|
142 |
+
else:
|
143 |
+
self.use_transformation_cache = True
|
144 |
+
self.transformation_cache_size = transformation_cache_size
|
145 |
+
self.transformation_cache = lru.LRU(transformation_cache_size)
|
146 |
+
|
147 |
+
self.constraint_cache_size = constraint_cache_size
|
148 |
+
self.constraints_cache = lru.LRU(constraint_cache_size)
|
149 |
+
|
150 |
+
# Give search method access to functions for getting transformations and evaluating them
|
151 |
+
self.search_method.get_transformations = self.get_transformations
|
152 |
+
# Give search method access to self.goal_function for model query count, etc.
|
153 |
+
self.search_method.goal_function = self.goal_function
|
154 |
+
# The search method only needs access to the first argument. The second is only used
|
155 |
+
# by the attack class when checking whether to skip the sample
|
156 |
+
self.search_method.get_goal_results = self.goal_function.get_results
|
157 |
+
|
158 |
+
# Give search method access to get indices which need to be ordered / searched
|
159 |
+
self.search_method.get_indices_to_order = self.get_indices_to_order
|
160 |
+
|
161 |
+
self.search_method.filter_transformations = self.filter_transformations
|
162 |
+
|
163 |
+
def clear_cache(self, recursive=True):
|
164 |
+
self.constraints_cache.clear()
|
165 |
+
if self.use_transformation_cache:
|
166 |
+
self.transformation_cache.clear()
|
167 |
+
if recursive:
|
168 |
+
self.goal_function.clear_cache()
|
169 |
+
for constraint in self.constraints:
|
170 |
+
if hasattr(constraint, "clear_cache"):
|
171 |
+
constraint.clear_cache()
|
172 |
+
|
173 |
+
def cpu_(self):
|
174 |
+
"""Move any `torch.nn.Module` models that are part of Attack to CPU."""
|
175 |
+
visited = set()
|
176 |
+
|
177 |
+
def to_cpu(obj):
|
178 |
+
visited.add(id(obj))
|
179 |
+
if isinstance(obj, torch.nn.Module):
|
180 |
+
obj.cpu()
|
181 |
+
elif isinstance(
|
182 |
+
obj,
|
183 |
+
(
|
184 |
+
Attack,
|
185 |
+
GoalFunction,
|
186 |
+
Transformation,
|
187 |
+
SearchMethod,
|
188 |
+
Constraint,
|
189 |
+
PreTransformationConstraint,
|
190 |
+
ModelWrapper,
|
191 |
+
),
|
192 |
+
):
|
193 |
+
for key in obj.__dict__:
|
194 |
+
s_obj = obj.__dict__[key]
|
195 |
+
if id(s_obj) not in visited:
|
196 |
+
to_cpu(s_obj)
|
197 |
+
elif isinstance(obj, (list, tuple)):
|
198 |
+
for item in obj:
|
199 |
+
if id(item) not in visited and isinstance(
|
200 |
+
item, (Transformation, Constraint, PreTransformationConstraint)
|
201 |
+
):
|
202 |
+
to_cpu(item)
|
203 |
+
|
204 |
+
to_cpu(self)
|
205 |
+
|
206 |
+
def cuda_(self):
|
207 |
+
"""Move any `torch.nn.Module` models that are part of Attack to GPU."""
|
208 |
+
visited = set()
|
209 |
+
|
210 |
+
def to_cuda(obj):
|
211 |
+
visited.add(id(obj))
|
212 |
+
if isinstance(obj, torch.nn.Module):
|
213 |
+
obj.to(textattack.shared.utils.device)
|
214 |
+
elif isinstance(
|
215 |
+
obj,
|
216 |
+
(
|
217 |
+
Attack,
|
218 |
+
GoalFunction,
|
219 |
+
Transformation,
|
220 |
+
SearchMethod,
|
221 |
+
Constraint,
|
222 |
+
PreTransformationConstraint,
|
223 |
+
ModelWrapper,
|
224 |
+
),
|
225 |
+
):
|
226 |
+
for key in obj.__dict__:
|
227 |
+
s_obj = obj.__dict__[key]
|
228 |
+
if id(s_obj) not in visited:
|
229 |
+
to_cuda(s_obj)
|
230 |
+
elif isinstance(obj, (list, tuple)):
|
231 |
+
for item in obj:
|
232 |
+
if id(item) not in visited and isinstance(
|
233 |
+
item, (Transformation, Constraint, PreTransformationConstraint)
|
234 |
+
):
|
235 |
+
to_cuda(item)
|
236 |
+
|
237 |
+
to_cuda(self)
|
238 |
+
|
239 |
+
def get_indices_to_order(self, current_text, **kwargs):
|
240 |
+
"""Applies ``pre_transformation_constraints`` to ``text`` to get all
|
241 |
+
the indices that can be used to search and order.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered.
|
245 |
+
Returns:
|
246 |
+
The length and the filtered list of indices which search methods can use to search/order.
|
247 |
+
"""
|
248 |
+
|
249 |
+
indices_to_order = self.transformation(
|
250 |
+
current_text,
|
251 |
+
pre_transformation_constraints=self.pre_transformation_constraints,
|
252 |
+
return_indices=True,
|
253 |
+
**kwargs,
|
254 |
+
)
|
255 |
+
|
256 |
+
len_text = len(indices_to_order)
|
257 |
+
|
258 |
+
# Convert indices_to_order to list for easier shuffling later
|
259 |
+
return len_text, list(indices_to_order)
|
260 |
+
|
261 |
+
def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
|
262 |
+
"""Applies ``self.transformation`` to ``text``, then filters the list
|
263 |
+
of possible transformations through the applicable constraints.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
current_text: The current ``AttackedText`` on which to perform the transformations.
|
267 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
268 |
+
Returns:
|
269 |
+
A filtered list of transformations where each transformation matches the constraints
|
270 |
+
"""
|
271 |
+
transformed_texts = self.transformation(
|
272 |
+
current_text,
|
273 |
+
pre_transformation_constraints=self.pre_transformation_constraints,
|
274 |
+
**kwargs,
|
275 |
+
)
|
276 |
+
|
277 |
+
return transformed_texts
|
278 |
+
|
279 |
+
def get_transformations(self, current_text, original_text=None, **kwargs):
|
280 |
+
"""Applies ``self.transformation`` to ``text``, then filters the list
|
281 |
+
of possible transformations through the applicable constraints.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
current_text: The current ``AttackedText`` on which to perform the transformations.
|
285 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
286 |
+
Returns:
|
287 |
+
A filtered list of transformations where each transformation matches the constraints
|
288 |
+
"""
|
289 |
+
if not self.transformation:
|
290 |
+
raise RuntimeError(
|
291 |
+
"Cannot call `get_transformations` without a transformation."
|
292 |
+
)
|
293 |
+
|
294 |
+
if self.use_transformation_cache:
|
295 |
+
cache_key = tuple([current_text] + sorted(kwargs.items()))
|
296 |
+
if utils.hashable(cache_key) and cache_key in self.transformation_cache:
|
297 |
+
# promote transformed_text to the top of the LRU cache
|
298 |
+
self.transformation_cache[cache_key] = self.transformation_cache[
|
299 |
+
cache_key
|
300 |
+
]
|
301 |
+
transformed_texts = list(self.transformation_cache[cache_key])
|
302 |
+
else:
|
303 |
+
transformed_texts = self._get_transformations_uncached(
|
304 |
+
current_text, original_text, **kwargs
|
305 |
+
)
|
306 |
+
if utils.hashable(cache_key):
|
307 |
+
self.transformation_cache[cache_key] = tuple(transformed_texts)
|
308 |
+
else:
|
309 |
+
transformed_texts = self._get_transformations_uncached(
|
310 |
+
current_text, original_text, **kwargs
|
311 |
+
)
|
312 |
+
|
313 |
+
return self.filter_transformations(
|
314 |
+
transformed_texts, current_text, original_text
|
315 |
+
)
|
316 |
+
|
317 |
+
def _filter_transformations_uncached(
|
318 |
+
self, transformed_texts, current_text, original_text=None
|
319 |
+
):
|
320 |
+
"""Filters a list of potential transformed texts based on
|
321 |
+
``self.constraints``
|
322 |
+
|
323 |
+
Args:
|
324 |
+
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
325 |
+
current_text: The current ``AttackedText`` on which the transformation was applied.
|
326 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
327 |
+
"""
|
328 |
+
filtered_texts = transformed_texts[:]
|
329 |
+
for C in self.constraints:
|
330 |
+
if len(filtered_texts) == 0:
|
331 |
+
break
|
332 |
+
if C.compare_against_original:
|
333 |
+
if not original_text:
|
334 |
+
raise ValueError(
|
335 |
+
f"Missing `original_text` argument when constraint {type(C)} is set to compare against `original_text`"
|
336 |
+
)
|
337 |
+
|
338 |
+
filtered_texts = C.call_many(filtered_texts, original_text)
|
339 |
+
else:
|
340 |
+
filtered_texts = C.call_many(filtered_texts, current_text)
|
341 |
+
# Default to false for all original transformations.
|
342 |
+
for original_transformed_text in transformed_texts:
|
343 |
+
self.constraints_cache[(current_text, original_transformed_text)] = False
|
344 |
+
# Set unfiltered transformations to True in the cache.
|
345 |
+
for filtered_text in filtered_texts:
|
346 |
+
self.constraints_cache[(current_text, filtered_text)] = True
|
347 |
+
return filtered_texts
|
348 |
+
|
349 |
+
def filter_transformations(
|
350 |
+
self, transformed_texts, current_text, original_text=None
|
351 |
+
):
|
352 |
+
"""Filters a list of potential transformed texts based on
|
353 |
+
``self.constraints`` Utilizes an LRU cache to attempt to avoid
|
354 |
+
recomputing common transformations.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
358 |
+
current_text: The current ``AttackedText`` on which the transformation was applied.
|
359 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
360 |
+
"""
|
361 |
+
# Remove any occurences of current_text in transformed_texts
|
362 |
+
transformed_texts = [
|
363 |
+
t for t in transformed_texts if t.text != current_text.text
|
364 |
+
]
|
365 |
+
# Populate cache with transformed_texts
|
366 |
+
uncached_texts = []
|
367 |
+
filtered_texts = []
|
368 |
+
for transformed_text in transformed_texts:
|
369 |
+
if (current_text, transformed_text) not in self.constraints_cache:
|
370 |
+
uncached_texts.append(transformed_text)
|
371 |
+
else:
|
372 |
+
# promote transformed_text to the top of the LRU cache
|
373 |
+
self.constraints_cache[
|
374 |
+
(current_text, transformed_text)
|
375 |
+
] = self.constraints_cache[(current_text, transformed_text)]
|
376 |
+
if self.constraints_cache[(current_text, transformed_text)]:
|
377 |
+
filtered_texts.append(transformed_text)
|
378 |
+
filtered_texts += self._filter_transformations_uncached(
|
379 |
+
uncached_texts, current_text, original_text=original_text
|
380 |
+
)
|
381 |
+
# Sort transformations to ensure order is preserved between runs
|
382 |
+
filtered_texts.sort(key=lambda t: t.text)
|
383 |
+
return filtered_texts
|
384 |
+
|
385 |
+
def _attack(self, initial_result):
|
386 |
+
"""Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
|
387 |
+
``initial_result``.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
|
394 |
+
or ``MaximizedAttackResult``.
|
395 |
+
"""
|
396 |
+
final_result = self.search_method(initial_result)
|
397 |
+
self.clear_cache()
|
398 |
+
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
399 |
+
result = SuccessfulAttackResult(
|
400 |
+
initial_result,
|
401 |
+
final_result,
|
402 |
+
)
|
403 |
+
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
|
404 |
+
result = FailedAttackResult(
|
405 |
+
initial_result,
|
406 |
+
final_result,
|
407 |
+
)
|
408 |
+
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
|
409 |
+
result = MaximizedAttackResult(
|
410 |
+
initial_result,
|
411 |
+
final_result,
|
412 |
+
)
|
413 |
+
else:
|
414 |
+
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
|
415 |
+
return result
|
416 |
+
|
417 |
+
def attack(self, example, ground_truth_output):
|
418 |
+
"""Attack a single example.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`):
|
422 |
+
Example to attack. It can be a single string or an `OrderedDict` where
|
423 |
+
keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx.
|
424 |
+
Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input.
|
425 |
+
ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`):
|
426 |
+
Ground truth output of `example`.
|
427 |
+
For classification tasks, it should be an integer representing the ground truth label.
|
428 |
+
For regression tasks (e.g. STS), it should be the target value.
|
429 |
+
For seq2seq tasks (e.g. translation), it should be the target string.
|
430 |
+
Returns:
|
431 |
+
:class:`~textattack.attack_results.AttackResult` that represents the result of the attack.
|
432 |
+
"""
|
433 |
+
assert isinstance(
|
434 |
+
example, (str, OrderedDict, AttackedText)
|
435 |
+
), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
|
436 |
+
if isinstance(example, (str, OrderedDict)):
|
437 |
+
example = AttackedText(example)
|
438 |
+
|
439 |
+
assert isinstance(
|
440 |
+
ground_truth_output, (int, str)
|
441 |
+
), "`ground_truth_output` must either be `str` or `int`."
|
442 |
+
goal_function_result, _ = self.goal_function.init_attack_example(
|
443 |
+
example, ground_truth_output
|
444 |
+
)
|
445 |
+
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
|
446 |
+
return SkippedAttackResult(goal_function_result)
|
447 |
+
else:
|
448 |
+
result = self._attack(goal_function_result)
|
449 |
+
return result
|
450 |
+
|
451 |
+
def __repr__(self):
|
452 |
+
"""Prints attack parameters in a human-readable string.
|
453 |
+
|
454 |
+
Inspired by the readability of printing PyTorch nn.Modules:
|
455 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
|
456 |
+
"""
|
457 |
+
main_str = "Attack" + "("
|
458 |
+
lines = []
|
459 |
+
|
460 |
+
lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2))
|
461 |
+
# self.goal_function
|
462 |
+
lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2))
|
463 |
+
# self.transformation
|
464 |
+
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
|
465 |
+
# self.constraints
|
466 |
+
constraints_lines = []
|
467 |
+
constraints = self.constraints + self.pre_transformation_constraints
|
468 |
+
if len(constraints):
|
469 |
+
for i, constraint in enumerate(constraints):
|
470 |
+
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
|
471 |
+
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
|
472 |
+
else:
|
473 |
+
constraints_str = "None"
|
474 |
+
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
|
475 |
+
# self.is_black_box
|
476 |
+
lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2))
|
477 |
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
478 |
+
main_str += ")"
|
479 |
+
return main_str
|
480 |
+
|
481 |
+
def __getstate__(self):
|
482 |
+
state = self.__dict__.copy()
|
483 |
+
state["transformation_cache"] = None
|
484 |
+
state["constraints_cache"] = None
|
485 |
+
return state
|
486 |
+
|
487 |
+
def __setstate__(self, state):
|
488 |
+
self.__dict__ = state
|
489 |
+
self.transformation_cache = lru.LRU(self.transformation_cache_size)
|
490 |
+
self.constraints_cache = lru.LRU(self.constraint_cache_size)
|
491 |
+
|
492 |
+
__str__ = __repr__
|
textattack/attack_args.py
ADDED
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
AttackArgs Class
|
3 |
+
================
|
4 |
+
"""
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
from typing import Dict, Optional
|
12 |
+
|
13 |
+
import textattack
|
14 |
+
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
|
15 |
+
|
16 |
+
from .attack import Attack
|
17 |
+
from .dataset_args import DatasetArgs
|
18 |
+
from .model_args import ModelArgs
|
19 |
+
|
20 |
+
ATTACK_RECIPE_NAMES = {
|
21 |
+
"alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
|
22 |
+
"bae": "textattack.attack_recipes.BAEGarg2019",
|
23 |
+
"bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
|
24 |
+
"faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
|
25 |
+
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
|
26 |
+
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
|
27 |
+
"input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
|
28 |
+
"kuleshov": "textattack.attack_recipes.Kuleshov2017",
|
29 |
+
"morpheus": "textattack.attack_recipes.MorpheusTan2020",
|
30 |
+
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
|
31 |
+
"textbugger": "textattack.attack_recipes.TextBuggerLi2018",
|
32 |
+
"textfooler": "textattack.attack_recipes.TextFoolerJin2019",
|
33 |
+
"pwws": "textattack.attack_recipes.PWWSRen2019",
|
34 |
+
"iga": "textattack.attack_recipes.IGAWang2019",
|
35 |
+
"pruthi": "textattack.attack_recipes.Pruthi2019",
|
36 |
+
"pso": "textattack.attack_recipes.PSOZang2020",
|
37 |
+
"checklist": "textattack.attack_recipes.CheckList2020",
|
38 |
+
"clare": "textattack.attack_recipes.CLARE2020",
|
39 |
+
"a2t": "textattack.attack_recipes.A2TYoo2021",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
|
44 |
+
"random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion",
|
45 |
+
"word-deletion": "textattack.transformations.WordDeletion",
|
46 |
+
"word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
|
47 |
+
"word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
|
48 |
+
"word-swap-inflections": "textattack.transformations.WordSwapInflections",
|
49 |
+
"word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
|
50 |
+
"word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
|
51 |
+
"word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
|
52 |
+
"word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
|
53 |
+
"word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
|
54 |
+
"word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
|
55 |
+
"word-swap-hownet": "textattack.transformations.WordSwapHowNet",
|
56 |
+
"word-swap-qwerty": "textattack.transformations.WordSwapQWERTY",
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
|
61 |
+
"word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
|
62 |
+
}
|
63 |
+
|
64 |
+
|
65 |
+
CONSTRAINT_CLASS_NAMES = {
|
66 |
+
#
|
67 |
+
# Semantics constraints
|
68 |
+
#
|
69 |
+
"embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
|
70 |
+
"bert": "textattack.constraints.semantics.sentence_encoders.BERT",
|
71 |
+
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
|
72 |
+
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
|
73 |
+
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
|
74 |
+
"muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder",
|
75 |
+
"bert-score": "textattack.constraints.semantics.BERTScore",
|
76 |
+
#
|
77 |
+
# Grammaticality constraints
|
78 |
+
#
|
79 |
+
"lang-tool": "textattack.constraints.grammaticality.LanguageTool",
|
80 |
+
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
|
81 |
+
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
|
82 |
+
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
|
83 |
+
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
|
84 |
+
"cola": "textattack.constraints.grammaticality.COLA",
|
85 |
+
#
|
86 |
+
# Overlap constraints
|
87 |
+
#
|
88 |
+
"bleu": "textattack.constraints.overlap.BLEU",
|
89 |
+
"chrf": "textattack.constraints.overlap.chrF",
|
90 |
+
"edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
|
91 |
+
"meteor": "textattack.constraints.overlap.METEOR",
|
92 |
+
"max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
|
93 |
+
#
|
94 |
+
# Pre-transformation constraints
|
95 |
+
#
|
96 |
+
"repeat": "textattack.constraints.pre_transformation.RepeatModification",
|
97 |
+
"stopword": "textattack.constraints.pre_transformation.StopwordModification",
|
98 |
+
"max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification",
|
99 |
+
}
|
100 |
+
|
101 |
+
|
102 |
+
SEARCH_METHOD_CLASS_NAMES = {
|
103 |
+
"beam-search": "textattack.search_methods.BeamSearch",
|
104 |
+
"greedy": "textattack.search_methods.GreedySearch",
|
105 |
+
"ga-word": "textattack.search_methods.GeneticAlgorithm",
|
106 |
+
"greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
|
107 |
+
"pso": "textattack.search_methods.ParticleSwarmOptimization",
|
108 |
+
}
|
109 |
+
|
110 |
+
|
111 |
+
GOAL_FUNCTION_CLASS_NAMES = {
|
112 |
+
#
|
113 |
+
# Classification goal functions
|
114 |
+
#
|
115 |
+
"targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
|
116 |
+
"untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
|
117 |
+
"input-reduction": "textattack.goal_functions.classification.InputReduction",
|
118 |
+
#
|
119 |
+
# Text goal functions
|
120 |
+
#
|
121 |
+
"minimize-bleu": "textattack.goal_functions.text.MinimizeBleu",
|
122 |
+
"non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput",
|
123 |
+
"text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction",
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
@dataclass
|
128 |
+
class AttackArgs:
|
129 |
+
"""Attack arguments to be passed to :class:`~textattack.Attacker`.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
|
133 |
+
The number of examples to attack. :obj:`-1` for entire dataset.
|
134 |
+
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
|
135 |
+
The number of successful adversarial examples we want. This is different from :obj:`num_examples`
|
136 |
+
as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
|
137 |
+
until we have `N` successful cases.
|
138 |
+
|
139 |
+
.. note::
|
140 |
+
If set, this argument overrides `num_examples` argument.
|
141 |
+
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
|
142 |
+
The offset index to start at in the dataset.
|
143 |
+
attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
144 |
+
Whether to run attack until total of `N` examples have been attacked (and not skipped).
|
145 |
+
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
146 |
+
If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
|
147 |
+
the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
|
148 |
+
:obj:`shuffle` can now be used with checkpoint saving.
|
149 |
+
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
|
150 |
+
The maximum number of model queries allowed per example attacked.
|
151 |
+
If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).
|
152 |
+
|
153 |
+
.. note::
|
154 |
+
Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
|
155 |
+
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
|
156 |
+
If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved.
|
157 |
+
checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
|
158 |
+
The directory to save checkpoint files.
|
159 |
+
random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
|
160 |
+
Random seed for reproducibility.
|
161 |
+
parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
|
162 |
+
If :obj:`True`, run attack using multiple CPUs/GPUs.
|
163 |
+
num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
|
164 |
+
Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`,
|
165 |
+
then 2 processes will be running in each GPU.
|
166 |
+
log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
|
167 |
+
If set, save attack logs as a `.txt` file to the directory specified by this argument.
|
168 |
+
If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
|
169 |
+
log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
|
170 |
+
If set, save attack logs as a CSV file to the directory specified by this argument.
|
171 |
+
If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
|
172 |
+
csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
|
173 |
+
Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`.
|
174 |
+
:obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way.
|
175 |
+
log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
|
176 |
+
If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
|
177 |
+
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
|
178 |
+
three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
|
179 |
+
log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
|
180 |
+
If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
|
181 |
+
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
|
182 |
+
key and its corresponding value: :obj:`"project"`.
|
183 |
+
disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
184 |
+
Disable displaying individual attack results to stdout.
|
185 |
+
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
186 |
+
Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
|
187 |
+
enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
188 |
+
Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
|
189 |
+
"""
|
190 |
+
|
191 |
+
num_examples: int = 10
|
192 |
+
num_successful_examples: int = None
|
193 |
+
num_examples_offset: int = 0
|
194 |
+
attack_n: bool = False
|
195 |
+
shuffle: bool = False
|
196 |
+
query_budget: int = None
|
197 |
+
checkpoint_interval: int = None
|
198 |
+
checkpoint_dir: str = "checkpoints"
|
199 |
+
random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
|
200 |
+
parallel: bool = False
|
201 |
+
num_workers_per_device: int = 1
|
202 |
+
log_to_txt: str = None
|
203 |
+
log_to_csv: str = None
|
204 |
+
log_summary_to_json: str = None
|
205 |
+
csv_coloring_style: str = "file"
|
206 |
+
log_to_visdom: dict = None
|
207 |
+
log_to_wandb: dict = None
|
208 |
+
disable_stdout: bool = False
|
209 |
+
silent: bool = False
|
210 |
+
enable_advance_metrics: bool = False
|
211 |
+
metrics: Optional[Dict] = None
|
212 |
+
|
213 |
+
def __post_init__(self):
|
214 |
+
if self.num_successful_examples:
|
215 |
+
self.num_examples = None
|
216 |
+
if self.num_examples:
|
217 |
+
assert (
|
218 |
+
self.num_examples >= 0 or self.num_examples == -1
|
219 |
+
), "`num_examples` must be greater than or equal to 0 or equal to -1."
|
220 |
+
if self.num_successful_examples:
|
221 |
+
assert (
|
222 |
+
self.num_successful_examples >= 0
|
223 |
+
), "`num_examples` must be greater than or equal to 0."
|
224 |
+
|
225 |
+
if self.query_budget:
|
226 |
+
assert self.query_budget > 0, "`query_budget` must be greater than 0."
|
227 |
+
|
228 |
+
if self.checkpoint_interval:
|
229 |
+
assert (
|
230 |
+
self.checkpoint_interval > 0
|
231 |
+
), "`checkpoint_interval` must be greater than 0."
|
232 |
+
|
233 |
+
assert (
|
234 |
+
self.num_workers_per_device > 0
|
235 |
+
), "`num_workers_per_device` must be greater than 0."
|
236 |
+
|
237 |
+
@classmethod
|
238 |
+
def _add_parser_args(cls, parser):
|
239 |
+
"""Add listed args to command line parser."""
|
240 |
+
default_obj = cls()
|
241 |
+
num_ex_group = parser.add_mutually_exclusive_group(required=False)
|
242 |
+
num_ex_group.add_argument(
|
243 |
+
"--num-examples",
|
244 |
+
"-n",
|
245 |
+
type=int,
|
246 |
+
default=default_obj.num_examples,
|
247 |
+
help="The number of examples to process, -1 for entire dataset.",
|
248 |
+
)
|
249 |
+
num_ex_group.add_argument(
|
250 |
+
"--num-successful-examples",
|
251 |
+
type=int,
|
252 |
+
default=default_obj.num_successful_examples,
|
253 |
+
help="The number of successful adversarial examples we want.",
|
254 |
+
)
|
255 |
+
parser.add_argument(
|
256 |
+
"--num-examples-offset",
|
257 |
+
"-o",
|
258 |
+
type=int,
|
259 |
+
required=False,
|
260 |
+
default=default_obj.num_examples_offset,
|
261 |
+
help="The offset to start at in the dataset.",
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--query-budget",
|
265 |
+
"-q",
|
266 |
+
type=int,
|
267 |
+
default=default_obj.query_budget,
|
268 |
+
help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--shuffle",
|
272 |
+
action="store_true",
|
273 |
+
default=default_obj.shuffle,
|
274 |
+
help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
|
275 |
+
)
|
276 |
+
parser.add_argument(
|
277 |
+
"--attack-n",
|
278 |
+
action="store_true",
|
279 |
+
default=default_obj.attack_n,
|
280 |
+
help="Whether to run attack until `n` examples have been attacked (not skipped).",
|
281 |
+
)
|
282 |
+
parser.add_argument(
|
283 |
+
"--checkpoint-dir",
|
284 |
+
required=False,
|
285 |
+
type=str,
|
286 |
+
default=default_obj.checkpoint_dir,
|
287 |
+
help="The directory to save checkpoint files.",
|
288 |
+
)
|
289 |
+
parser.add_argument(
|
290 |
+
"--checkpoint-interval",
|
291 |
+
required=False,
|
292 |
+
type=int,
|
293 |
+
default=default_obj.checkpoint_interval,
|
294 |
+
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--random-seed",
|
298 |
+
default=default_obj.random_seed,
|
299 |
+
type=int,
|
300 |
+
help="Random seed for reproducibility.",
|
301 |
+
)
|
302 |
+
parser.add_argument(
|
303 |
+
"--parallel",
|
304 |
+
action="store_true",
|
305 |
+
default=default_obj.parallel,
|
306 |
+
help="Run attack using multiple GPUs.",
|
307 |
+
)
|
308 |
+
parser.add_argument(
|
309 |
+
"--num-workers-per-device",
|
310 |
+
default=default_obj.num_workers_per_device,
|
311 |
+
type=int,
|
312 |
+
help="Number of worker processes to run per device.",
|
313 |
+
)
|
314 |
+
parser.add_argument(
|
315 |
+
"--log-to-txt",
|
316 |
+
nargs="?",
|
317 |
+
default=default_obj.log_to_txt,
|
318 |
+
const="",
|
319 |
+
type=str,
|
320 |
+
help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
|
321 |
+
"If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--log-to-csv",
|
325 |
+
nargs="?",
|
326 |
+
default=default_obj.log_to_csv,
|
327 |
+
const="",
|
328 |
+
type=str,
|
329 |
+
help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
|
330 |
+
"If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--log-summary-to-json",
|
334 |
+
nargs="?",
|
335 |
+
default=default_obj.log_summary_to_json,
|
336 |
+
const="",
|
337 |
+
type=str,
|
338 |
+
help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. "
|
339 |
+
"If the last part of the path ends with `.json` extension, the path is assumed to path for output file.",
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--csv-coloring-style",
|
343 |
+
default=default_obj.csv_coloring_style,
|
344 |
+
type=str,
|
345 |
+
help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
|
346 |
+
'"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
|
347 |
+
)
|
348 |
+
parser.add_argument(
|
349 |
+
"--log-to-visdom",
|
350 |
+
nargs="?",
|
351 |
+
default=None,
|
352 |
+
const='{"env": "main", "port": 8097, "hostname": "localhost"}',
|
353 |
+
type=json.loads,
|
354 |
+
help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following "
|
355 |
+
'three keys and their corresponding values: `"env", "port", "hostname"`. '
|
356 |
+
'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
|
357 |
+
)
|
358 |
+
parser.add_argument(
|
359 |
+
"--log-to-wandb",
|
360 |
+
nargs="?",
|
361 |
+
default=None,
|
362 |
+
const='{"project": "textattack"}',
|
363 |
+
type=json.loads,
|
364 |
+
help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
|
365 |
+
'key and its corresponding value: `"project"`. '
|
366 |
+
'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
|
367 |
+
)
|
368 |
+
parser.add_argument(
|
369 |
+
"--disable-stdout",
|
370 |
+
action="store_true",
|
371 |
+
default=default_obj.disable_stdout,
|
372 |
+
help="Disable logging attack results to stdout",
|
373 |
+
)
|
374 |
+
parser.add_argument(
|
375 |
+
"--silent",
|
376 |
+
action="store_true",
|
377 |
+
default=default_obj.silent,
|
378 |
+
help="Disable all logging",
|
379 |
+
)
|
380 |
+
parser.add_argument(
|
381 |
+
"--enable-advance-metrics",
|
382 |
+
action="store_true",
|
383 |
+
default=default_obj.enable_advance_metrics,
|
384 |
+
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
|
385 |
+
)
|
386 |
+
|
387 |
+
return parser
|
388 |
+
|
389 |
+
@classmethod
|
390 |
+
def create_loggers_from_args(cls, args):
|
391 |
+
"""Creates AttackLogManager from an AttackArgs object."""
|
392 |
+
assert isinstance(
|
393 |
+
args, cls
|
394 |
+
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
|
395 |
+
|
396 |
+
# Create logger
|
397 |
+
attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)
|
398 |
+
|
399 |
+
# Get current time for file naming
|
400 |
+
timestamp = time.strftime("%Y-%m-%d-%H-%M")
|
401 |
+
|
402 |
+
# if '--log-to-txt' specified with arguments
|
403 |
+
if args.log_to_txt is not None:
|
404 |
+
if args.log_to_txt.lower().endswith(".txt"):
|
405 |
+
txt_file_path = args.log_to_txt
|
406 |
+
else:
|
407 |
+
txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt")
|
408 |
+
|
409 |
+
dir_path = os.path.dirname(txt_file_path)
|
410 |
+
dir_path = dir_path if dir_path else "."
|
411 |
+
if not os.path.exists(dir_path):
|
412 |
+
os.makedirs(os.path.dirname(txt_file_path))
|
413 |
+
|
414 |
+
color_method = "file"
|
415 |
+
attack_log_manager.add_output_file(txt_file_path, color_method)
|
416 |
+
|
417 |
+
# if '--log-to-csv' specified with arguments
|
418 |
+
if args.log_to_csv is not None:
|
419 |
+
if args.log_to_csv.lower().endswith(".csv"):
|
420 |
+
csv_file_path = args.log_to_csv
|
421 |
+
else:
|
422 |
+
csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv")
|
423 |
+
|
424 |
+
dir_path = os.path.dirname(csv_file_path)
|
425 |
+
dir_path = dir_path if dir_path else "."
|
426 |
+
if not os.path.exists(dir_path):
|
427 |
+
os.makedirs(dir_path)
|
428 |
+
|
429 |
+
color_method = (
|
430 |
+
None if args.csv_coloring_style == "plain" else args.csv_coloring_style
|
431 |
+
)
|
432 |
+
attack_log_manager.add_output_csv(csv_file_path, color_method)
|
433 |
+
|
434 |
+
# if '--log-summary-to-json' specified with arguments
|
435 |
+
if args.log_summary_to_json is not None:
|
436 |
+
if args.log_summary_to_json.lower().endswith(".json"):
|
437 |
+
summary_json_file_path = args.log_summary_to_json
|
438 |
+
else:
|
439 |
+
summary_json_file_path = os.path.join(
|
440 |
+
args.log_summary_to_json, f"{timestamp}-attack_summary_log.json"
|
441 |
+
)
|
442 |
+
|
443 |
+
dir_path = os.path.dirname(summary_json_file_path)
|
444 |
+
dir_path = dir_path if dir_path else "."
|
445 |
+
if not os.path.exists(dir_path):
|
446 |
+
os.makedirs(os.path.dirname(summary_json_file_path))
|
447 |
+
|
448 |
+
attack_log_manager.add_output_summary_json(summary_json_file_path)
|
449 |
+
|
450 |
+
# Visdom
|
451 |
+
if args.log_to_visdom is not None:
|
452 |
+
attack_log_manager.enable_visdom(**args.log_to_visdom)
|
453 |
+
|
454 |
+
# Weights & Biases
|
455 |
+
if args.log_to_wandb is not None:
|
456 |
+
attack_log_manager.enable_wandb(**args.log_to_wandb)
|
457 |
+
|
458 |
+
# Stdout
|
459 |
+
if not args.disable_stdout and not sys.stdout.isatty():
|
460 |
+
attack_log_manager.disable_color()
|
461 |
+
elif not args.disable_stdout:
|
462 |
+
attack_log_manager.enable_stdout()
|
463 |
+
|
464 |
+
return attack_log_manager
|
465 |
+
|
466 |
+
|
467 |
+
@dataclass
|
468 |
+
class _CommandLineAttackArgs:
|
469 |
+
"""Attack args for command line execution. This requires more arguments to
|
470 |
+
create ``Attack`` object as specified.
|
471 |
+
|
472 |
+
Args:
|
473 |
+
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
|
474 |
+
Name of transformation to use.
|
475 |
+
constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
|
476 |
+
List of names of constraints to use.
|
477 |
+
goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
|
478 |
+
Name of goal function to use.
|
479 |
+
search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
|
480 |
+
Name of search method to use.
|
481 |
+
attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
|
482 |
+
Name of attack recipe to use.
|
483 |
+
.. note::
|
484 |
+
Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
|
485 |
+
attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
|
486 |
+
Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
|
487 |
+
.. note::
|
488 |
+
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
|
489 |
+
interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
490 |
+
If `True`, carry attack in interactive mode.
|
491 |
+
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
492 |
+
If `True`, attack in parallel.
|
493 |
+
model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
|
494 |
+
The batch size for making queries to the victim model.
|
495 |
+
model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
496 |
+
The maximum number of items to keep in the model results cache at once.
|
497 |
+
constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
498 |
+
The maximum number of items to keep in the constraints cache at once.
|
499 |
+
"""
|
500 |
+
|
501 |
+
transformation: str = "word-swap-embedding"
|
502 |
+
constraints: list = field(default_factory=lambda: ["repeat", "stopword"])
|
503 |
+
goal_function: str = "untargeted-classification"
|
504 |
+
search_method: str = "greedy-word-wir"
|
505 |
+
attack_recipe: str = None
|
506 |
+
attack_from_file: str = None
|
507 |
+
interactive: bool = False
|
508 |
+
parallel: bool = False
|
509 |
+
model_batch_size: int = 32
|
510 |
+
model_cache_size: int = 2**18
|
511 |
+
constraint_cache_size: int = 2**18
|
512 |
+
|
513 |
+
@classmethod
|
514 |
+
def _add_parser_args(cls, parser):
|
515 |
+
"""Add listed args to command line parser."""
|
516 |
+
default_obj = cls()
|
517 |
+
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
|
518 |
+
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
|
519 |
+
)
|
520 |
+
parser.add_argument(
|
521 |
+
"--transformation",
|
522 |
+
type=str,
|
523 |
+
required=False,
|
524 |
+
default=default_obj.transformation,
|
525 |
+
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
526 |
+
+ str(transformation_names),
|
527 |
+
)
|
528 |
+
parser.add_argument(
|
529 |
+
"--constraints",
|
530 |
+
type=str,
|
531 |
+
required=False,
|
532 |
+
nargs="*",
|
533 |
+
default=default_obj.constraints,
|
534 |
+
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
535 |
+
+ str(CONSTRAINT_CLASS_NAMES.keys()),
|
536 |
+
)
|
537 |
+
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
|
538 |
+
parser.add_argument(
|
539 |
+
"--goal-function",
|
540 |
+
"-g",
|
541 |
+
default=default_obj.goal_function,
|
542 |
+
help=f"The goal function to use. choices: {goal_function_choices}",
|
543 |
+
)
|
544 |
+
attack_group = parser.add_mutually_exclusive_group(required=False)
|
545 |
+
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
|
546 |
+
attack_group.add_argument(
|
547 |
+
"--search-method",
|
548 |
+
"--search",
|
549 |
+
"-s",
|
550 |
+
type=str,
|
551 |
+
required=False,
|
552 |
+
default=default_obj.search_method,
|
553 |
+
help=f"The search method to use. choices: {search_choices}",
|
554 |
+
)
|
555 |
+
attack_group.add_argument(
|
556 |
+
"--attack-recipe",
|
557 |
+
"--recipe",
|
558 |
+
"-r",
|
559 |
+
type=str,
|
560 |
+
required=False,
|
561 |
+
default=default_obj.attack_recipe,
|
562 |
+
help="full attack recipe (overrides provided goal function, transformation & constraints)",
|
563 |
+
choices=ATTACK_RECIPE_NAMES.keys(),
|
564 |
+
)
|
565 |
+
attack_group.add_argument(
|
566 |
+
"--attack-from-file",
|
567 |
+
type=str,
|
568 |
+
required=False,
|
569 |
+
default=default_obj.attack_from_file,
|
570 |
+
help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
|
571 |
+
)
|
572 |
+
parser.add_argument(
|
573 |
+
"--interactive",
|
574 |
+
action="store_true",
|
575 |
+
default=default_obj.interactive,
|
576 |
+
help="Whether to run attacks interactively.",
|
577 |
+
)
|
578 |
+
parser.add_argument(
|
579 |
+
"--model-batch-size",
|
580 |
+
type=int,
|
581 |
+
default=default_obj.model_batch_size,
|
582 |
+
help="The batch size for making calls to the model.",
|
583 |
+
)
|
584 |
+
parser.add_argument(
|
585 |
+
"--model-cache-size",
|
586 |
+
type=int,
|
587 |
+
default=default_obj.model_cache_size,
|
588 |
+
help="The maximum number of items to keep in the model results cache at once.",
|
589 |
+
)
|
590 |
+
parser.add_argument(
|
591 |
+
"--constraint-cache-size",
|
592 |
+
type=int,
|
593 |
+
default=default_obj.constraint_cache_size,
|
594 |
+
help="The maximum number of items to keep in the constraints cache at once.",
|
595 |
+
)
|
596 |
+
|
597 |
+
return parser
|
598 |
+
|
599 |
+
@classmethod
|
600 |
+
def _create_transformation_from_args(cls, args, model_wrapper):
|
601 |
+
"""Create `Transformation` based on provided `args` and
|
602 |
+
`model_wrapper`."""
|
603 |
+
|
604 |
+
transformation_name = args.transformation
|
605 |
+
if ARGS_SPLIT_TOKEN in transformation_name:
|
606 |
+
transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)
|
607 |
+
|
608 |
+
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
609 |
+
transformation = eval(
|
610 |
+
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
|
611 |
+
)
|
612 |
+
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
613 |
+
transformation = eval(
|
614 |
+
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
raise ValueError(
|
618 |
+
f"Error: unsupported transformation {transformation_name}"
|
619 |
+
)
|
620 |
+
else:
|
621 |
+
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
622 |
+
transformation = eval(
|
623 |
+
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
|
624 |
+
)
|
625 |
+
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
626 |
+
transformation = eval(
|
627 |
+
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
|
628 |
+
)
|
629 |
+
else:
|
630 |
+
raise ValueError(
|
631 |
+
f"Error: unsupported transformation {transformation_name}"
|
632 |
+
)
|
633 |
+
return transformation
|
634 |
+
|
635 |
+
@classmethod
|
636 |
+
def _create_goal_function_from_args(cls, args, model_wrapper):
|
637 |
+
"""Create `GoalFunction` based on provided `args` and
|
638 |
+
`model_wrapper`."""
|
639 |
+
|
640 |
+
goal_function = args.goal_function
|
641 |
+
if ARGS_SPLIT_TOKEN in goal_function:
|
642 |
+
goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
|
643 |
+
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
|
644 |
+
raise ValueError(
|
645 |
+
f"Error: unsupported goal_function {goal_function_name}"
|
646 |
+
)
|
647 |
+
goal_function = eval(
|
648 |
+
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})"
|
649 |
+
)
|
650 |
+
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
|
651 |
+
goal_function = eval(
|
652 |
+
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)"
|
653 |
+
)
|
654 |
+
else:
|
655 |
+
raise ValueError(f"Error: unsupported goal_function {goal_function}")
|
656 |
+
if args.query_budget:
|
657 |
+
goal_function.query_budget = args.query_budget
|
658 |
+
goal_function.model_cache_size = args.model_cache_size
|
659 |
+
goal_function.batch_size = args.model_batch_size
|
660 |
+
return goal_function
|
661 |
+
|
662 |
+
@classmethod
|
663 |
+
def _create_constraints_from_args(cls, args):
|
664 |
+
"""Create list of `Constraints` based on provided `args`."""
|
665 |
+
|
666 |
+
if not args.constraints:
|
667 |
+
return []
|
668 |
+
|
669 |
+
_constraints = []
|
670 |
+
for constraint in args.constraints:
|
671 |
+
if ARGS_SPLIT_TOKEN in constraint:
|
672 |
+
constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
|
673 |
+
if constraint_name not in CONSTRAINT_CLASS_NAMES:
|
674 |
+
raise ValueError(f"Error: unsupported constraint {constraint_name}")
|
675 |
+
_constraints.append(
|
676 |
+
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
|
677 |
+
)
|
678 |
+
elif constraint in CONSTRAINT_CLASS_NAMES:
|
679 |
+
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
|
680 |
+
else:
|
681 |
+
raise ValueError(f"Error: unsupported constraint {constraint}")
|
682 |
+
|
683 |
+
return _constraints
|
684 |
+
|
685 |
+
@classmethod
|
686 |
+
def _create_attack_from_args(cls, args, model_wrapper):
|
687 |
+
"""Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
|
688 |
+
``Attack`` object."""
|
689 |
+
|
690 |
+
assert isinstance(
|
691 |
+
args, cls
|
692 |
+
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
|
693 |
+
|
694 |
+
if args.attack_recipe:
|
695 |
+
if ARGS_SPLIT_TOKEN in args.attack_recipe:
|
696 |
+
recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
|
697 |
+
if recipe_name not in ATTACK_RECIPE_NAMES:
|
698 |
+
raise ValueError(f"Error: unsupported recipe {recipe_name}")
|
699 |
+
recipe = eval(
|
700 |
+
f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
|
701 |
+
)
|
702 |
+
elif args.attack_recipe in ATTACK_RECIPE_NAMES:
|
703 |
+
recipe = eval(
|
704 |
+
f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
|
705 |
+
)
|
706 |
+
else:
|
707 |
+
raise ValueError(f"Invalid recipe {args.attack_recipe}")
|
708 |
+
if args.query_budget:
|
709 |
+
recipe.goal_function.query_budget = args.query_budget
|
710 |
+
recipe.goal_function.model_cache_size = args.model_cache_size
|
711 |
+
recipe.constraint_cache_size = args.constraint_cache_size
|
712 |
+
return recipe
|
713 |
+
elif args.attack_from_file:
|
714 |
+
if ARGS_SPLIT_TOKEN in args.attack_from_file:
|
715 |
+
attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
|
716 |
+
else:
|
717 |
+
attack_file, attack_name = args.attack_from_file, "attack"
|
718 |
+
attack_module = load_module_from_file(attack_file)
|
719 |
+
if not hasattr(attack_module, attack_name):
|
720 |
+
raise ValueError(
|
721 |
+
f"Loaded `{attack_file}` but could not find `{attack_name}`."
|
722 |
+
)
|
723 |
+
attack_func = getattr(attack_module, attack_name)
|
724 |
+
return attack_func(model_wrapper)
|
725 |
+
else:
|
726 |
+
goal_function = cls._create_goal_function_from_args(args, model_wrapper)
|
727 |
+
transformation = cls._create_transformation_from_args(args, model_wrapper)
|
728 |
+
constraints = cls._create_constraints_from_args(args)
|
729 |
+
if ARGS_SPLIT_TOKEN in args.search_method:
|
730 |
+
search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
|
731 |
+
if search_name not in SEARCH_METHOD_CLASS_NAMES:
|
732 |
+
raise ValueError(f"Error: unsupported search {search_name}")
|
733 |
+
search_method = eval(
|
734 |
+
f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
|
735 |
+
)
|
736 |
+
elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
|
737 |
+
search_method = eval(
|
738 |
+
f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
|
739 |
+
)
|
740 |
+
else:
|
741 |
+
raise ValueError(f"Error: unsupported attack {args.search_method}")
|
742 |
+
|
743 |
+
return Attack(
|
744 |
+
goal_function,
|
745 |
+
constraints,
|
746 |
+
transformation,
|
747 |
+
search_method,
|
748 |
+
constraint_cache_size=args.constraint_cache_size,
|
749 |
+
)
|
750 |
+
|
751 |
+
|
752 |
+
# This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass.
|
753 |
+
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
754 |
+
@dataclass
|
755 |
+
class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs):
|
756 |
+
@classmethod
|
757 |
+
def _add_parser_args(cls, parser):
|
758 |
+
"""Add listed args to command line parser."""
|
759 |
+
parser = ModelArgs._add_parser_args(parser)
|
760 |
+
parser = DatasetArgs._add_parser_args(parser)
|
761 |
+
parser = _CommandLineAttackArgs._add_parser_args(parser)
|
762 |
+
parser = AttackArgs._add_parser_args(parser)
|
763 |
+
return parser
|
textattack/attack_recipes/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""".. _attack_recipes:
|
2 |
+
|
3 |
+
Attack Recipes Package:
|
4 |
+
========================
|
5 |
+
|
6 |
+
We provide a number of pre-built attack recipes, which correspond to attacks from the literature. To run an attack recipe from the command line, run::
|
7 |
+
|
8 |
+
textattack attack --recipe [recipe_name]
|
9 |
+
|
10 |
+
To initialize an attack in Python script, use::
|
11 |
+
|
12 |
+
<recipe name>.build(model_wrapper)
|
13 |
+
|
14 |
+
For example, ``attack = InputReductionFeng2018.build(model)`` creates `attack`, an object of type ``Attack`` with the goal function, transformation, constraints, and search method specified in that paper. This object can then be used just like any other attack; for example, by calling ``attack.attack_dataset``.
|
15 |
+
|
16 |
+
TextAttack supports the following attack recipes (each recipe's documentation contains a link to the corresponding paper):
|
17 |
+
|
18 |
+
.. contents:: :local:
|
19 |
+
"""
|
20 |
+
|
21 |
+
from .attack_recipe import AttackRecipe
|
22 |
+
|
23 |
+
from .a2t_yoo_2021 import A2TYoo2021
|
24 |
+
from .bae_garg_2019 import BAEGarg2019
|
25 |
+
from .bert_attack_li_2020 import BERTAttackLi2020
|
26 |
+
from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
|
27 |
+
from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
|
28 |
+
from .deepwordbug_gao_2018 import DeepWordBugGao2018
|
29 |
+
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
|
30 |
+
from .input_reduction_feng_2018 import InputReductionFeng2018
|
31 |
+
from .kuleshov_2017 import Kuleshov2017
|
32 |
+
from .morpheus_tan_2020 import MorpheusTan2020
|
33 |
+
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
|
34 |
+
from .textbugger_li_2018 import TextBuggerLi2018
|
35 |
+
from .textfooler_jin_2019 import TextFoolerJin2019
|
36 |
+
from .pwws_ren_2019 import PWWSRen2019
|
37 |
+
from .iga_wang_2019 import IGAWang2019
|
38 |
+
from .pruthi_2019 import Pruthi2019
|
39 |
+
from .pso_zang_2020 import PSOZang2020
|
40 |
+
from .checklist_ribeiro_2020 import CheckList2020
|
41 |
+
from .clare_li_2020 import CLARE2020
|
42 |
+
from .french_recipe import FrenchRecipe
|
43 |
+
from .spanish_recipe import SpanishRecipe
|
textattack/attack_recipes/a2t_yoo_2021.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A2T (A2T: Attack for Adversarial Training Recipe)
|
3 |
+
==================================================
|
4 |
+
|
5 |
+
"""
|
6 |
+
|
7 |
+
from textattack import Attack
|
8 |
+
from textattack.constraints.grammaticality import PartOfSpeech
|
9 |
+
from textattack.constraints.pre_transformation import (
|
10 |
+
InputColumnModification,
|
11 |
+
MaxModificationRate,
|
12 |
+
RepeatModification,
|
13 |
+
StopwordModification,
|
14 |
+
)
|
15 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
16 |
+
from textattack.constraints.semantics.sentence_encoders import BERT
|
17 |
+
from textattack.goal_functions import UntargetedClassification
|
18 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
19 |
+
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
|
20 |
+
|
21 |
+
from .attack_recipe import AttackRecipe
|
22 |
+
|
23 |
+
|
24 |
+
class A2TYoo2021(AttackRecipe):
|
25 |
+
"""Towards Improving Adversarial Training of NLP Models.
|
26 |
+
|
27 |
+
(Yoo et al., 2021)
|
28 |
+
|
29 |
+
https://arxiv.org/abs/2109.00544
|
30 |
+
"""
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def build(model_wrapper, mlm=False):
|
34 |
+
"""Build attack recipe.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
|
38 |
+
Model wrapper containing both the model and the tokenizer.
|
39 |
+
mlm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
40 |
+
If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
:class:`~textattack.Attack`: A2T attack.
|
44 |
+
"""
|
45 |
+
constraints = [RepeatModification(), StopwordModification()]
|
46 |
+
input_column_modification = InputColumnModification(
|
47 |
+
["premise", "hypothesis"], {"premise"}
|
48 |
+
)
|
49 |
+
constraints.append(input_column_modification)
|
50 |
+
constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
|
51 |
+
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
|
52 |
+
sent_encoder = BERT(
|
53 |
+
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
|
54 |
+
)
|
55 |
+
constraints.append(sent_encoder)
|
56 |
+
|
57 |
+
if mlm:
|
58 |
+
transformation = transformation = WordSwapMaskedLM(
|
59 |
+
method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
transformation = WordSwapEmbedding(max_candidates=20)
|
63 |
+
constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
|
64 |
+
|
65 |
+
#
|
66 |
+
# Goal is untargeted classification
|
67 |
+
#
|
68 |
+
goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
|
69 |
+
#
|
70 |
+
# Greedily swap words with "Word Importance Ranking".
|
71 |
+
#
|
72 |
+
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
73 |
+
|
74 |
+
return Attack(goal_function, constraints, transformation, search_method)
|
textattack/attack_recipes/attack_recipe.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Attack Recipe Class
|
3 |
+
========================
|
4 |
+
|
5 |
+
"""
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
|
9 |
+
from textattack import Attack
|
10 |
+
|
11 |
+
|
12 |
+
class AttackRecipe(Attack, ABC):
|
13 |
+
"""A recipe for building an NLP adversarial attack from the literature."""
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
@abstractmethod
|
17 |
+
def build(model_wrapper, **kwargs):
|
18 |
+
"""Creates pre-built :class:`~textattack.Attack` that correspond to
|
19 |
+
attacks from the literature.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
|
23 |
+
:class:`~textattack.models.wrappers.ModelWrapper` that contains the victim model and tokenizer.
|
24 |
+
This is passed to :class:`~textattack.goal_functions.GoalFunction` when constructing the attack.
|
25 |
+
kwargs:
|
26 |
+
Additional keyword arguments.
|
27 |
+
Returns:
|
28 |
+
:class:`~textattack.Attack`
|
29 |
+
"""
|
30 |
+
raise NotImplementedError()
|
textattack/attack_recipes/bae_garg_2019.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BAE (BAE: BERT-Based Adversarial Examples)
|
3 |
+
============================================
|
4 |
+
|
5 |
+
"""
|
6 |
+
from textattack.constraints.grammaticality import PartOfSpeech
|
7 |
+
from textattack.constraints.pre_transformation import (
|
8 |
+
RepeatModification,
|
9 |
+
StopwordModification,
|
10 |
+
)
|
11 |
+
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
12 |
+
from textattack.goal_functions import UntargetedClassification
|
13 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
14 |
+
from textattack.transformations import WordSwapMaskedLM
|
15 |
+
|
16 |
+
from .attack_recipe import AttackRecipe
|
17 |
+
|
18 |
+
|
19 |
+
class BAEGarg2019(AttackRecipe):
|
20 |
+
"""Siddhant Garg and Goutham Ramakrishnan, 2019.
|
21 |
+
|
22 |
+
BAE: BERT-based Adversarial Examples for Text Classification.
|
23 |
+
|
24 |
+
https://arxiv.org/pdf/2004.01970
|
25 |
+
|
26 |
+
This is "attack mode" 1 from the paper, BAE-R, word replacement.
|
27 |
+
|
28 |
+
We present 4 attack modes for BAE based on the
|
29 |
+
R and I operations, where for each token t in S:
|
30 |
+
• BAE-R: Replace token t (See Algorithm 1)
|
31 |
+
• BAE-I: Insert a token to the left or right of t
|
32 |
+
• BAE-R/I: Either replace token t or insert a
|
33 |
+
token to the left or right of t
|
34 |
+
• BAE-R+I: First replace token t, then insert a
|
35 |
+
token to the left or right of t
|
36 |
+
"""
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def build(model_wrapper):
|
40 |
+
# "In this paper, we present a simple yet novel technique: BAE (BERT-based
|
41 |
+
# Adversarial Examples), which uses a language model (LM) for token
|
42 |
+
# replacement to best fit the overall context. We perturb an input sentence
|
43 |
+
# by either replacing a token or inserting a new token in the sentence, by
|
44 |
+
# means of masking a part of the input and using a LM to fill in the mask."
|
45 |
+
#
|
46 |
+
# We only consider the top K=50 synonyms from the MLM predictions.
|
47 |
+
#
|
48 |
+
# [from email correspondance with the author]
|
49 |
+
# "When choosing the top-K candidates from the BERT masked LM, we filter out
|
50 |
+
# the sub-words and only retain the whole words (by checking if they are
|
51 |
+
# present in the GloVE vocabulary)"
|
52 |
+
#
|
53 |
+
transformation = WordSwapMaskedLM(
|
54 |
+
method="bae", max_candidates=50, min_confidence=0.0
|
55 |
+
)
|
56 |
+
#
|
57 |
+
# Don't modify the same word twice or stopwords.
|
58 |
+
#
|
59 |
+
constraints = [RepeatModification(), StopwordModification()]
|
60 |
+
|
61 |
+
# For the R operations we add an additional check for
|
62 |
+
# grammatical correctness of the generated adversarial example by filtering
|
63 |
+
# out predicted tokens that do not form the same part of speech (POS) as the
|
64 |
+
# original token t_i in the sentence.
|
65 |
+
constraints.append(PartOfSpeech(allow_verb_noun_swap=True))
|
66 |
+
|
67 |
+
# "To ensure semantic similarity on introducing perturbations in the input
|
68 |
+
# text, we filter the set of top-K masked tokens (K is a pre-defined
|
69 |
+
# constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
|
70 |
+
# (Cer et al., 2018)-based sentence similarity scorer."
|
71 |
+
#
|
72 |
+
# "[We] set a threshold of 0.8 for the cosine similarity between USE-based
|
73 |
+
# embeddings of the adversarial and input text."
|
74 |
+
#
|
75 |
+
# [from email correspondence with the author]
|
76 |
+
# "For a fair comparison of the benefits of using a BERT-MLM in our paper,
|
77 |
+
# we retained the majority of TextFooler's specifications. Thus we:
|
78 |
+
# 1. Use the USE for comparison within a window of size 15 around the word
|
79 |
+
# being replaced/inserted.
|
80 |
+
# 2. Set the similarity score threshold to 0.1 for inputs shorter than the
|
81 |
+
# window size (this translates roughly to almost always accepting the new text).
|
82 |
+
# 3. Perform the USE similarity thresholding of 0.8 with respect to the text
|
83 |
+
# just before the replacement/insertion and not the original text (For
|
84 |
+
# example: at the 3rd R/I operation, we compute the USE score on a window
|
85 |
+
# of size 15 of the text obtained after the first 2 R/I operations and not
|
86 |
+
# the original text).
|
87 |
+
# ...
|
88 |
+
# To address point (3) from above, compare the USE with the original text
|
89 |
+
# at each iteration instead of the current one (While doing this change
|
90 |
+
# for the R-operation is trivial, doing it for the I-operation with the
|
91 |
+
# window based USE comparison might be more involved)."
|
92 |
+
#
|
93 |
+
# Finally, since the BAE code is based on the TextFooler code, we need to
|
94 |
+
# adjust the threshold to account for the missing / pi in the cosine
|
95 |
+
# similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi
|
96 |
+
# = 1 - (0.2 / pi) = 0.936338023.
|
97 |
+
use_constraint = UniversalSentenceEncoder(
|
98 |
+
threshold=0.936338023,
|
99 |
+
metric="cosine",
|
100 |
+
compare_against_original=True,
|
101 |
+
window_size=15,
|
102 |
+
skip_text_shorter_than_window=True,
|
103 |
+
)
|
104 |
+
constraints.append(use_constraint)
|
105 |
+
#
|
106 |
+
# Goal is untargeted classification.
|
107 |
+
#
|
108 |
+
goal_function = UntargetedClassification(model_wrapper)
|
109 |
+
#
|
110 |
+
# "We estimate the token importance Ii of each token
|
111 |
+
# t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
|
112 |
+
# decrease in probability of predicting the correct label y, similar
|
113 |
+
# to (Jin et al., 2019).
|
114 |
+
#
|
115 |
+
# • "If there are multiple tokens can cause C to misclassify S when they
|
116 |
+
# replace the mask, we choose the token which makes Sadv most similar to
|
117 |
+
# the original S based on the USE score."
|
118 |
+
# • "If no token causes misclassification, we choose the perturbation that
|
119 |
+
# decreases the prediction probability P(C(Sadv)=y) the most."
|
120 |
+
#
|
121 |
+
search_method = GreedyWordSwapWIR(wir_method="delete")
|
122 |
+
|
123 |
+
return BAEGarg2019(goal_function, constraints, transformation, search_method)
|
textattack/attack_recipes/bert_attack_li_2020.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BERT-Attack:
|
3 |
+
============================================================
|
4 |
+
|
5 |
+
(BERT-Attack: Adversarial Attack Against BERT Using BERT)
|
6 |
+
|
7 |
+
.. warning::
|
8 |
+
This attack is super slow
|
9 |
+
(see https://github.com/QData/TextAttack/issues/586)
|
10 |
+
Consider using smaller values for "max_candidates".
|
11 |
+
|
12 |
+
"""
|
13 |
+
from textattack import Attack
|
14 |
+
from textattack.constraints.overlap import MaxWordsPerturbed
|
15 |
+
from textattack.constraints.pre_transformation import (
|
16 |
+
RepeatModification,
|
17 |
+
StopwordModification,
|
18 |
+
)
|
19 |
+
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
20 |
+
from textattack.goal_functions import UntargetedClassification
|
21 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
22 |
+
from textattack.transformations import WordSwapMaskedLM
|
23 |
+
|
24 |
+
from .attack_recipe import AttackRecipe
|
25 |
+
|
26 |
+
|
27 |
+
class BERTAttackLi2020(AttackRecipe):
|
28 |
+
"""Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020).
|
29 |
+
|
30 |
+
BERT-ATTACK: Adversarial Attack Against BERT Using BERT
|
31 |
+
|
32 |
+
https://arxiv.org/abs/2004.09984
|
33 |
+
|
34 |
+
This is "attack mode" 1 from the paper, BAE-R, word replacement.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def build(model_wrapper):
|
39 |
+
# [from correspondence with the author]
|
40 |
+
# Candidate size K is set to 48 for all data-sets.
|
41 |
+
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)
|
42 |
+
#
|
43 |
+
# Don't modify the same word twice or stopwords.
|
44 |
+
#
|
45 |
+
constraints = [RepeatModification(), StopwordModification()]
|
46 |
+
|
47 |
+
# "We only take ε percent of the most important words since we tend to keep
|
48 |
+
# perturbations minimum."
|
49 |
+
#
|
50 |
+
# [from correspondence with the author]
|
51 |
+
# "Word percentage allowed to change is set to 0.4 for most data-sets, this
|
52 |
+
# parameter is trivial since most attacks only need a few changes. This
|
53 |
+
# epsilon is only used to avoid too much queries on those very hard samples."
|
54 |
+
constraints.append(MaxWordsPerturbed(max_percent=0.4))
|
55 |
+
|
56 |
+
# "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence
|
57 |
+
# Encoder (Cer et al., 2018) to measure the semantic consistency between the
|
58 |
+
# adversarial sample and the original sequence. To balance between semantic
|
59 |
+
# preservation and attack success rate, we set up a threshold of semantic
|
60 |
+
# similarity score to filter the less similar examples."
|
61 |
+
#
|
62 |
+
# [from correspondence with author]
|
63 |
+
# "Over the full texts, after generating all the adversarial samples, we filter
|
64 |
+
# out low USE score samples. Thus the success rate is lower but the USE score
|
65 |
+
# can be higher. (actually USE score is not a golden metric, so we simply
|
66 |
+
# measure the USE score over the final texts for a comparison with TextFooler).
|
67 |
+
# For datasets like IMDB, we set a higher threshold between 0.4-0.7; for
|
68 |
+
# datasets like MNLI, we set threshold between 0-0.2."
|
69 |
+
#
|
70 |
+
# Since the threshold in the real world can't be determined from the training
|
71 |
+
# data, the TextAttack implementation uses a fixed threshold - determined to
|
72 |
+
# be 0.2 to be most fair.
|
73 |
+
use_constraint = UniversalSentenceEncoder(
|
74 |
+
threshold=0.2,
|
75 |
+
metric="cosine",
|
76 |
+
compare_against_original=True,
|
77 |
+
window_size=None,
|
78 |
+
)
|
79 |
+
constraints.append(use_constraint)
|
80 |
+
#
|
81 |
+
# Goal is untargeted classification.
|
82 |
+
#
|
83 |
+
goal_function = UntargetedClassification(model_wrapper)
|
84 |
+
#
|
85 |
+
# "We first select the words in the sequence which have a high significance
|
86 |
+
# influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote
|
87 |
+
# the input sentence, and oy(S) denote the logit output by the target model
|
88 |
+
# for correct label y, the importance score Iwi is defined as
|
89 |
+
# Iwi = oy(S) − oy(S\wi), where S\wi = [w0, ··· , wi−1, [MASK], wi+1, ···]
|
90 |
+
# is the sentence after replacing wi with [MASK]. Then we rank all the words
|
91 |
+
# according to the ranking score Iwi in descending order to create word list
|
92 |
+
# L."
|
93 |
+
search_method = GreedyWordSwapWIR(wir_method="unk")
|
94 |
+
|
95 |
+
return Attack(goal_function, constraints, transformation, search_method)
|
textattack/attack_recipes/checklist_ribeiro_2020.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
CheckList:
|
3 |
+
=========================
|
4 |
+
|
5 |
+
(Beyond Accuracy: Behavioral Testing of NLP models with CheckList)
|
6 |
+
|
7 |
+
"""
|
8 |
+
from textattack import Attack
|
9 |
+
from textattack.constraints.pre_transformation import RepeatModification
|
10 |
+
from textattack.goal_functions import UntargetedClassification
|
11 |
+
from textattack.search_methods import GreedySearch
|
12 |
+
from textattack.transformations import (
|
13 |
+
CompositeTransformation,
|
14 |
+
WordSwapChangeLocation,
|
15 |
+
WordSwapChangeName,
|
16 |
+
WordSwapChangeNumber,
|
17 |
+
WordSwapContract,
|
18 |
+
WordSwapExtend,
|
19 |
+
)
|
20 |
+
|
21 |
+
from .attack_recipe import AttackRecipe
|
22 |
+
|
23 |
+
|
24 |
+
class CheckList2020(AttackRecipe):
|
25 |
+
"""An implementation of the attack used in "Beyond Accuracy: Behavioral
|
26 |
+
Testing of NLP models with CheckList", Ribeiro et al., 2020.
|
27 |
+
|
28 |
+
This attack focuses on a number of attacks used in the Invariance Testing
|
29 |
+
Method: Contraction, Extension, Changing Names, Number, Location
|
30 |
+
|
31 |
+
https://arxiv.org/abs/2005.04118
|
32 |
+
"""
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def build(model_wrapper):
|
36 |
+
transformation = CompositeTransformation(
|
37 |
+
[
|
38 |
+
WordSwapExtend(),
|
39 |
+
WordSwapContract(),
|
40 |
+
WordSwapChangeName(),
|
41 |
+
WordSwapChangeNumber(),
|
42 |
+
WordSwapChangeLocation(),
|
43 |
+
]
|
44 |
+
)
|
45 |
+
|
46 |
+
# Need this constraint to prevent extend and contract modifying each others' changes and forming infinite loop
|
47 |
+
constraints = [RepeatModification()]
|
48 |
+
|
49 |
+
# Untargeted attack & GreedySearch
|
50 |
+
goal_function = UntargetedClassification(model_wrapper)
|
51 |
+
search_method = GreedySearch()
|
52 |
+
|
53 |
+
return Attack(goal_function, constraints, transformation, search_method)
|