ajude commited on
Commit
59ce1f5
1 Parent(s): 5373df2

Tie model and language selection to selected tab

Browse files
Files changed (2) hide show
  1. app.py +8 -8
  2. core.py +21 -3
app.py CHANGED
@@ -101,23 +101,23 @@ with demo:
101
 
102
  demo.load(
103
  core.update_task_groups_and_fewshot,
104
- [gr.State(value=0), fewshot],
105
- [shown_tasks, fewshot, selected_tab],
106
  )
107
  fewshot.change(
108
  core.update_task_groups_and_fewshot,
109
- [selected_tab, fewshot],
110
- [shown_tasks, fewshot, selected_tab],
111
  )
112
  acc.select(
113
  core.update_task_groups_and_fewshot,
114
- inputs=[gr.State(value=0), fewshot],
115
- outputs=[shown_tasks, fewshot, selected_tab],
116
  )
117
  misc.select(
118
  core.update_task_groups_and_fewshot,
119
- inputs=[gr.State(value=1), fewshot],
120
- outputs=[shown_tasks, fewshot, selected_tab],
121
  )
122
  for comp, fn in [
123
  (search_bar, "submit"),
 
101
 
102
  demo.load(
103
  core.update_task_groups_and_fewshot,
104
+ [gr.State(value=0), model_types, langs_bar,fewshot],
105
+ [shown_tasks, fewshot, selected_tab, model_types, langs_bar],
106
  )
107
  fewshot.change(
108
  core.update_task_groups_and_fewshot,
109
+ [selected_tab, model_types, langs_bar, fewshot],
110
+ [shown_tasks, fewshot, selected_tab, model_types, langs_bar],
111
  )
112
  acc.select(
113
  core.update_task_groups_and_fewshot,
114
+ inputs=[gr.State(value=0), model_types, langs_bar, fewshot],
115
+ outputs=[shown_tasks, fewshot, selected_tab, model_types, langs_bar],
116
  )
117
  misc.select(
118
  core.update_task_groups_and_fewshot,
119
+ inputs=[gr.State(value=1), model_types, langs_bar, fewshot],
120
+ outputs=[shown_tasks, fewshot, selected_tab, model_types, langs_bar],
121
  )
122
  for comp, fn in [
123
  (search_bar, "submit"),
core.py CHANGED
@@ -4,10 +4,10 @@ import os
4
  import gradio as gr
5
  import numpy as np
6
  import pandas as pd
7
- import plotly.express as px
8
  from datasets import load_dataset
9
 
10
  import style
 
11
 
12
  ZERO_SHOT_ONLY = ["BELEBELE"]
13
  FEW_SHOT_ONLY = ["GSM8K", "TruthfulQA"]
@@ -127,7 +127,7 @@ def update_df(
127
  return sort_cols(df, fewshot)
128
 
129
 
130
- def update_task_groups_and_fewshot(current_selected_tab: int, is_fewshot_current: bool = False):
131
  selected_task_type = get_selected_task_type(current_selected_tab)
132
  available_tasks = get_available_task_groups(selected_task_type, is_fewshot_current)
133
  new_selected_tasks = available_tasks.copy()
@@ -149,7 +149,25 @@ def update_task_groups_and_fewshot(current_selected_tab: int, is_fewshot_current
149
  interactive=fewshot_available,
150
  )
151
 
152
- return [tasks_checkbox_group_update, fewshot_radio_update, current_selected_tab]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
  def get_selected_task_type(task_type_id):
 
4
  import gradio as gr
5
  import numpy as np
6
  import pandas as pd
 
7
  from datasets import load_dataset
8
 
9
  import style
10
+ from style import T_SYMBOLS, LANG_SYMBOLS
11
 
12
  ZERO_SHOT_ONLY = ["BELEBELE"]
13
  FEW_SHOT_ONLY = ["GSM8K", "TruthfulQA"]
 
127
  return sort_cols(df, fewshot)
128
 
129
 
130
+ def update_task_groups_and_fewshot(current_selected_tab: int, model_types, langs_bar, is_fewshot_current: bool = False, ):
131
  selected_task_type = get_selected_task_type(current_selected_tab)
132
  available_tasks = get_available_task_groups(selected_task_type, is_fewshot_current)
133
  new_selected_tasks = available_tasks.copy()
 
149
  interactive=fewshot_available,
150
  )
151
 
152
+ model_types = gr.CheckboxGroup(
153
+ label="Select model type",
154
+ choices=[
155
+ (
156
+ f"Pretrained {T_SYMBOLS['pretrained']}",
157
+ T_SYMBOLS["pretrained"],
158
+ ),
159
+ (f"Chat {T_SYMBOLS['chat']}", T_SYMBOLS["chat"]),
160
+ ],
161
+ value=list(T_SYMBOLS.values()),
162
+ interactive=True
163
+ )
164
+ langs_bar = gr.CheckboxGroup(
165
+ choices=[(LANG_SYMBOLS.get(l, l), l) for l in languages_list],
166
+ value=languages_list,
167
+ interactive=True,
168
+ )
169
+
170
+ return [tasks_checkbox_group_update, fewshot_radio_update, current_selected_tab, model_types, langs_bar]
171
 
172
 
173
  def get_selected_task_type(task_type_id):