PFEemp2024's picture
fixing the dataset bug
7a9762c verified
import os
import zipfile
import gradio as gr
import nltk
import pandas as pd
import requests
from pyabsa import TADCheckpointManager
from textattack.attack_recipes import (
BAEGarg2019,
PWWSRen2019,
TextFoolerJin2019,
PSOZang2020,
IGAWang2019,
GeneticAlgorithmAlzantot2018,
DeepWordBugGao2018,
CLARE2020,
)
from textattack.attack_results import SuccessfulAttackResult
from utils import SentAttacker, get_agnews_example, get_sst2_example, get_amazon_example, get_imdb_example, diff_texts
# from utils import get_yahoo_example
sent_attackers = {}
tad_classifiers = {}
attack_recipes = {
"bae": BAEGarg2019,
"pwws": PWWSRen2019,
"textfooler": TextFoolerJin2019,
"pso": PSOZang2020,
"iga": IGAWang2019,
"ga": GeneticAlgorithmAlzantot2018,
"deepwordbug": DeepWordBugGao2018,
"clare": CLARE2020,
}
def init():
nltk.download("omw-1.4")
if not os.path.exists("TAD-SST2"):
z = zipfile.ZipFile("checkpoints.zip", "r")
z.extractall(os.getcwd())
for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]:
for dataset in [
"agnews10k",
"sst2",
"MR",
'imdb'
]:
if "tad-{}".format(dataset) not in tad_classifiers:
tad_classifiers[
"tad-{}".format(dataset)
] = TADCheckpointManager.get_tad_text_classifier(
"tad-{}".format(dataset).upper()
)
sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker(
tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker]
)
tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[
"tad-{}pwws".format(dataset)
]
cache = set()
def generate_adversarial_example(dataset, attacker, text=None, label=None):
if not text or text in cache:
if "agnews" in dataset.lower():
text, label = get_agnews_example()
elif "sst2" in dataset.lower():
text, label = get_sst2_example()
elif "MR" in dataset.lower():
text, label = get_amazon_example()
# elif "yahoo" in dataset.lower():
# text, label = get_yahoo_example()
elif "imdb" in dataset.lower():
text, label = get_imdb_example()
cache.add(text)
result = None
attack_result = sent_attackers[
"tad-{}{}".format(dataset.lower(), attacker.lower())
].attacker.simple_attack(text, int(label))
if isinstance(attack_result, SuccessfulAttackResult):
if (
attack_result.perturbed_result.output
!= attack_result.original_result.ground_truth_output
) and (
attack_result.original_result.output
== attack_result.original_result.ground_truth_output
):
# with defense
result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
attack_result.perturbed_result.attacked_text.text
+ "$LABEL${},{},{}".format(
attack_result.original_result.ground_truth_output,
1,
attack_result.perturbed_result.output,
),
print_result=True,
defense=attacker,
)
if result:
classification_df = {}
classification_df["is_repaired"] = result["is_fixed"]
classification_df["pred_label"] = result["label"]
classification_df["confidence"] = round(result["confidence"], 3)
classification_df["is_correct"] = str(result["pred_label"]) == str(label)
advdetection_df = {}
if result["is_adv_label"] != "0":
advdetection_df["is_adversarial"] = {
"0": False,
"1": True,
0: False,
1: True,
}[result["is_adv_label"]]
advdetection_df["perturbed_label"] = result["perturbed_label"]
advdetection_df["confidence"] = round(result["is_adv_confidence"], 3)
advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
advdetection_df['is_correct'] = result['ref_is_adv_check']
else:
return generate_adversarial_example(dataset, attacker)
return (
text,
label,
result["restored_text"],
result["label"],
attack_result.perturbed_result.attacked_text.text,
diff_texts(text, text),
diff_texts(text, attack_result.perturbed_result.attacked_text.text),
diff_texts(text, result["restored_text"]),
attack_result.perturbed_result.output,
pd.DataFrame(classification_df, index=[0]),
pd.DataFrame(advdetection_df, index=[0]),
)
def run_demo(dataset, attacker, text=None, label=None):
try:
data = {
"dataset": dataset,
"attacker": attacker,
"text": text,
"label": label,
}
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', json=data)
result = response.json()
print(response.json())
return (
result["text"],
result["label"],
result["restored_text"],
result["result_label"],
result["perturbed_text"],
result["text_diff"],
result["perturbed_diff"],
result["restored_diff"],
result["output"],
pd.DataFrame(result["classification_df"]),
pd.DataFrame(result["advdetection_df"]),
result["message"]
)
except Exception as e:
print(e)
return generate_adversarial_example(dataset, attacker, text, label)
def check_gpu():
try:
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', timeout=3)
if response.status_code < 500:
return 'GPU available'
else:
return 'GPU not available'
except Exception as e:
return 'GPU not available'
if __name__ == "__main__":
try:
init()
except Exception as e:
print(e)
print("Failed to initialize the demo. Please try again later.")
demo = gr.Blocks()
with demo:
gr.Markdown("<h1 align='center'>Detection and Correction based on Word Importance Ranking (DCWIR) </h1>")
gr.Markdown("<h2 align='center'>Clarifications</h2>")
gr.Markdown("""
- This demo has no mechanism to ensure the adversarial example will be correctly repaired by DCWIR.
- The adversarial example and corrected adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations.
- All the proposed attacks are Black Box attack where the attacker has no access to the model parameters.
""")
gr.Markdown("<h2 align='center'>Natural Example Input</h2>")
with gr.Group():
with gr.Row():
input_dataset = gr.Radio(
choices=["SST2", "IMDB", "MR", "AGNews10K"],
value="SST2",
label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
)
input_attacker = gr.Radio(
choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"],
value="TextFooler",
label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
)
with gr.Group(visible=True):
with gr.Row():
input_sentence = gr.Textbox(
placeholder="Input a natural example...",
label="Alternatively, input a natural example and its original label (from above datasets) to generate an adversarial example.",
)
input_label = gr.Textbox(
placeholder="Original label, (must be a integer, because we use digits to represent labels in training)",
label="Original Label",
)
gr.Markdown(
"<h3 align='center'>Default parameters are set according to the main experiment setup in the report.</h2>",
)
with gr.Row():
wir_percentage = gr.Textbox(
placeholder="Enter percentage from WIR...",
label="Percentage from WIR",
)
frequency_threshold = gr.Textbox(
placeholder="Enter frequency threshold...",
label="Frequency Threshold",
)
max_candidates = gr.Textbox(
placeholder="Enter maximum number of candidates...",
label="Maximum Number of Candidates",
)
msg_text = gr.Textbox(
label="Message",
placeholder="This is a message box to show any error messages.",
)
button_gen = gr.Button(
"Generate an adversarial example to repair using DCWIR (GPU: < 1 minute, CPU: 1-10 minutes)",
variant="primary",
)
gpu_status_text = gr.Textbox(
label='GPU status',
placeholder="Please click to check",
)
button_check = gr.Button(
"Check if GPU available",
variant="primary"
)
button_check.click(
fn=check_gpu,
inputs=[],
outputs=[
gpu_status_text
]
)
gr.Markdown("<h2 align='center'>Generated Adversarial Example and Repaired Adversarial Example</h2>")
with gr.Column():
with gr.Group():
with gr.Row():
output_original_example = gr.Textbox(label="Original Example")
output_original_label = gr.Textbox(label="Original Label")
with gr.Row():
output_adv_example = gr.Textbox(label="Adversarial Example")
output_adv_label = gr.Textbox(label="Predicted Label of the Adversarial Example")
with gr.Row():
output_repaired_example = gr.Textbox(
label="Repaired Adversarial Example by Rapid"
)
output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example")
gr.Markdown("<h2 align='center'>Example Difference (Comparisons)</p>")
gr.Markdown("""
<p align='center'>The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.</p>
""")
ori_text_diff = gr.HighlightedText(
label="The Original Natural Example",
combine_adjacent=True,
show_legend=True,
)
adv_text_diff = gr.HighlightedText(
label="Character Editions of Adversarial Example Compared to the Natural Example",
combine_adjacent=True,
show_legend=True,
)
restored_text_diff = gr.HighlightedText(
label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
combine_adjacent=True,
show_legend=True,
)
gr.Markdown(
"## <h2 align='center'>The Output of Reactive Perturbation Defocusing</p>"
)
with gr.Row():
with gr.Column():
with gr.Group():
output_is_adv_df = gr.DataFrame(
label="Adversarial Example Detection Result"
)
gr.Markdown(
"""
- The is_adversarial field indicates if an adversarial example is detected.
- The perturbed_label is the predicted label of the adversarial example.
- The confidence field represents the ratio of Inverted samples among the total number of generated candidates.
"""
)
with gr.Column():
with gr.Group():
output_df = gr.DataFrame(
label="Correction Classification Result"
)
gr.Markdown(
"""
- If is_corrected=true, it has been Corrected by DCWIR.
- The pred_label field indicates the standard classification result.
- The confidence field represents ratio of the dominant class among all Inverted candidates.
- The is_correct field indicates whether the predicted label is correct.
"""
)
# Bind functions to buttons
button_gen.click(
fn=run_demo,
inputs=[input_dataset, input_attacker, input_sentence, input_label],
outputs=[
output_original_example,
output_original_label,
output_repaired_example,
output_repaired_label,
output_adv_example,
ori_text_diff,
adv_text_diff,
restored_text_diff,
output_adv_label,
output_df,
output_is_adv_df,
msg_text
],
)
demo.queue(2).launch()