Spaces:
Sleeping
Sleeping
import gradio as gr | |
import uuid | |
from io_utils import read_scanners, write_scanners, read_inference_type, write_inference_type | |
from wordings import INTRODUCTION_MD, CONFIRM_MAPPING_DETAILS_MD | |
from text_classification_ui_helpers import try_submit, check_dataset_and_get_config, check_dataset_and_get_split, check_model_and_show_prediction, write_column_mapping_to_config, get_logs_file | |
MAX_LABELS = 20 | |
MAX_FEATURES = 20 | |
EXAMPLE_MODEL_ID = 'cardiffnlp/twitter-roberta-base-sentiment-latest' | |
EXAMPLE_DATA_ID = 'tweet_eval' | |
CONFIG_PATH='./config.yaml' | |
def get_demo(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown(INTRODUCTION_MD) | |
with gr.Row(): | |
model_id_input = gr.Textbox( | |
label="Hugging Face model id", | |
placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)", | |
) | |
dataset_id_input = gr.Textbox( | |
label="Hugging Face Dataset id", | |
placeholder=EXAMPLE_DATA_ID + " (press enter to confirm)", | |
) | |
with gr.Row(): | |
dataset_config_input = gr.Dropdown(label='Dataset Config', visible=False) | |
dataset_split_input = gr.Dropdown(label='Dataset Split', visible=False) | |
with gr.Row(): | |
example_input = gr.Markdown('Example Input', visible=False) | |
with gr.Row(): | |
example_prediction = gr.Label(label='Model Prediction Sample', visible=False) | |
with gr.Row(): | |
with gr.Accordion(label='Label and Feature Mapping', visible=False, open=False) as column_mapping_accordion: | |
with gr.Row(): | |
gr.Markdown(CONFIRM_MAPPING_DETAILS_MD) | |
column_mappings = [] | |
with gr.Row(): | |
with gr.Column(): | |
for _ in range(MAX_LABELS): | |
column_mappings.append(gr.Dropdown(visible=False)) | |
with gr.Column(): | |
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES): | |
column_mappings.append(gr.Dropdown(visible=False)) | |
with gr.Accordion(label='Model Wrap Advance Config (optional)', open=False): | |
run_local = gr.Checkbox(value=True, label="Run in this Space") | |
use_inference = read_inference_type('./config.yaml') == 'hf_inference_api' | |
run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API") | |
with gr.Accordion(label='Scanner Advance Config (optional)', open=False): | |
selected = read_scanners('./config.yaml') | |
# currently we remove data_leakage from the default scanners | |
# Reason: data_leakage barely raises any issues and takes too many requests | |
# when using inference API, causing rate limit error | |
scan_config = selected + ['data_leakage'] | |
scanners = gr.CheckboxGroup(choices=scan_config, value=selected, label='Scan Settings', visible=True) | |
with gr.Row(): | |
run_btn = gr.Button( | |
"Get Evaluation Result", | |
variant="primary", | |
interactive=True, | |
size="lg", | |
) | |
with gr.Row(): | |
uid = uuid.uuid4() | |
uid_label = gr.Textbox(label="Evaluation ID:", value=uid, visible=False) | |
logs = gr.Textbox(label="Giskard Bot Evaluation Log:", visible=False) | |
demo.load(get_logs_file, uid_label, logs, every=0.5) | |
gr.on(triggers=[label.change for label in column_mappings], | |
fn=write_column_mapping_to_config, | |
inputs=[dataset_id_input, dataset_config_input, dataset_split_input, *column_mappings]) | |
gr.on(triggers=[model_id_input.change, dataset_config_input.change, dataset_split_input.change], | |
fn=check_model_and_show_prediction, | |
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input], | |
outputs=[example_input, example_prediction, column_mapping_accordion, *column_mappings]) | |
dataset_id_input.blur(check_dataset_and_get_config, dataset_id_input, dataset_config_input) | |
dataset_config_input.change( | |
check_dataset_and_get_split, | |
inputs=[dataset_id_input, dataset_config_input], | |
outputs=[dataset_split_input]) | |
scanners.change( | |
write_scanners, | |
inputs=scanners | |
) | |
run_inference.change( | |
write_inference_type, | |
inputs=[run_inference] | |
) | |
gr.on( | |
triggers=[ | |
run_btn.click, | |
], | |
fn=try_submit, | |
inputs=[ | |
model_id_input, | |
dataset_id_input, | |
dataset_config_input, | |
dataset_split_input, | |
run_local, | |
uid_label], | |
outputs=[run_btn, logs]) | |
def enable_run_btn(): | |
return gr.update(interactive=True) | |
gr.on( | |
triggers=[ | |
model_id_input.change, | |
dataset_config_input.change, | |
dataset_split_input.change, | |
run_inference.change, | |
run_local.change, | |
scanners.change], | |
fn=enable_run_btn, | |
inputs=None, | |
outputs=[run_btn]) | |
gr.on( | |
triggers=[label.change for label in column_mappings], | |
fn=enable_run_btn, | |
inputs=None, | |
outputs=[run_btn]) |