Bram Vanroy commited on
Commit
87eb8f9
1 Parent(s): 470893f

add visualization

Browse files
Files changed (1) hide show
  1. app.py +56 -1
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- from collections import defaultdict
3
  from dataclasses import dataclass, field, fields
4
  from functools import cached_property
5
  from pathlib import Path
@@ -10,6 +9,7 @@ import pandas as pd
10
  import gradio as gr
11
  from pandas import DataFrame
12
  from pandas.io.formats.style import Styler
 
13
 
14
  from content import *
15
 
@@ -176,6 +176,51 @@ class ResultSet:
176
  styler = styler.hide()
177
  return styler
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def convert_number_to_kmb(number: int) -> str:
181
  """
@@ -282,6 +327,16 @@ with gr.Blocks() as demo:
282
  gr.Markdown("## LaTeX")
283
  gr.Code(results.latex_df.to_latex(convert_css=True))
284
 
 
 
 
 
 
 
 
 
 
 
285
  gr.Markdown(DISCLAIMER, elem_classes="markdown-text")
286
  gr.Markdown(CREDIT, elem_classes="markdown-text")
287
  gr.Markdown(CITATION, elem_classes="markdown-text")
 
1
  import json
 
2
  from dataclasses import dataclass, field, fields
3
  from functools import cached_property
4
  from pathlib import Path
 
9
  import gradio as gr
10
  from pandas import DataFrame
11
  from pandas.io.formats.style import Styler
12
+ import plotly.graph_objects as go
13
 
14
  from content import *
15
 
 
176
  styler = styler.hide()
177
  return styler
178
 
179
+ @cached_property
180
+ def viz_checkboxes(self):
181
+ model_col_name = self.column_names["short_name"]
182
+ avg_col = self.column_names["average"]
183
+ top3_models = self.df.sort_values(by=avg_col, ascending=False)[model_col_name].tolist()[:3]
184
+ return gr.CheckboxGroup(self.df[model_col_name].tolist(), label="Models", value=top3_models)
185
+
186
+ def plot(self, model_names: list[str]):
187
+ if not model_names:
188
+ return None
189
+
190
+ # Only get task columns and model name
191
+ task_columns = [col for attr, col in self.column_names.items() if attr in TASK_METRICS or attr == "short_name"]
192
+ df = self.df[task_columns]
193
+
194
+ # Rename the columns to the task names
195
+ reversed_col_names = {v: k for k, v in self.column_names.items() if v != "Model"}
196
+ df = df.rename(columns=reversed_col_names)
197
+
198
+ # Only keep the selected models
199
+ df = df[df["Model"].isin(model_names)]
200
+
201
+ # Melt the dataframe to long format
202
+ df = df.melt(id_vars=["Model"], var_name="Task", value_name="Score").sort_values(by="Task")
203
+
204
+ # Populate figure
205
+ fig = go.Figure()
206
+ for model_name in model_names:
207
+ model_df = df[df["Model"] == model_name]
208
+ scores = model_df["Score"].tolist()
209
+ tasks = model_df["Task"].tolist()
210
+
211
+ # Repeat the first point at the end to close the lines
212
+ # Cf. https://community.plotly.com/t/closing-line-for-radar-cart-and-popup-window-on-chart-radar/47711/4
213
+ scores.append(scores[0])
214
+ tasks.append(tasks[0])
215
+
216
+ fig.add_trace(go.Scatterpolar(r=scores, theta=tasks, name=model_name))
217
+
218
+ fig.update_layout(
219
+ title="Model performance on Dutch benchmarks",
220
+ )
221
+
222
+ return fig
223
+
224
 
225
  def convert_number_to_kmb(number: int) -> str:
226
  """
 
327
  gr.Markdown("## LaTeX")
328
  gr.Code(results.latex_df.to_latex(convert_css=True))
329
 
330
+ gr.Markdown("## Visualization")
331
+ with gr.Row():
332
+ with gr.Column():
333
+ buttons = results.viz_checkboxes
334
+
335
+ with gr.Column(scale=2):
336
+ plot = gr.Plot(container=True)
337
+ buttons.change(results.plot, inputs=buttons, outputs=[plot])
338
+ demo.load(results.plot, inputs=buttons, outputs=[plot])
339
+
340
  gr.Markdown(DISCLAIMER, elem_classes="markdown-text")
341
  gr.Markdown(CREDIT, elem_classes="markdown-text")
342
  gr.Markdown(CITATION, elem_classes="markdown-text")