Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
import gradio as gr | |
import nltk | |
import pandas as pd | |
import requests | |
from pyabsa import TADCheckpointManager | |
from textattack.attack_recipes import ( | |
BAEGarg2019, | |
PWWSRen2019, | |
TextFoolerJin2019, | |
PSOZang2020, | |
IGAWang2019, | |
GeneticAlgorithmAlzantot2018, | |
DeepWordBugGao2018, | |
CLARE2020, | |
) | |
from textattack.attack_results import SuccessfulAttackResult | |
from utils import SentAttacker, get_agnews_example, get_sst2_example, get_amazon_example, get_imdb_example, diff_texts | |
# from utils import get_yahoo_example | |
sent_attackers = {} | |
tad_classifiers = {} | |
attack_recipes = { | |
"bae": BAEGarg2019, | |
"pwws": PWWSRen2019, | |
"textfooler": TextFoolerJin2019, | |
"pso": PSOZang2020, | |
"iga": IGAWang2019, | |
"ga": GeneticAlgorithmAlzantot2018, | |
"deepwordbug": DeepWordBugGao2018, | |
"clare": CLARE2020, | |
} | |
def init(): | |
nltk.download("omw-1.4") | |
if not os.path.exists("TAD-SST2"): | |
z = zipfile.ZipFile("checkpoints.zip", "r") | |
z.extractall(os.getcwd()) | |
for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]: | |
for dataset in [ | |
"agnews10k", | |
"sst2", | |
"MR", | |
'imdb' | |
]: | |
if "tad-{}".format(dataset) not in tad_classifiers: | |
tad_classifiers[ | |
"tad-{}".format(dataset) | |
] = TADCheckpointManager.get_tad_text_classifier( | |
"tad-{}".format(dataset).upper() | |
) | |
sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker( | |
tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker] | |
) | |
tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[ | |
"tad-{}pwws".format(dataset) | |
] | |
cache = set() | |
def generate_adversarial_example(dataset, attacker, text=None, label=None): | |
if not text or text in cache: | |
if "agnews" in dataset.lower(): | |
text, label = get_agnews_example() | |
elif "sst2" in dataset.lower(): | |
text, label = get_sst2_example() | |
elif "MR" in dataset.lower(): | |
text, label = get_amazon_example() | |
# elif "yahoo" in dataset.lower(): | |
# text, label = get_yahoo_example() | |
elif "imdb" in dataset.lower(): | |
text, label = get_imdb_example() | |
cache.add(text) | |
result = None | |
attack_result = sent_attackers[ | |
"tad-{}{}".format(dataset.lower(), attacker.lower()) | |
].attacker.simple_attack(text, int(label)) | |
if isinstance(attack_result, SuccessfulAttackResult): | |
if ( | |
attack_result.perturbed_result.output | |
!= attack_result.original_result.ground_truth_output | |
) and ( | |
attack_result.original_result.output | |
== attack_result.original_result.ground_truth_output | |
): | |
# with defense | |
result = tad_classifiers["tad-{}".format(dataset.lower())].infer( | |
attack_result.perturbed_result.attacked_text.text | |
+ "$LABEL${},{},{}".format( | |
attack_result.original_result.ground_truth_output, | |
1, | |
attack_result.perturbed_result.output, | |
), | |
print_result=True, | |
defense=attacker, | |
) | |
if result: | |
classification_df = {} | |
classification_df["is_repaired"] = result["is_fixed"] | |
classification_df["pred_label"] = result["label"] | |
classification_df["confidence"] = round(result["confidence"], 3) | |
classification_df["is_correct"] = str(result["pred_label"]) == str(label) | |
advdetection_df = {} | |
if result["is_adv_label"] != "0": | |
advdetection_df["is_adversarial"] = { | |
"0": False, | |
"1": True, | |
0: False, | |
1: True, | |
}[result["is_adv_label"]] | |
advdetection_df["perturbed_label"] = result["perturbed_label"] | |
advdetection_df["confidence"] = round(result["is_adv_confidence"], 3) | |
advdetection_df['ref_is_attack'] = result['ref_is_adv_label'] | |
advdetection_df['is_correct'] = result['ref_is_adv_check'] | |
else: | |
return generate_adversarial_example(dataset, attacker) | |
return ( | |
text, | |
label, | |
result["restored_text"], | |
result["label"], | |
attack_result.perturbed_result.attacked_text.text, | |
diff_texts(text, text), | |
diff_texts(text, attack_result.perturbed_result.attacked_text.text), | |
diff_texts(text, result["restored_text"]), | |
attack_result.perturbed_result.output, | |
pd.DataFrame(classification_df, index=[0]), | |
pd.DataFrame(advdetection_df, index=[0]), | |
) | |
def run_demo(dataset, attacker, text=None, label=None): | |
try: | |
data = { | |
"dataset": dataset, | |
"attacker": attacker, | |
"text": text, | |
"label": label, | |
} | |
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', json=data) | |
result = response.json() | |
print(response.json()) | |
return ( | |
result["text"], | |
result["label"], | |
result["restored_text"], | |
result["result_label"], | |
result["perturbed_text"], | |
result["text_diff"], | |
result["perturbed_diff"], | |
result["restored_diff"], | |
result["output"], | |
pd.DataFrame(result["classification_df"]), | |
pd.DataFrame(result["advdetection_df"]), | |
result["message"] | |
) | |
except Exception as e: | |
print(e) | |
return generate_adversarial_example(dataset, attacker, text, label) | |
def check_gpu(): | |
try: | |
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', timeout=3) | |
if response.status_code < 500: | |
return 'GPU available' | |
else: | |
return 'GPU not available' | |
except Exception as e: | |
return 'GPU not available' | |
if __name__ == "__main__": | |
try: | |
init() | |
except Exception as e: | |
print(e) | |
print("Failed to initialize the demo. Please try again later.") | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("<h1 align='center'>Detection and Correction based on Word Importance Ranking (DCWIR) </h1>") | |
gr.Markdown("<h2 align='center'>Clarifications</h2>") | |
gr.Markdown(""" | |
- This demo has no mechanism to ensure the adversarial example will be correctly repaired by DCWIR. | |
- The adversarial example and corrected adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations. | |
- All the proposed attacks are Black Box attack where the attacker has no access to the model parameters. | |
""") | |
gr.Markdown("<h2 align='center'>Natural Example Input</h2>") | |
with gr.Group(): | |
with gr.Row(): | |
input_dataset = gr.Radio( | |
choices=["SST2", "IMDB", "MR", "AGNews10K"], | |
value="SST2", | |
label="Select a testing dataset and an adversarial attacker to generate an adversarial example.", | |
) | |
input_attacker = gr.Radio( | |
choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"], | |
value="TextFooler", | |
label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.", | |
) | |
with gr.Group(visible=True): | |
with gr.Row(): | |
input_sentence = gr.Textbox( | |
placeholder="Input a natural example...", | |
label="Alternatively, input a natural example and its original label (from above datasets) to generate an adversarial example.", | |
) | |
input_label = gr.Textbox( | |
placeholder="Original label, (must be a integer, because we use digits to represent labels in training)", | |
label="Original Label", | |
) | |
gr.Markdown( | |
"<h3 align='center'>Default parameters are set according to the main experiment setup in the report.</h2>", | |
) | |
with gr.Row(): | |
wir_percentage = gr.Textbox( | |
placeholder="Enter percentage from WIR...", | |
label="Percentage from WIR", | |
) | |
frequency_threshold = gr.Textbox( | |
placeholder="Enter frequency threshold...", | |
label="Frequency Threshold", | |
) | |
max_candidates = gr.Textbox( | |
placeholder="Enter maximum number of candidates...", | |
label="Maximum Number of Candidates", | |
) | |
msg_text = gr.Textbox( | |
label="Message", | |
placeholder="This is a message box to show any error messages.", | |
) | |
button_gen = gr.Button( | |
"Generate an adversarial example to repair using DCWIR (GPU: < 1 minute, CPU: 1-10 minutes)", | |
variant="primary", | |
) | |
gpu_status_text = gr.Textbox( | |
label='GPU status', | |
placeholder="Please click to check", | |
) | |
button_check = gr.Button( | |
"Check if GPU available", | |
variant="primary" | |
) | |
button_check.click( | |
fn=check_gpu, | |
inputs=[], | |
outputs=[ | |
gpu_status_text | |
] | |
) | |
gr.Markdown("<h2 align='center'>Generated Adversarial Example and Repaired Adversarial Example</h2>") | |
with gr.Column(): | |
with gr.Group(): | |
with gr.Row(): | |
output_original_example = gr.Textbox(label="Original Example") | |
output_original_label = gr.Textbox(label="Original Label") | |
with gr.Row(): | |
output_adv_example = gr.Textbox(label="Adversarial Example") | |
output_adv_label = gr.Textbox(label="Predicted Label of the Adversarial Example") | |
with gr.Row(): | |
output_repaired_example = gr.Textbox( | |
label="Repaired Adversarial Example by Rapid" | |
) | |
output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example") | |
gr.Markdown("<h2 align='center'>Example Difference (Comparisons)</p>") | |
gr.Markdown(""" | |
<p align='center'>The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.</p> | |
""") | |
ori_text_diff = gr.HighlightedText( | |
label="The Original Natural Example", | |
combine_adjacent=True, | |
show_legend=True, | |
) | |
adv_text_diff = gr.HighlightedText( | |
label="Character Editions of Adversarial Example Compared to the Natural Example", | |
combine_adjacent=True, | |
show_legend=True, | |
) | |
restored_text_diff = gr.HighlightedText( | |
label="Character Editions of Repaired Adversarial Example Compared to the Natural Example", | |
combine_adjacent=True, | |
show_legend=True, | |
) | |
gr.Markdown( | |
"## <h2 align='center'>The Output of Reactive Perturbation Defocusing</p>" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
output_is_adv_df = gr.DataFrame( | |
label="Adversarial Example Detection Result" | |
) | |
gr.Markdown( | |
""" | |
- The is_adversarial field indicates if an adversarial example is detected. | |
- The perturbed_label is the predicted label of the adversarial example. | |
- The confidence field represents the ratio of Inverted samples among the total number of generated candidates. | |
""" | |
) | |
with gr.Column(): | |
with gr.Group(): | |
output_df = gr.DataFrame( | |
label="Correction Classification Result" | |
) | |
gr.Markdown( | |
""" | |
- If is_corrected=true, it has been Corrected by DCWIR. | |
- The pred_label field indicates the standard classification result. | |
- The confidence field represents ratio of the dominant class among all Inverted candidates. | |
- The is_correct field indicates whether the predicted label is correct. | |
""" | |
) | |
# Bind functions to buttons | |
button_gen.click( | |
fn=run_demo, | |
inputs=[input_dataset, input_attacker, input_sentence, input_label], | |
outputs=[ | |
output_original_example, | |
output_original_label, | |
output_repaired_example, | |
output_repaired_label, | |
output_adv_example, | |
ori_text_diff, | |
adv_text_diff, | |
restored_text_diff, | |
output_adv_label, | |
output_df, | |
output_is_adv_df, | |
msg_text | |
], | |
) | |
demo.queue(2).launch() |