Spaces:
Runtime error
Runtime error
jacopoteneggi
commited on
Commit
•
5ead791
1
Parent(s):
df15615
Update
Browse files- app.py +7 -28
- app_lib/main.py +34 -23
- app_lib/test.py +21 -7
- app_lib/user_input.py +1 -2
- app_lib/viz.py +40 -0
- header.md +5 -0
- style.css +39 -0
app.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import open_clip
|
3 |
import streamlit as st
|
4 |
|
5 |
from app_lib.main import main
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
if "sidebar_state" not in st.session_state:
|
8 |
st.session_state.sidebar_state = "collapsed"
|
9 |
if "disabled" not in st.session_state:
|
@@ -14,33 +17,9 @@ if "results" not in st.session_state:
|
|
14 |
st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
|
15 |
|
16 |
st.session_state.sidebar_state = "collapsed"
|
17 |
-
st.markdown(
|
18 |
-
"""
|
19 |
-
<style>
|
20 |
-
textarea {
|
21 |
-
font-family: monospace !important;
|
22 |
-
}
|
23 |
-
input {
|
24 |
-
font-family: monospace !important;
|
25 |
-
}
|
26 |
-
|
27 |
-
[data-testid="stHorizontalBlock"] {
|
28 |
-
align-items: center;
|
29 |
-
}
|
30 |
-
</style>
|
31 |
-
""",
|
32 |
-
unsafe_allow_html=True,
|
33 |
-
)
|
34 |
-
|
35 |
-
st.markdown(
|
36 |
-
"""
|
37 |
-
# I Bet You Did Not Mean That
|
38 |
-
|
39 |
-
Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantci Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
|
40 |
|
41 |
-
|
42 |
-
""",
|
43 |
-
)
|
44 |
|
45 |
if __name__ == "__main__":
|
46 |
main()
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
from app_lib.main import main
|
4 |
|
5 |
+
with open("style.css", "r") as f:
|
6 |
+
style = f.read()
|
7 |
+
with open("header.md", "r") as f:
|
8 |
+
header = f.read()
|
9 |
+
|
10 |
if "sidebar_state" not in st.session_state:
|
11 |
st.session_state.sidebar_state = "collapsed"
|
12 |
if "disabled" not in st.session_state:
|
|
|
17 |
st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
|
18 |
|
19 |
st.session_state.sidebar_state = "collapsed"
|
20 |
+
st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
st.markdown(header)
|
|
|
|
|
23 |
|
24 |
if __name__ == "__main__":
|
25 |
main()
|
app_lib/main.py
CHANGED
@@ -10,6 +10,7 @@ from app_lib.user_input import (
|
|
10 |
get_model_name,
|
11 |
)
|
12 |
from app_lib.test import test
|
|
|
13 |
|
14 |
|
15 |
def _disable():
|
@@ -20,50 +21,58 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
20 |
columns = st.columns([0.40, 0.60])
|
21 |
|
22 |
with columns[0]:
|
23 |
-
|
24 |
|
25 |
-
|
26 |
-
row2 = st.columns(2)
|
27 |
|
28 |
-
with
|
29 |
image = get_image()
|
30 |
st.image(image, use_column_width=True)
|
31 |
-
with row1[1]:
|
32 |
-
class_name, class_ready, class_error = get_class_name()
|
33 |
-
concepts, concepts_ready, concepts_error = get_concepts()
|
34 |
-
cardinality = get_cardinality(concepts, concepts_ready)
|
35 |
|
36 |
-
with row2[0]:
|
37 |
change_image_button = st.button(
|
38 |
"Change Image",
|
39 |
-
use_container_width=
|
40 |
disabled=st.session_state.disabled,
|
41 |
)
|
42 |
if change_image_button:
|
43 |
st.session_state.sidebar_state = "expanded"
|
44 |
st.experimental_rerun()
|
45 |
-
with
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
if class_error is not None:
|
50 |
-
error_message += f"- {class_error}\n"
|
51 |
-
if concepts_error is not None:
|
52 |
-
error_message += f"- {concepts_error}\n"
|
53 |
-
if error_message:
|
54 |
-
st.error(error_message)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
test_button = st.button(
|
57 |
-
"Test",
|
58 |
use_container_width=True,
|
59 |
on_click=_disable,
|
60 |
disabled=st.session_state.disabled or not ready,
|
61 |
)
|
62 |
|
|
|
|
|
|
|
63 |
with columns[1]:
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
test(
|
68 |
image,
|
69 |
class_name,
|
@@ -73,3 +82,5 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
73 |
model_name,
|
74 |
device,
|
75 |
)
|
|
|
|
|
|
10 |
get_model_name,
|
11 |
)
|
12 |
from app_lib.test import test
|
13 |
+
from app_lib.viz import viz_results
|
14 |
|
15 |
|
16 |
def _disable():
|
|
|
21 |
columns = st.columns([0.40, 0.60])
|
22 |
|
23 |
with columns[0]:
|
24 |
+
st.header("Choose Image and Concepts")
|
25 |
|
26 |
+
image_col, concepts_col = st.columns(2)
|
|
|
27 |
|
28 |
+
with image_col:
|
29 |
image = get_image()
|
30 |
st.image(image, use_column_width=True)
|
|
|
|
|
|
|
|
|
31 |
|
|
|
32 |
change_image_button = st.button(
|
33 |
"Change Image",
|
34 |
+
use_container_width=False,
|
35 |
disabled=st.session_state.disabled,
|
36 |
)
|
37 |
if change_image_button:
|
38 |
st.session_state.sidebar_state = "expanded"
|
39 |
st.experimental_rerun()
|
40 |
+
with concepts_col:
|
41 |
+
model_name = get_model_name()
|
42 |
+
class_name, class_ready, class_error = get_class_name()
|
43 |
+
concepts, concepts_ready, concepts_error = get_concepts()
|
44 |
+
cardinality = int(len(concepts) / 2)
|
45 |
+
# get_cardinality(concepts, concepts_ready)
|
46 |
|
47 |
+
ready = class_ready and concepts_ready
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
error_message = ""
|
50 |
+
if class_error is not None:
|
51 |
+
error_message += f"- {class_error}\n"
|
52 |
+
if concepts_error is not None:
|
53 |
+
error_message += f"- {concepts_error}\n"
|
54 |
+
if error_message:
|
55 |
+
st.error(error_message)
|
56 |
+
|
57 |
+
with st.container():
|
58 |
test_button = st.button(
|
59 |
+
"Test Concepts",
|
60 |
use_container_width=True,
|
61 |
on_click=_disable,
|
62 |
disabled=st.session_state.disabled or not ready,
|
63 |
)
|
64 |
|
65 |
+
with st.popover("Advanced settings", disabled=st.session_state.disabled):
|
66 |
+
st.markdown("Hello World 👋")
|
67 |
+
|
68 |
with columns[1]:
|
69 |
+
st.header("Results")
|
70 |
+
|
71 |
+
if test_button:
|
72 |
+
st.session_state.results = None
|
73 |
+
|
74 |
+
_, centercol, _ = st.columns(3)
|
75 |
+
with centercol:
|
76 |
test(
|
77 |
image,
|
78 |
class_name,
|
|
|
82 |
model_name,
|
83 |
device,
|
84 |
)
|
85 |
+
|
86 |
+
viz_results()
|
app_lib/test.py
CHANGED
@@ -4,7 +4,7 @@ import open_clip
|
|
4 |
import h5py
|
5 |
import streamlit as st
|
6 |
import numpy as np
|
7 |
-
import
|
8 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
|
10 |
import ml_collections
|
@@ -170,12 +170,26 @@ def test(image, class_name, concepts, cardinality, dataset_name, model_name, dev
|
|
170 |
results.append(future.result())
|
171 |
progress_bar.progress((idx + 1) / len(concepts))
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
st.session_state.disabled = False
|
181 |
st.experimental_rerun()
|
|
|
4 |
import h5py
|
5 |
import streamlit as st
|
6 |
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
|
10 |
import ml_collections
|
|
|
170 |
results.append(future.result())
|
171 |
progress_bar.progress((idx + 1) / len(concepts))
|
172 |
|
173 |
+
rejected = np.empty((testing_config.r, len(concepts)))
|
174 |
+
tau = np.empty((testing_config.r, len(concepts)))
|
175 |
+
wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts)))
|
176 |
+
|
177 |
+
for _results in results:
|
178 |
+
concept_idx = concepts.index(_results["concept"])
|
179 |
+
|
180 |
+
rejected[:, concept_idx] = np.array(_results["rejected"])
|
181 |
+
tau[:, concept_idx] = np.array(_results["tau"])
|
182 |
+
wealth[:, :, concept_idx] = np.array(_results["wealth"])
|
183 |
+
|
184 |
+
tau /= testing_config.tau_max
|
185 |
+
|
186 |
+
st.session_state.results = {
|
187 |
+
"significance_level": testing_config.significance_level,
|
188 |
+
"concepts": concepts,
|
189 |
+
"rejected": rejected,
|
190 |
+
"tau": tau,
|
191 |
+
"wealth": wealth,
|
192 |
+
}
|
193 |
|
194 |
st.session_state.disabled = False
|
195 |
st.experimental_rerun()
|
app_lib/user_input.py
CHANGED
@@ -20,10 +20,9 @@ def _validate_concepts(concepts):
|
|
20 |
return (False, "Maximum 10 concepts allowed")
|
21 |
return (True, None)
|
22 |
|
23 |
-
|
24 |
def get_model_name():
|
25 |
return st.selectbox(
|
26 |
-
"
|
27 |
options=list(SUPPORTED_MODELS.keys()),
|
28 |
help="Name of the vision-language model to test the predictions of.",
|
29 |
disabled=st.session_state.disabled,
|
|
|
20 |
return (False, "Maximum 10 concepts allowed")
|
21 |
return (True, None)
|
22 |
|
|
|
23 |
def get_model_name():
|
24 |
return st.selectbox(
|
25 |
+
"Model to test",
|
26 |
options=list(SUPPORTED_MODELS.keys()),
|
27 |
help="Name of the vision-language model to test the predictions of.",
|
28 |
disabled=st.session_state.disabled,
|
app_lib/viz.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
|
5 |
+
|
6 |
+
def _viz_wealth(results):
|
7 |
+
wealth = results["wealth"]
|
8 |
+
concepts = results["concepts"]
|
9 |
+
significance_level = results["significance_level"]
|
10 |
+
|
11 |
+
wealth_mu = wealth.mean(axis=0)
|
12 |
+
|
13 |
+
wealth_df = []
|
14 |
+
for concept_idx, concept in enumerate(concepts):
|
15 |
+
for t in range(wealth.shape[1]):
|
16 |
+
wealth_df.append(
|
17 |
+
{"time": t, "concept": concept, "wealth": wealth_mu[t, concept_idx]}
|
18 |
+
)
|
19 |
+
wealth_df = pd.DataFrame(wealth_df)
|
20 |
+
|
21 |
+
fig = px.line(wealth_df, x="time", y="wealth", color="concept")
|
22 |
+
fig.update_yaxes(range=[0, 3 * 1 / significance_level])
|
23 |
+
st.plotly_chart(fig, use_container_width=True)
|
24 |
+
|
25 |
+
|
26 |
+
def viz_results():
|
27 |
+
results = st.session_state.results
|
28 |
+
|
29 |
+
if results is None:
|
30 |
+
st.info("Run tests to show results", icon="ℹ️")
|
31 |
+
else:
|
32 |
+
rank_tab, wealth_tab = st.tabs(["Rank of importance", "Wealth process"])
|
33 |
+
|
34 |
+
with rank_tab:
|
35 |
+
st.subheader("Rank of Semantic Importance")
|
36 |
+
with wealth_tab:
|
37 |
+
st.subheader("Wealth Process of Testing Procedures")
|
38 |
+
|
39 |
+
if results is not None:
|
40 |
+
_viz_wealth(results)
|
header.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🤔 I Bet You Did Not Mean That
|
2 |
+
|
3 |
+
Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantic Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
|
4 |
+
|
5 |
+
---
|
style.css
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
textarea {
|
2 |
+
font-family: monospace !important;
|
3 |
+
}
|
4 |
+
|
5 |
+
input {
|
6 |
+
font-family: monospace !important;
|
7 |
+
}
|
8 |
+
|
9 |
+
h1 {
|
10 |
+
padding-top: 0 !important;
|
11 |
+
}
|
12 |
+
|
13 |
+
[data-testid="stHorizontalBlock"] [data-testid="stHorizontalBlock"] {
|
14 |
+
align-items: center;
|
15 |
+
}
|
16 |
+
|
17 |
+
[data-testid="stButton"] {
|
18 |
+
display: flex;
|
19 |
+
justify-content: center;
|
20 |
+
}
|
21 |
+
|
22 |
+
[data-testid="stVerticalBlock"]:has(> [data-testid="stPopover"]) {
|
23 |
+
display: block;
|
24 |
+
}
|
25 |
+
|
26 |
+
[data-testid="stPopover"] {
|
27 |
+
button {
|
28 |
+
padding: 0;
|
29 |
+
border: 0;
|
30 |
+
|
31 |
+
div:first-of-type>p {
|
32 |
+
font-size: small;
|
33 |
+
}
|
34 |
+
}
|
35 |
+
|
36 |
+
button:hover>div:first-of-type>p {
|
37 |
+
text-decoration: underline;
|
38 |
+
}
|
39 |
+
}
|