jacopoteneggi commited on
Commit
5ead791
1 Parent(s): df15615
Files changed (7) hide show
  1. app.py +7 -28
  2. app_lib/main.py +34 -23
  3. app_lib/test.py +21 -7
  4. app_lib/user_input.py +1 -2
  5. app_lib/viz.py +40 -0
  6. header.md +5 -0
  7. style.css +39 -0
app.py CHANGED
@@ -1,9 +1,12 @@
1
- import numpy as np
2
- import open_clip
3
  import streamlit as st
4
 
5
  from app_lib.main import main
6
 
 
 
 
 
 
7
  if "sidebar_state" not in st.session_state:
8
  st.session_state.sidebar_state = "collapsed"
9
  if "disabled" not in st.session_state:
@@ -14,33 +17,9 @@ if "results" not in st.session_state:
14
  st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
15
 
16
  st.session_state.sidebar_state = "collapsed"
17
- st.markdown(
18
- """
19
- <style>
20
- textarea {
21
- font-family: monospace !important;
22
- }
23
- input {
24
- font-family: monospace !important;
25
- }
26
-
27
- [data-testid="stHorizontalBlock"] {
28
- align-items: center;
29
- }
30
- </style>
31
- """,
32
- unsafe_allow_html=True,
33
- )
34
-
35
- st.markdown(
36
- """
37
- # I Bet You Did Not Mean That
38
-
39
- Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantci Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
40
 
41
- ---
42
- """,
43
- )
44
 
45
  if __name__ == "__main__":
46
  main()
 
 
 
1
  import streamlit as st
2
 
3
  from app_lib.main import main
4
 
5
+ with open("style.css", "r") as f:
6
+ style = f.read()
7
+ with open("header.md", "r") as f:
8
+ header = f.read()
9
+
10
  if "sidebar_state" not in st.session_state:
11
  st.session_state.sidebar_state = "collapsed"
12
  if "disabled" not in st.session_state:
 
17
  st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
18
 
19
  st.session_state.sidebar_state = "collapsed"
20
+ st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ st.markdown(header)
 
 
23
 
24
  if __name__ == "__main__":
25
  main()
app_lib/main.py CHANGED
@@ -10,6 +10,7 @@ from app_lib.user_input import (
10
  get_model_name,
11
  )
12
  from app_lib.test import test
 
13
 
14
 
15
  def _disable():
@@ -20,50 +21,58 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
20
  columns = st.columns([0.40, 0.60])
21
 
22
  with columns[0]:
23
- model_name = get_model_name()
24
 
25
- row1 = st.columns(2)
26
- row2 = st.columns(2)
27
 
28
- with row1[0]:
29
  image = get_image()
30
  st.image(image, use_column_width=True)
31
- with row1[1]:
32
- class_name, class_ready, class_error = get_class_name()
33
- concepts, concepts_ready, concepts_error = get_concepts()
34
- cardinality = get_cardinality(concepts, concepts_ready)
35
 
36
- with row2[0]:
37
  change_image_button = st.button(
38
  "Change Image",
39
- use_container_width=True,
40
  disabled=st.session_state.disabled,
41
  )
42
  if change_image_button:
43
  st.session_state.sidebar_state = "expanded"
44
  st.experimental_rerun()
45
- with row2[1]:
46
- ready = class_ready and concepts_ready
 
 
 
 
47
 
48
- error_message = ""
49
- if class_error is not None:
50
- error_message += f"- {class_error}\n"
51
- if concepts_error is not None:
52
- error_message += f"- {concepts_error}\n"
53
- if error_message:
54
- st.error(error_message)
55
 
 
 
 
 
 
 
 
 
 
56
  test_button = st.button(
57
- "Test",
58
  use_container_width=True,
59
  on_click=_disable,
60
  disabled=st.session_state.disabled or not ready,
61
  )
62
 
 
 
 
63
  with columns[1]:
64
- _, centercol, _ = st.columns(3)
65
- with centercol:
66
- if test_button:
 
 
 
 
67
  test(
68
  image,
69
  class_name,
@@ -73,3 +82,5 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
73
  model_name,
74
  device,
75
  )
 
 
 
10
  get_model_name,
11
  )
12
  from app_lib.test import test
13
+ from app_lib.viz import viz_results
14
 
15
 
16
  def _disable():
 
21
  columns = st.columns([0.40, 0.60])
22
 
23
  with columns[0]:
24
+ st.header("Choose Image and Concepts")
25
 
26
+ image_col, concepts_col = st.columns(2)
 
27
 
28
+ with image_col:
29
  image = get_image()
30
  st.image(image, use_column_width=True)
 
 
 
 
31
 
 
32
  change_image_button = st.button(
33
  "Change Image",
34
+ use_container_width=False,
35
  disabled=st.session_state.disabled,
36
  )
37
  if change_image_button:
38
  st.session_state.sidebar_state = "expanded"
39
  st.experimental_rerun()
40
+ with concepts_col:
41
+ model_name = get_model_name()
42
+ class_name, class_ready, class_error = get_class_name()
43
+ concepts, concepts_ready, concepts_error = get_concepts()
44
+ cardinality = int(len(concepts) / 2)
45
+ # get_cardinality(concepts, concepts_ready)
46
 
47
+ ready = class_ready and concepts_ready
 
 
 
 
 
 
48
 
49
+ error_message = ""
50
+ if class_error is not None:
51
+ error_message += f"- {class_error}\n"
52
+ if concepts_error is not None:
53
+ error_message += f"- {concepts_error}\n"
54
+ if error_message:
55
+ st.error(error_message)
56
+
57
+ with st.container():
58
  test_button = st.button(
59
+ "Test Concepts",
60
  use_container_width=True,
61
  on_click=_disable,
62
  disabled=st.session_state.disabled or not ready,
63
  )
64
 
65
+ with st.popover("Advanced settings", disabled=st.session_state.disabled):
66
+ st.markdown("Hello World 👋")
67
+
68
  with columns[1]:
69
+ st.header("Results")
70
+
71
+ if test_button:
72
+ st.session_state.results = None
73
+
74
+ _, centercol, _ = st.columns(3)
75
+ with centercol:
76
  test(
77
  image,
78
  class_name,
 
82
  model_name,
83
  device,
84
  )
85
+
86
+ viz_results()
app_lib/test.py CHANGED
@@ -4,7 +4,7 @@ import open_clip
4
  import h5py
5
  import streamlit as st
6
  import numpy as np
7
- import time
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
 
10
  import ml_collections
@@ -170,12 +170,26 @@ def test(image, class_name, concepts, cardinality, dataset_name, model_name, dev
170
  results.append(future.result())
171
  progress_bar.progress((idx + 1) / len(concepts))
172
 
173
- # print(results)
174
- # wealth = np.empty((testing_config.tau_max, len(concepts)))
175
- # wealth[:] = np.nan
176
- # for _results in results:
177
- # concept_idx = concepts.index(_results["concept"])
178
- # _wealth =
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  st.session_state.disabled = False
181
  st.experimental_rerun()
 
4
  import h5py
5
  import streamlit as st
6
  import numpy as np
7
+ import pandas as pd
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
 
10
  import ml_collections
 
170
  results.append(future.result())
171
  progress_bar.progress((idx + 1) / len(concepts))
172
 
173
+ rejected = np.empty((testing_config.r, len(concepts)))
174
+ tau = np.empty((testing_config.r, len(concepts)))
175
+ wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts)))
176
+
177
+ for _results in results:
178
+ concept_idx = concepts.index(_results["concept"])
179
+
180
+ rejected[:, concept_idx] = np.array(_results["rejected"])
181
+ tau[:, concept_idx] = np.array(_results["tau"])
182
+ wealth[:, :, concept_idx] = np.array(_results["wealth"])
183
+
184
+ tau /= testing_config.tau_max
185
+
186
+ st.session_state.results = {
187
+ "significance_level": testing_config.significance_level,
188
+ "concepts": concepts,
189
+ "rejected": rejected,
190
+ "tau": tau,
191
+ "wealth": wealth,
192
+ }
193
 
194
  st.session_state.disabled = False
195
  st.experimental_rerun()
app_lib/user_input.py CHANGED
@@ -20,10 +20,9 @@ def _validate_concepts(concepts):
20
  return (False, "Maximum 10 concepts allowed")
21
  return (True, None)
22
 
23
-
24
  def get_model_name():
25
  return st.selectbox(
26
- "Choose a model to test",
27
  options=list(SUPPORTED_MODELS.keys()),
28
  help="Name of the vision-language model to test the predictions of.",
29
  disabled=st.session_state.disabled,
 
20
  return (False, "Maximum 10 concepts allowed")
21
  return (True, None)
22
 
 
23
  def get_model_name():
24
  return st.selectbox(
25
+ "Model to test",
26
  options=list(SUPPORTED_MODELS.keys()),
27
  help="Name of the vision-language model to test the predictions of.",
28
  disabled=st.session_state.disabled,
app_lib/viz.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+
5
+
6
+ def _viz_wealth(results):
7
+ wealth = results["wealth"]
8
+ concepts = results["concepts"]
9
+ significance_level = results["significance_level"]
10
+
11
+ wealth_mu = wealth.mean(axis=0)
12
+
13
+ wealth_df = []
14
+ for concept_idx, concept in enumerate(concepts):
15
+ for t in range(wealth.shape[1]):
16
+ wealth_df.append(
17
+ {"time": t, "concept": concept, "wealth": wealth_mu[t, concept_idx]}
18
+ )
19
+ wealth_df = pd.DataFrame(wealth_df)
20
+
21
+ fig = px.line(wealth_df, x="time", y="wealth", color="concept")
22
+ fig.update_yaxes(range=[0, 3 * 1 / significance_level])
23
+ st.plotly_chart(fig, use_container_width=True)
24
+
25
+
26
+ def viz_results():
27
+ results = st.session_state.results
28
+
29
+ if results is None:
30
+ st.info("Run tests to show results", icon="ℹ️")
31
+ else:
32
+ rank_tab, wealth_tab = st.tabs(["Rank of importance", "Wealth process"])
33
+
34
+ with rank_tab:
35
+ st.subheader("Rank of Semantic Importance")
36
+ with wealth_tab:
37
+ st.subheader("Wealth Process of Testing Procedures")
38
+
39
+ if results is not None:
40
+ _viz_wealth(results)
header.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # 🤔 I Bet You Did Not Mean That
2
+
3
+ Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantic Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
4
+
5
+ ---
style.css ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ textarea {
2
+ font-family: monospace !important;
3
+ }
4
+
5
+ input {
6
+ font-family: monospace !important;
7
+ }
8
+
9
+ h1 {
10
+ padding-top: 0 !important;
11
+ }
12
+
13
+ [data-testid="stHorizontalBlock"] [data-testid="stHorizontalBlock"] {
14
+ align-items: center;
15
+ }
16
+
17
+ [data-testid="stButton"] {
18
+ display: flex;
19
+ justify-content: center;
20
+ }
21
+
22
+ [data-testid="stVerticalBlock"]:has(> [data-testid="stPopover"]) {
23
+ display: block;
24
+ }
25
+
26
+ [data-testid="stPopover"] {
27
+ button {
28
+ padding: 0;
29
+ border: 0;
30
+
31
+ div:first-of-type>p {
32
+ font-size: small;
33
+ }
34
+ }
35
+
36
+ button:hover>div:first-of-type>p {
37
+ text-decoration: underline;
38
+ }
39
+ }