File size: 5,816 Bytes
c33448f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb18007
c33448f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import logging

import datasets
import gradio as gr
import pandas as pd
import datetime

from fetch_utils import (check_dataset_and_get_config,
                         check_dataset_and_get_split)

import leaderboard
logger = logging.getLogger(__name__)
global update_time 
update_time = datetime.datetime.fromtimestamp(0)

def get_records_from_dataset_repo(dataset_id):
    dataset_config = check_dataset_and_get_config(dataset_id)

    logger.info(f"Dataset {dataset_id} has configs {dataset_config}")
    dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0])
    logger.info(f"Dataset {dataset_id} has splits {dataset_split}")

    try:
        ds = datasets.load_dataset(dataset_id, dataset_config[0], split=dataset_split[0])
        df = ds.to_pandas()
        return df
    except Exception as e:
        logger.warning(
            f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}"
        )
        return pd.DataFrame()

    
def get_model_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['model_id']}")
    models = ds["model_id"].tolist()
    # return unique elements in the list model_ids
    model_ids = list(set(models))
    model_ids.insert(0, "Any")
    return model_ids


def get_dataset_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['dataset_id']}")
    datasets = ds["dataset_id"].tolist()
    dataset_ids = list(set(datasets))
    dataset_ids.insert(0, "Any")
    return dataset_ids


def get_types(ds):
    # set types for each column
    types = [str(t) for t in ds.dtypes.to_list()]
    types = [t.replace("object", "markdown") for t in types]
    types = [t.replace("float64", "number") for t in types]
    types = [t.replace("int64", "number") for t in types]
    return types


def get_display_df(df):
    # style all elements in the model_id column
    display_df = df.copy()
    columns = display_df.columns.tolist()
    if "model_id" in columns:
        display_df["model_id"] = display_df["model_id"].apply(
            lambda x: f'<a href="https://huggingface.co/{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
        )
    # style all elements in the dataset_id column
    if "dataset_id" in columns:
        display_df["dataset_id"] = display_df["dataset_id"].apply(
            lambda x: f'<a href="https://huggingface.co/datasets/{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
        )
    # style all elements in the report_link column
    if "report_link" in columns:
        display_df["report_link"] = display_df["report_link"].apply(
            lambda x: f'<a href="{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
        )
    return display_df

def get_demo(leaderboard_tab):
    global update_time
    update_time = datetime.datetime.now()
    logger.info("Loading leaderboard records")
    leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
    records = leaderboard.records

    model_ids = get_model_ids(records)
    dataset_ids = get_dataset_ids(records)

    column_names = records.columns.tolist()
    default_columns = ["model_id", "dataset_id", "total_issues", "report_link"]
    default_df = records[default_columns]  # extract columns selected
    types = get_types(default_df)
    display_df = get_display_df(default_df)  # the styled dataframe to display

    with gr.Row():
        task_select = gr.Dropdown(
            label="Task",
            choices=["text_classification", "tabular"],
            value="text_classification",
            interactive=True,
        )
        model_select = gr.Dropdown(
            label="Model id", choices=model_ids, value=model_ids[0], interactive=True
        )
        dataset_select = gr.Dropdown(
            label="Dataset id",
            choices=dataset_ids,
            value=dataset_ids[0],
            interactive=True,
        )

    with gr.Row():
        columns_select = gr.CheckboxGroup(
            label="Show columns",
            choices=column_names,
            value=default_columns,
            interactive=True,
        )

    with gr.Row():
        leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)

    def update_leaderboard_records(model_id, dataset_id, columns, task):
        global update_time
        if datetime.datetime.now() - update_time < datetime.timedelta(minutes=10):
            return gr.update()
        update_time = datetime.datetime.now()
        logger.info("Updating leaderboard records")
        leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
        return filter_table(model_id, dataset_id, columns, task)

    leaderboard_tab.select(
        fn=update_leaderboard_records, 
        inputs=[model_select, dataset_select, columns_select, task_select], 
        outputs=[leaderboard_df])

    @gr.on(
        triggers=[
            model_select.change,
            dataset_select.change,
            columns_select.change,
            task_select.change,
        ],
        inputs=[model_select, dataset_select, columns_select, task_select],
        outputs=[leaderboard_df],
    )
    def filter_table(model_id, dataset_id, columns, task):
        logger.info("Filtering leaderboard records")
        records = leaderboard.records
        # filter the table based on task
        df = records[(records["task"] == task)]
        # filter the table based on the model_id and dataset_id
        if model_id and model_id != "Any":
            df = df[(df["model_id"] == model_id)]
        if dataset_id and dataset_id != "Any":
            df = df[(df["dataset_id"] == dataset_id)]

        # filter the table based on the columns
        df = df[columns]
        types = get_types(df)
        display_df = get_display_df(df)
        return gr.update(value=display_df, datatype=types, interactive=False)