|
import json |
|
import logging |
|
import os |
|
import subprocess |
|
import time |
|
|
|
import datasets |
|
import gradio as gr |
|
import huggingface_hub |
|
from transformers.pipelines import TextClassificationPipeline |
|
|
|
from io_utils import ( |
|
convert_column_mapping_to_json, |
|
read_inference_type, |
|
read_scanners, |
|
write_inference_type, |
|
write_scanners, |
|
) |
|
from text_classification import ( |
|
check_column_mapping_keys_validity, |
|
text_classification_fix_column_mapping, |
|
) |
|
from wordings import CONFIRM_MAPPING_DETAILS_FAIL_MD, CONFIRM_MAPPING_DETAILS_MD |
|
|
|
HF_REPO_ID = "HF_REPO_ID" |
|
HF_SPACE_ID = "SPACE_ID" |
|
HF_WRITE_TOKEN = "HF_WRITE_TOKEN" |
|
|
|
|
|
def check_model(model_id): |
|
try: |
|
task = huggingface_hub.model_info(model_id).pipeline_tag |
|
except Exception: |
|
return None, None |
|
|
|
try: |
|
from transformers import pipeline |
|
|
|
ppl = pipeline(task=task, model=model_id) |
|
|
|
return model_id, ppl |
|
except Exception as e: |
|
return model_id, e |
|
|
|
|
|
def check_dataset(dataset_id, dataset_config="default", dataset_split="test"): |
|
try: |
|
configs = datasets.get_dataset_config_names(dataset_id) |
|
except Exception: |
|
|
|
return None, dataset_config, dataset_split |
|
|
|
if dataset_config not in configs: |
|
|
|
return dataset_id, configs, dataset_split |
|
|
|
ds = datasets.load_dataset(dataset_id, dataset_config) |
|
|
|
if isinstance(ds, datasets.DatasetDict): |
|
|
|
if dataset_split not in ds.keys(): |
|
return dataset_id, None, list(ds.keys()) |
|
elif not isinstance(ds, datasets.Dataset): |
|
|
|
return dataset_id, None, None |
|
return dataset_id, dataset_config, dataset_split |
|
|
|
|
|
def try_validate( |
|
m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping="{}" |
|
): |
|
|
|
if m_id is None: |
|
gr.Warning( |
|
"Model is not accessible. Please set your HF_TOKEN if it is a private model." |
|
) |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
if isinstance(ppl, Exception): |
|
gr.Warning(f'Failed to load model": {ppl}') |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
d_id, config, split = check_dataset( |
|
dataset_id=dataset_id, |
|
dataset_config=dataset_config, |
|
dataset_split=dataset_split, |
|
) |
|
|
|
dataset_ok = False |
|
if d_id is None: |
|
gr.Warning( |
|
f'Dataset "{dataset_id}" is not accessible. Please set your HF_TOKEN if it is a private dataset.' |
|
) |
|
elif isinstance(config, list): |
|
gr.Warning( |
|
f'Dataset "{dataset_id}" does not have "{dataset_config}" config. Please choose a valid config.' |
|
) |
|
config = gr.update(choices=config, value=config[0]) |
|
elif isinstance(split, list): |
|
gr.Warning( |
|
f'Dataset "{dataset_id}" does not have "{dataset_split}" split. Please choose a valid split.' |
|
) |
|
split = gr.update(choices=split, value=split[0]) |
|
else: |
|
dataset_ok = True |
|
|
|
if not dataset_ok: |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
prediction_result = None |
|
id2label_df = None |
|
if isinstance(ppl, TextClassificationPipeline): |
|
try: |
|
column_mapping = json.loads(column_mapping) |
|
except Exception: |
|
column_mapping = {} |
|
|
|
( |
|
column_mapping, |
|
prediction_input, |
|
prediction_result, |
|
id2label_df, |
|
feature_df, |
|
) = text_classification_fix_column_mapping( |
|
column_mapping, ppl, d_id, config, split |
|
) |
|
|
|
column_mapping = json.dumps(column_mapping, indent=2) |
|
|
|
if prediction_result is None and id2label_df is not None: |
|
gr.Warning( |
|
'The model failed to predict with the first row in the dataset. Please provide feature mappings in "Advance" settings.' |
|
) |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=False), |
|
gr.update(CONFIRM_MAPPING_DETAILS_MD, visible=True), |
|
gr.update( |
|
value=f"**Sample Input**: {prediction_input}", visible=True |
|
), |
|
gr.update(visible=False), |
|
gr.update( |
|
value=id2label_df, visible=True, interactive=True |
|
), |
|
gr.update( |
|
value=feature_df, visible=True, interactive=True |
|
), |
|
) |
|
elif id2label_df is None: |
|
gr.Warning( |
|
'The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.' |
|
) |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=False), |
|
gr.update(CONFIRM_MAPPING_DETAILS_MD, visible=True), |
|
gr.update( |
|
value=f"**Sample Input**: {prediction_input}", visible=True |
|
), |
|
gr.update( |
|
value=prediction_result, visible=True |
|
), |
|
gr.update(visible=True, interactive=True), |
|
gr.update(visible=True, interactive=True), |
|
) |
|
|
|
gr.Info( |
|
"Model and dataset validations passed. Your can submit the evaluation task." |
|
) |
|
|
|
return ( |
|
gr.update(interactive=True), |
|
gr.update(visible=False), |
|
gr.update(CONFIRM_MAPPING_DETAILS_MD, visible=True), |
|
gr.update( |
|
value=f"**Sample Input**: {prediction_input}", visible=True |
|
), |
|
gr.update(value=prediction_result, visible=True), |
|
gr.update( |
|
value=id2label_df, visible=True, interactive=True |
|
), |
|
gr.update( |
|
value=feature_df, visible=True, interactive=True |
|
), |
|
) |
|
|
|
|
|
def try_submit( |
|
m_id, |
|
d_id, |
|
config, |
|
split, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
local, |
|
): |
|
label_mapping = {} |
|
for i, label in id2label_mapping_dataframe["Model Prediction Labels"].items(): |
|
label_mapping.update({str(i): label}) |
|
|
|
feature_mapping = {} |
|
for i, feature in feature_mapping_dataframe["Dataset Features"].items(): |
|
feature_mapping.update( |
|
{feature_mapping_dataframe["Model Input Features"][i]: feature} |
|
) |
|
|
|
|
|
|
|
if local: |
|
command = [ |
|
"giskard_scanner", |
|
"--loader", |
|
"huggingface", |
|
"--model", |
|
m_id, |
|
"--dataset", |
|
d_id, |
|
"--dataset_config", |
|
config, |
|
"--dataset_split", |
|
split, |
|
"--hf_token", |
|
os.environ.get(HF_WRITE_TOKEN), |
|
"--discussion_repo", |
|
os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID), |
|
"--output_format", |
|
"markdown", |
|
"--output_portal", |
|
"huggingface", |
|
"--feature_mapping", |
|
json.dumps(feature_mapping), |
|
"--label_mapping", |
|
json.dumps(label_mapping), |
|
"--scan_config", |
|
"../config.yaml", |
|
] |
|
|
|
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" |
|
start = time.time() |
|
logging.info(f"Start local evaluation on {eval_str}") |
|
|
|
evaluator = subprocess.Popen( |
|
command, |
|
stderr=subprocess.STDOUT, |
|
) |
|
result = evaluator.wait() |
|
|
|
logging.info( |
|
f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s" |
|
) |
|
|
|
gr.Info( |
|
f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s" |
|
) |
|
else: |
|
gr.Info("TODO: Submit task to an endpoint") |
|
|
|
return gr.update(interactive=True) |
|
|
|
|
|
def get_demo(): |
|
|
|
|
|
|
|
|
|
def check_dataset_and_get_config(dataset_id): |
|
try: |
|
configs = datasets.get_dataset_config_names(dataset_id) |
|
return gr.Dropdown(configs, value=configs[0], visible=True) |
|
except Exception: |
|
|
|
pass |
|
|
|
def check_dataset_and_get_split(dataset_config, dataset_id): |
|
try: |
|
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys()) |
|
return gr.Dropdown(splits, value=splits[0], visible=True) |
|
except Exception as e: |
|
|
|
gr.Warning( |
|
f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}" |
|
) |
|
|
|
def clear_column_mapping_tables(): |
|
return [ |
|
gr.update(CONFIRM_MAPPING_DETAILS_FAIL_MD, visible=True), |
|
gr.update(value=[], visible=False, interactive=True), |
|
gr.update(value=[], visible=False, interactive=True), |
|
] |
|
|
|
def gate_validate_btn( |
|
model_id, |
|
dataset_id, |
|
dataset_config, |
|
dataset_split, |
|
id2label_mapping_dataframe=None, |
|
feature_mapping_dataframe=None, |
|
): |
|
column_mapping = "{}" |
|
_, ppl = check_model(model_id=model_id) |
|
|
|
if id2label_mapping_dataframe is not None: |
|
labels = convert_column_mapping_to_json( |
|
id2label_mapping_dataframe.value, label="data" |
|
) |
|
features = convert_column_mapping_to_json( |
|
feature_mapping_dataframe.value, label="text" |
|
) |
|
column_mapping = json.dumps({**labels, **features}, indent=2) |
|
|
|
if check_column_mapping_keys_validity(column_mapping, ppl) is False: |
|
gr.Warning("Label mapping table has invalid contents. Please check again.") |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(CONFIRM_MAPPING_DETAILS_FAIL_MD, visible=True), |
|
gr.update(), |
|
gr.update(), |
|
gr.update(), |
|
gr.update(), |
|
gr.update(), |
|
) |
|
else: |
|
if model_id and dataset_id and dataset_config and dataset_split: |
|
return try_validate( |
|
model_id, |
|
ppl, |
|
dataset_id, |
|
dataset_config, |
|
dataset_split, |
|
column_mapping, |
|
) |
|
else: |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown(CONFIRM_MAPPING_DETAILS_MD) |
|
with gr.Row(): |
|
run_local = gr.Checkbox(value=True, label="Run in this Space") |
|
use_inference = read_inference_type("./config.yaml") == "hf_inference_api" |
|
run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API") |
|
|
|
with gr.Row(): |
|
selected = read_scanners("./config.yaml") |
|
scan_config = selected + ["data_leakage"] |
|
scanners = gr.CheckboxGroup( |
|
choices=scan_config, value=selected, label="Scan Settings", visible=True |
|
) |
|
|
|
with gr.Row(): |
|
model_id_input = gr.Textbox( |
|
label="Hugging Face model id", |
|
placeholder="cardiffnlp/twitter-roberta-base-sentiment-latest", |
|
) |
|
|
|
dataset_id_input = gr.Textbox( |
|
label="Hugging Face Dataset id", |
|
placeholder="tweet_eval", |
|
) |
|
with gr.Row(): |
|
dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False) |
|
dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False) |
|
|
|
with gr.Row(visible=True) as loading_row: |
|
gr.Markdown( |
|
""" |
|
<p style="text-align: center;"> |
|
🚀🐢Please validate your model and dataset first... |
|
</p> |
|
""" |
|
) |
|
|
|
with gr.Row(visible=False) as preview_row: |
|
gr.Markdown( |
|
""" |
|
<h1 style="text-align: center;"> |
|
Confirm Pre-processing Details |
|
</h1> |
|
Base on your model and dataset, we inferred this label mapping and feature mapping. <b>If the mapping is incorrect, please modify it in the table below.</b> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
id2label_mapping_dataframe = gr.DataFrame( |
|
label="Preview of label mapping", interactive=True, visible=False |
|
) |
|
feature_mapping_dataframe = gr.DataFrame( |
|
label="Preview of feature mapping", interactive=True, visible=False |
|
) |
|
with gr.Row(): |
|
example_input = gr.Markdown("Sample Input: ", visible=False) |
|
|
|
with gr.Row(): |
|
example_labels = gr.Label(label="Model Prediction Sample", visible=False) |
|
|
|
run_btn = gr.Button( |
|
"Get Evaluation Result", |
|
variant="primary", |
|
interactive=False, |
|
size="lg", |
|
) |
|
|
|
model_id_input.blur( |
|
clear_column_mapping_tables, |
|
outputs=[id2label_mapping_dataframe, feature_mapping_dataframe], |
|
) |
|
|
|
dataset_id_input.blur( |
|
check_dataset_and_get_config, dataset_id_input, dataset_config_input |
|
) |
|
dataset_id_input.submit( |
|
check_dataset_and_get_config, dataset_id_input, dataset_config_input |
|
) |
|
|
|
dataset_config_input.change( |
|
check_dataset_and_get_split, |
|
inputs=[dataset_config_input, dataset_id_input], |
|
outputs=[dataset_split_input], |
|
) |
|
|
|
dataset_id_input.blur( |
|
clear_column_mapping_tables, |
|
outputs=[id2label_mapping_dataframe, feature_mapping_dataframe], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_config_input.change( |
|
gate_validate_btn, |
|
inputs=[ |
|
model_id_input, |
|
dataset_id_input, |
|
dataset_config_input, |
|
dataset_split_input, |
|
], |
|
outputs=[ |
|
run_btn, |
|
loading_row, |
|
preview_row, |
|
example_input, |
|
example_labels, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
], |
|
) |
|
dataset_split_input.change( |
|
gate_validate_btn, |
|
inputs=[ |
|
model_id_input, |
|
dataset_id_input, |
|
dataset_config_input, |
|
dataset_split_input, |
|
], |
|
outputs=[ |
|
run_btn, |
|
loading_row, |
|
preview_row, |
|
example_input, |
|
example_labels, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
], |
|
) |
|
id2label_mapping_dataframe.input( |
|
gate_validate_btn, |
|
inputs=[ |
|
model_id_input, |
|
dataset_id_input, |
|
dataset_config_input, |
|
dataset_split_input, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
], |
|
outputs=[ |
|
run_btn, |
|
loading_row, |
|
preview_row, |
|
example_input, |
|
example_labels, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
], |
|
) |
|
feature_mapping_dataframe.input( |
|
gate_validate_btn, |
|
inputs=[ |
|
model_id_input, |
|
dataset_id_input, |
|
dataset_config_input, |
|
dataset_split_input, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
], |
|
outputs=[ |
|
run_btn, |
|
loading_row, |
|
preview_row, |
|
example_input, |
|
example_labels, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
], |
|
) |
|
scanners.change(write_scanners, inputs=scanners) |
|
run_inference.change(write_inference_type, inputs=[run_inference]) |
|
|
|
run_btn.click( |
|
try_submit, |
|
inputs=[ |
|
model_id_input, |
|
dataset_id_input, |
|
dataset_config_input, |
|
dataset_split_input, |
|
id2label_mapping_dataframe, |
|
feature_mapping_dataframe, |
|
run_local, |
|
], |
|
outputs=[ |
|
run_btn, |
|
], |
|
) |
|
|