Spaces:
Running
Running
ZeroCommand
commited on
Commit
•
f04482d
1
Parent(s):
be473e6
update log area
Browse files- app.py +23 -8
- app_text_classification.py +12 -179
- io_utils.py +44 -1
- run_jobs.py +29 -0
- text_classification_ui_helpers.py +184 -0
- tmp/pipe +0 -0
app.py
CHANGED
@@ -3,15 +3,30 @@
|
|
3 |
# from pathlib import Path
|
4 |
|
5 |
import gradio as gr
|
6 |
-
|
7 |
from app_text_classification import get_demo as get_demo_text_classification
|
8 |
from app_leaderboard import get_demo as get_demo_leaderboard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
|
11 |
-
with gr.Tab("Text Classification"):
|
12 |
-
get_demo_text_classification()
|
13 |
-
with gr.Tab("Leaderboard"):
|
14 |
-
get_demo_leaderboard()
|
15 |
|
16 |
-
demo.queue(max_size=100)
|
17 |
-
demo.launch(share=False)
|
|
|
3 |
# from pathlib import Path
|
4 |
|
5 |
import gradio as gr
|
6 |
+
import atexit
|
7 |
from app_text_classification import get_demo as get_demo_text_classification
|
8 |
from app_leaderboard import get_demo as get_demo_leaderboard
|
9 |
+
from run_jobs import start_process_run_job, stop_thread
|
10 |
+
import threading
|
11 |
+
|
12 |
+
if threading.current_thread() is not threading.main_thread():
|
13 |
+
t = threading.current_thread()
|
14 |
+
print(t.do_run)
|
15 |
+
try:
|
16 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
|
17 |
+
with gr.Tab("Text Classification"):
|
18 |
+
get_demo_text_classification()
|
19 |
+
with gr.Tab("Leaderboard"):
|
20 |
+
get_demo_leaderboard()
|
21 |
+
|
22 |
+
start_process_run_job()
|
23 |
+
|
24 |
+
demo.queue(max_size=100)
|
25 |
+
demo.launch(share=False)
|
26 |
+
atexit.register(stop_thread)
|
27 |
+
|
28 |
+
except Exception:
|
29 |
+
print("stop background thread")
|
30 |
+
stop_thread()
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
|
|
|
|
app_text_classification.py
CHANGED
@@ -1,22 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import subprocess
|
6 |
-
import logging
|
7 |
-
import collections
|
8 |
-
|
9 |
-
import json
|
10 |
-
|
11 |
-
from transformers.pipelines import TextClassificationPipeline
|
12 |
-
|
13 |
-
from text_classification import get_labels_and_features_from_dataset, check_model, get_example_prediction
|
14 |
-
from io_utils import read_scanners, write_scanners, read_inference_type, read_column_mapping, write_column_mapping, write_inference_type
|
15 |
-
from wordings import INTRODUCTION_MD, CONFIRM_MAPPING_DETAILS_MD, CONFIRM_MAPPING_DETAILS_FAIL_RAW
|
16 |
-
|
17 |
-
HF_REPO_ID = 'HF_REPO_ID'
|
18 |
-
HF_SPACE_ID = 'SPACE_ID'
|
19 |
-
HF_WRITE_TOKEN = 'HF_WRITE_TOKEN'
|
20 |
|
21 |
MAX_LABELS = 20
|
22 |
MAX_FEATURES = 20
|
@@ -25,75 +10,6 @@ EXAMPLE_MODEL_ID = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
|
|
25 |
EXAMPLE_DATA_ID = 'tweet_eval'
|
26 |
CONFIG_PATH='./config.yaml'
|
27 |
|
28 |
-
def try_submit(m_id, d_id, config, split, local):
|
29 |
-
all_mappings = read_column_mapping(CONFIG_PATH)
|
30 |
-
|
31 |
-
if "labels" not in all_mappings.keys():
|
32 |
-
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
33 |
-
return gr.update(interactive=True)
|
34 |
-
label_mapping = all_mappings["labels"]
|
35 |
-
|
36 |
-
if "features" not in all_mappings.keys():
|
37 |
-
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
38 |
-
return gr.update(interactive=True)
|
39 |
-
feature_mapping = all_mappings["features"]
|
40 |
-
|
41 |
-
# TODO: Set column mapping for some dataset such as `amazon_polarity`
|
42 |
-
if local:
|
43 |
-
command = [
|
44 |
-
"python",
|
45 |
-
"cli.py",
|
46 |
-
"--loader", "huggingface",
|
47 |
-
"--model", m_id,
|
48 |
-
"--dataset", d_id,
|
49 |
-
"--dataset_config", config,
|
50 |
-
"--dataset_split", split,
|
51 |
-
"--hf_token", os.environ.get(HF_WRITE_TOKEN),
|
52 |
-
"--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
|
53 |
-
"--output_format", "markdown",
|
54 |
-
"--output_portal", "huggingface",
|
55 |
-
"--feature_mapping", json.dumps(feature_mapping),
|
56 |
-
"--label_mapping", json.dumps(label_mapping),
|
57 |
-
"--scan_config", "../config.yaml",
|
58 |
-
]
|
59 |
-
|
60 |
-
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
61 |
-
start = time.time()
|
62 |
-
logging.info(f"Start local evaluation on {eval_str}")
|
63 |
-
|
64 |
-
evaluator = subprocess.Popen(
|
65 |
-
command,
|
66 |
-
cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "cicd"),
|
67 |
-
stderr=subprocess.STDOUT,
|
68 |
-
)
|
69 |
-
result = evaluator.wait()
|
70 |
-
|
71 |
-
logging.info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s")
|
72 |
-
|
73 |
-
gr.Info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s")
|
74 |
-
else:
|
75 |
-
gr.Info("TODO: Submit task to an endpoint")
|
76 |
-
|
77 |
-
return gr.update(interactive=True) # Submit button
|
78 |
-
|
79 |
-
|
80 |
-
def check_dataset_and_get_config(dataset_id):
|
81 |
-
try:
|
82 |
-
configs = datasets.get_dataset_config_names(dataset_id)
|
83 |
-
return gr.Dropdown(configs, value=configs[0], visible=True)
|
84 |
-
except Exception:
|
85 |
-
# Dataset may not exist
|
86 |
-
pass
|
87 |
-
|
88 |
-
def check_dataset_and_get_split(dataset_id, dataset_config):
|
89 |
-
try:
|
90 |
-
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
|
91 |
-
return gr.Dropdown(splits, value=splits[0], visible=True)
|
92 |
-
except Exception:
|
93 |
-
# Dataset may not exist
|
94 |
-
# gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
|
95 |
-
pass
|
96 |
-
|
97 |
def get_demo():
|
98 |
with gr.Row():
|
99 |
gr.Markdown(INTRODUCTION_MD)
|
@@ -147,102 +63,18 @@ def get_demo():
|
|
147 |
interactive=True,
|
148 |
size="lg",
|
149 |
)
|
|
|
|
|
|
|
150 |
|
151 |
-
|
|
|
152 |
inputs=[dataset_id_input, dataset_config_input, dataset_split_input, *column_mappings])
|
153 |
-
def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels):
|
154 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
|
155 |
-
if labels is None:
|
156 |
-
return
|
157 |
-
labels = [*labels]
|
158 |
-
all_mappings = read_column_mapping(CONFIG_PATH)
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
for i, label in enumerate(labels[:MAX_LABELS]):
|
163 |
-
if label:
|
164 |
-
all_mappings["labels"][label] = ds_labels[i]
|
165 |
-
|
166 |
-
if "features" not in all_mappings.keys():
|
167 |
-
all_mappings["features"] = dict()
|
168 |
-
for i, feat in enumerate(labels[MAX_LABELS:(MAX_LABELS + MAX_FEATURES)]):
|
169 |
-
if feat:
|
170 |
-
all_mappings["features"][feat] = ds_features[i]
|
171 |
-
write_column_mapping(all_mappings)
|
172 |
-
|
173 |
-
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
|
174 |
-
model_labels = list(model_id2label.values())
|
175 |
-
lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])]
|
176 |
-
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
|
177 |
-
# TODO: Substitute 'text' with more features for zero-shot
|
178 |
-
features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in ['text']]
|
179 |
-
features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))]
|
180 |
-
return lables + features
|
181 |
-
|
182 |
-
@gr.on(triggers=[model_id_input.change, dataset_config_input.change])
|
183 |
-
def clear_column_mapping_config():
|
184 |
-
write_column_mapping(None)
|
185 |
-
|
186 |
-
@gr.on(triggers=[model_id_input.change, dataset_config_input.change, dataset_split_input.change],
|
187 |
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
|
188 |
outputs=[example_input, example_prediction, column_mapping_accordion, *column_mappings])
|
189 |
-
def check_model_and_show_prediction(model_id, dataset_id, dataset_config, dataset_split):
|
190 |
-
ppl = check_model(model_id)
|
191 |
-
if ppl is None or not isinstance(ppl, TextClassificationPipeline):
|
192 |
-
gr.Warning("Please check your model.")
|
193 |
-
return (
|
194 |
-
gr.update(visible=False),
|
195 |
-
gr.update(visible=False),
|
196 |
-
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
|
197 |
-
)
|
198 |
-
|
199 |
-
dropdown_placement = [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
|
200 |
-
|
201 |
-
if ppl is None: # pipeline not found
|
202 |
-
gr.Warning("Model not found")
|
203 |
-
return (
|
204 |
-
gr.update(visible=False),
|
205 |
-
gr.update(visible=False),
|
206 |
-
gr.update(visible=False, open=False),
|
207 |
-
*dropdown_placement
|
208 |
-
)
|
209 |
-
model_id2label = ppl.model.config.id2label
|
210 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
|
211 |
-
|
212 |
-
# when dataset does not have labels or features
|
213 |
-
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
|
214 |
-
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
215 |
-
return (
|
216 |
-
gr.update(visible=False),
|
217 |
-
gr.update(visible=False),
|
218 |
-
gr.update(visible=False, open=False),
|
219 |
-
*dropdown_placement
|
220 |
-
)
|
221 |
-
|
222 |
-
column_mappings = list_labels_and_features_from_dataset(
|
223 |
-
ds_labels,
|
224 |
-
ds_features,
|
225 |
-
model_id2label,
|
226 |
-
)
|
227 |
-
|
228 |
-
# when labels or features are not aligned
|
229 |
-
# show manually column mapping
|
230 |
-
if collections.Counter(model_id2label.items()) != collections.Counter(ds_labels) or ds_features[0] != 'text':
|
231 |
-
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
232 |
-
return (
|
233 |
-
gr.update(visible=False),
|
234 |
-
gr.update(visible=False),
|
235 |
-
gr.update(visible=True, open=True),
|
236 |
-
*column_mappings
|
237 |
-
)
|
238 |
-
|
239 |
-
prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split)
|
240 |
-
return (
|
241 |
-
gr.update(value=prediction_input, visible=True),
|
242 |
-
gr.update(value=prediction_output, visible=True),
|
243 |
-
gr.update(visible=True, open=False),
|
244 |
-
*column_mappings
|
245 |
-
)
|
246 |
|
247 |
dataset_id_input.blur(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
|
248 |
|
@@ -267,4 +99,5 @@ def get_demo():
|
|
267 |
],
|
268 |
fn=try_submit,
|
269 |
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, run_local],
|
270 |
-
outputs=[run_btn])
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from io_utils import read_scanners, write_scanners, read_inference_type, write_inference_type
|
3 |
+
from wordings import INTRODUCTION_MD, CONFIRM_MAPPING_DETAILS_MD
|
4 |
+
from text_classification_ui_helpers import try_submit, check_dataset_and_get_config, check_dataset_and_get_split, check_model_and_show_prediction, write_column_mapping_to_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
MAX_LABELS = 20
|
7 |
MAX_FEATURES = 20
|
|
|
10 |
EXAMPLE_DATA_ID = 'tweet_eval'
|
11 |
CONFIG_PATH='./config.yaml'
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def get_demo():
|
14 |
with gr.Row():
|
15 |
gr.Markdown(INTRODUCTION_MD)
|
|
|
63 |
interactive=True,
|
64 |
size="lg",
|
65 |
)
|
66 |
+
|
67 |
+
with gr.Row():
|
68 |
+
logs = gr.Textbox(label="Giskard Bot Evaluation Log:", visible=False)
|
69 |
|
70 |
+
gr.on(triggers=[label.change for label in column_mappings],
|
71 |
+
fn=write_column_mapping_to_config,
|
72 |
inputs=[dataset_id_input, dataset_config_input, dataset_split_input, *column_mappings])
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
gr.on(triggers=[model_id_input.change, dataset_config_input.change, dataset_split_input.change],
|
75 |
+
fn=check_model_and_show_prediction,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
|
77 |
outputs=[example_input, example_prediction, column_mapping_accordion, *column_mappings])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
dataset_id_input.blur(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
|
80 |
|
|
|
99 |
],
|
100 |
fn=try_submit,
|
101 |
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, run_local],
|
102 |
+
outputs=[run_btn, logs])
|
103 |
+
|
io_utils.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import yaml
|
|
|
|
|
2 |
|
3 |
YAML_PATH = "./config.yaml"
|
|
|
4 |
|
5 |
class Dumper(yaml.Dumper):
|
6 |
def increase_indent(self, flow=False, *args, **kwargs):
|
@@ -56,7 +59,9 @@ def read_column_mapping(path):
|
|
56 |
def write_column_mapping(mapping):
|
57 |
with open(YAML_PATH, "r") as f:
|
58 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
59 |
-
if
|
|
|
|
|
60 |
del config["column_mapping"]
|
61 |
else:
|
62 |
config["column_mapping"] = mapping
|
@@ -71,3 +76,41 @@ def convert_column_mapping_to_json(df, label=""):
|
|
71 |
for _, row in df.iterrows():
|
72 |
column_mapping[label].append(row.tolist())
|
73 |
return column_mapping
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import yaml
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
|
5 |
YAML_PATH = "./config.yaml"
|
6 |
+
PIPE_PATH = "./tmp/pipe"
|
7 |
|
8 |
class Dumper(yaml.Dumper):
|
9 |
def increase_indent(self, flow=False, *args, **kwargs):
|
|
|
59 |
def write_column_mapping(mapping):
|
60 |
with open(YAML_PATH, "r") as f:
|
61 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
62 |
+
if config is None:
|
63 |
+
return
|
64 |
+
if mapping is None and "column_mapping" in config.keys():
|
65 |
del config["column_mapping"]
|
66 |
else:
|
67 |
config["column_mapping"] = mapping
|
|
|
76 |
for _, row in df.iterrows():
|
77 |
column_mapping[label].append(row.tolist())
|
78 |
return column_mapping
|
79 |
+
|
80 |
+
def write_log_to_user_file(id, log):
|
81 |
+
with open(f"./tmp/{id}_log", "a") as f:
|
82 |
+
f.write(log)
|
83 |
+
|
84 |
+
def save_job_to_pipe(id, job, lock):
|
85 |
+
if not os.path.exists('./tmp'):
|
86 |
+
os.makedirs('./tmp')
|
87 |
+
job = [str(i) for i in job]
|
88 |
+
job = ",".join(job)
|
89 |
+
print(job)
|
90 |
+
with lock:
|
91 |
+
with open(PIPE_PATH, "a") as f:
|
92 |
+
# write each element in job
|
93 |
+
f.write(f'{id}@{job}\n')
|
94 |
+
|
95 |
+
def pop_job_from_pipe():
|
96 |
+
if not os.path.exists(PIPE_PATH):
|
97 |
+
return
|
98 |
+
with open(PIPE_PATH, "r+") as f:
|
99 |
+
jobs = f.readlines()
|
100 |
+
f.write("\n".join(jobs[1:]))
|
101 |
+
f.close()
|
102 |
+
if len(jobs) == 0:
|
103 |
+
return
|
104 |
+
job_info = jobs[0].split('\n')[0].split("@")
|
105 |
+
if len(job_info) != 2:
|
106 |
+
raise ValueError("Invalid job info: ", job_info)
|
107 |
+
print(f"Running job {job_info}")
|
108 |
+
command = job_info[1].split(",")
|
109 |
+
print(command)
|
110 |
+
log_file = open(f"./tmp/{job_info[0]}_log", "w")
|
111 |
+
subprocess.Popen(
|
112 |
+
command,
|
113 |
+
cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "cicd"),
|
114 |
+
stdout=log_file,
|
115 |
+
stderr=log_file,
|
116 |
+
)
|
run_jobs.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io_utils import pop_job_from_pipe
|
2 |
+
import time
|
3 |
+
import threading
|
4 |
+
|
5 |
+
def start_process_run_job():
|
6 |
+
try:
|
7 |
+
print("Running jobs in thread")
|
8 |
+
global thread
|
9 |
+
thread = threading.Thread(target=run_job)
|
10 |
+
thread.daemon = True
|
11 |
+
thread.do_run = True
|
12 |
+
thread.start()
|
13 |
+
|
14 |
+
except Exception as e:
|
15 |
+
print("Failed to start thread: ", e)
|
16 |
+
def stop_thread():
|
17 |
+
print("Stop thread")
|
18 |
+
thread.do_run = False
|
19 |
+
|
20 |
+
def run_job():
|
21 |
+
while True:
|
22 |
+
print(thread.do_run)
|
23 |
+
try:
|
24 |
+
pop_job_from_pipe()
|
25 |
+
time.sleep(10)
|
26 |
+
except KeyboardInterrupt:
|
27 |
+
print("KeyboardInterrupt stop background thread")
|
28 |
+
stop_thread()
|
29 |
+
break
|
text_classification_ui_helpers.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from wordings import CONFIRM_MAPPING_DETAILS_FAIL_RAW
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
import uuid
|
7 |
+
import threading
|
8 |
+
from io_utils import read_column_mapping, write_column_mapping, save_job_to_pipe, write_log_to_user_file
|
9 |
+
import datasets
|
10 |
+
import collections
|
11 |
+
from text_classification import get_labels_and_features_from_dataset, check_model, get_example_prediction
|
12 |
+
from transformers.pipelines import TextClassificationPipeline
|
13 |
+
|
14 |
+
MAX_LABELS = 20
|
15 |
+
MAX_FEATURES = 20
|
16 |
+
|
17 |
+
HF_REPO_ID = 'HF_REPO_ID'
|
18 |
+
HF_SPACE_ID = 'SPACE_ID'
|
19 |
+
HF_WRITE_TOKEN = 'HF_WRITE_TOKEN'
|
20 |
+
CONFIG_PATH = "./config.yaml"
|
21 |
+
|
22 |
+
def check_dataset_and_get_config(dataset_id):
|
23 |
+
try:
|
24 |
+
write_column_mapping(None)
|
25 |
+
configs = datasets.get_dataset_config_names(dataset_id)
|
26 |
+
return gr.Dropdown(configs, value=configs[0], visible=True)
|
27 |
+
except Exception:
|
28 |
+
# Dataset may not exist
|
29 |
+
pass
|
30 |
+
|
31 |
+
def check_dataset_and_get_split(dataset_id, dataset_config):
|
32 |
+
try:
|
33 |
+
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
|
34 |
+
return gr.Dropdown(splits, value=splits[0], visible=True)
|
35 |
+
except Exception:
|
36 |
+
# Dataset may not exist
|
37 |
+
# gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
|
38 |
+
pass
|
39 |
+
|
40 |
+
def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels):
|
41 |
+
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
|
42 |
+
if labels is None:
|
43 |
+
return
|
44 |
+
labels = [*labels]
|
45 |
+
all_mappings = read_column_mapping(CONFIG_PATH)
|
46 |
+
|
47 |
+
if "labels" not in all_mappings.keys():
|
48 |
+
all_mappings["labels"] = dict()
|
49 |
+
for i, label in enumerate(labels[:MAX_LABELS]):
|
50 |
+
if label:
|
51 |
+
all_mappings["labels"][label] = ds_labels[i]
|
52 |
+
|
53 |
+
if "features" not in all_mappings.keys():
|
54 |
+
all_mappings["features"] = dict()
|
55 |
+
for i, feat in enumerate(labels[MAX_LABELS:(MAX_LABELS + MAX_FEATURES)]):
|
56 |
+
if feat:
|
57 |
+
all_mappings["features"][feat] = ds_features[i]
|
58 |
+
write_column_mapping(all_mappings)
|
59 |
+
|
60 |
+
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
|
61 |
+
model_labels = list(model_id2label.values())
|
62 |
+
lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])]
|
63 |
+
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
|
64 |
+
# TODO: Substitute 'text' with more features for zero-shot
|
65 |
+
features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in ['text']]
|
66 |
+
features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))]
|
67 |
+
return lables + features
|
68 |
+
|
69 |
+
def check_model_and_show_prediction(model_id, dataset_id, dataset_config, dataset_split):
|
70 |
+
ppl = check_model(model_id)
|
71 |
+
if ppl is None or not isinstance(ppl, TextClassificationPipeline):
|
72 |
+
gr.Warning("Please check your model.")
|
73 |
+
return (
|
74 |
+
gr.update(visible=False),
|
75 |
+
gr.update(visible=False),
|
76 |
+
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
|
77 |
+
)
|
78 |
+
|
79 |
+
dropdown_placement = [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
|
80 |
+
|
81 |
+
if ppl is None: # pipeline not found
|
82 |
+
gr.Warning("Model not found")
|
83 |
+
return (
|
84 |
+
gr.update(visible=False),
|
85 |
+
gr.update(visible=False),
|
86 |
+
gr.update(visible=False, open=False),
|
87 |
+
*dropdown_placement
|
88 |
+
)
|
89 |
+
model_id2label = ppl.model.config.id2label
|
90 |
+
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
|
91 |
+
|
92 |
+
# when dataset does not have labels or features
|
93 |
+
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
|
94 |
+
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
95 |
+
return (
|
96 |
+
gr.update(visible=False),
|
97 |
+
gr.update(visible=False),
|
98 |
+
gr.update(visible=False, open=False),
|
99 |
+
*dropdown_placement
|
100 |
+
)
|
101 |
+
|
102 |
+
column_mappings = list_labels_and_features_from_dataset(
|
103 |
+
ds_labels,
|
104 |
+
ds_features,
|
105 |
+
model_id2label,
|
106 |
+
)
|
107 |
+
|
108 |
+
# when labels or features are not aligned
|
109 |
+
# show manually column mapping
|
110 |
+
if collections.Counter(model_id2label.values()) != collections.Counter(ds_labels) or ds_features[0] != 'text':
|
111 |
+
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
112 |
+
return (
|
113 |
+
gr.update(visible=False),
|
114 |
+
gr.update(visible=False),
|
115 |
+
gr.update(visible=True, open=True),
|
116 |
+
*column_mappings
|
117 |
+
)
|
118 |
+
|
119 |
+
prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split)
|
120 |
+
return (
|
121 |
+
gr.update(value=prediction_input, visible=True),
|
122 |
+
gr.update(value=prediction_output, visible=True),
|
123 |
+
gr.update(visible=True, open=False),
|
124 |
+
*column_mappings
|
125 |
+
)
|
126 |
+
|
127 |
+
def get_logs_file(uid):
|
128 |
+
file = open(f"./tmp/{uid}_log")
|
129 |
+
contents = file.readlines()
|
130 |
+
file.close()
|
131 |
+
return '\n'.join(contents)
|
132 |
+
|
133 |
+
def try_submit(m_id, d_id, config, split, local):
|
134 |
+
all_mappings = read_column_mapping(CONFIG_PATH)
|
135 |
+
|
136 |
+
if all_mappings is None:
|
137 |
+
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
138 |
+
return gr.update(interactive=True)
|
139 |
+
|
140 |
+
if "labels" not in all_mappings.keys():
|
141 |
+
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
142 |
+
return gr.update(interactive=True)
|
143 |
+
label_mapping = all_mappings["labels"]
|
144 |
+
|
145 |
+
if "features" not in all_mappings.keys():
|
146 |
+
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
147 |
+
return gr.update(interactive=True)
|
148 |
+
feature_mapping = all_mappings["features"]
|
149 |
+
|
150 |
+
# TODO: Set column mapping for some dataset such as `amazon_polarity`
|
151 |
+
if local:
|
152 |
+
command = [
|
153 |
+
"python",
|
154 |
+
"cli.py",
|
155 |
+
"--loader", "huggingface",
|
156 |
+
"--model", m_id,
|
157 |
+
"--dataset", d_id,
|
158 |
+
"--dataset_config", config,
|
159 |
+
"--dataset_split", split,
|
160 |
+
"--hf_token", os.environ.get(HF_WRITE_TOKEN),
|
161 |
+
"--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
|
162 |
+
"--output_format", "markdown",
|
163 |
+
"--output_portal", "huggingface",
|
164 |
+
"--feature_mapping", json.dumps(feature_mapping),
|
165 |
+
"--label_mapping", json.dumps(label_mapping),
|
166 |
+
"--scan_config", "../config.yaml",
|
167 |
+
]
|
168 |
+
|
169 |
+
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
170 |
+
logging.info(f"Start local evaluation on {eval_str}")
|
171 |
+
uid = uuid.uuid4()
|
172 |
+
save_job_to_pipe(uid, command, threading.Lock())
|
173 |
+
write_log_to_user_file(uid, f"Start local evaluation on {eval_str}. Please wait for your job to start...\n")
|
174 |
+
gr.Info(f"Start local evaluation on {eval_str}")
|
175 |
+
|
176 |
+
return (
|
177 |
+
gr.update(interactive=False),
|
178 |
+
gr.update(value=get_logs_file(uid), visible=True, interactive=False))
|
179 |
+
|
180 |
+
else:
|
181 |
+
gr.Info("TODO: Submit task to an endpoint")
|
182 |
+
|
183 |
+
return (gr.update(interactive=True), # Submit button
|
184 |
+
gr.update(visible=False))
|
tmp/pipe
ADDED
The diff for this file is too large to render.
See raw diff
|
|