Spaces:
Running
Running
Update my_model/tabs/dataset_analysis.py
Browse files- my_model/tabs/dataset_analysis.py +160 -14
my_model/tabs/dataset_analysis.py
CHANGED
@@ -8,6 +8,7 @@ from typing import Tuple, List, Optional
|
|
8 |
from my_model.dataset.dataset_processor import process_okvqa_dataset
|
9 |
from my_model.config import dataset_config as config
|
10 |
|
|
|
11 |
class OKVQADatasetAnalyzer:
|
12 |
"""
|
13 |
Provides tools for analyzing and visualizing distributions of question types within given question datasets.
|
@@ -29,22 +30,22 @@ class OKVQADatasetAnalyzer:
|
|
29 |
|
30 |
Parameters:
|
31 |
train_file_path (str): Path to the training dataset JSON file. This file should contain a list of questions.
|
32 |
-
test_file_path (str): Path to the testing dataset JSON file. This file should also contain a list of
|
33 |
questions.
|
34 |
-
data_choice (str): Specifies which dataset(s) to load and analyze. Valid options are 'train', 'test', or
|
35 |
'train_test'indicating whether to load training data, testing data, or both.
|
36 |
|
37 |
-
The constructor initializes the paths, selects the dataset based on the choice, and loads the initial data by
|
38 |
calling the `load_data` method.
|
39 |
It also prepares structures for categorizing questions and storing the results.
|
40 |
"""
|
41 |
-
|
42 |
self.train_file_path = train_file_path
|
43 |
self.test_file_path = test_file_path
|
44 |
self.data_choice = data_choice
|
45 |
self.questions = []
|
46 |
self.question_types = Counter()
|
47 |
-
self.Qs = {keyword: [] for keyword in config.QUESTION_KEYWORDS}
|
48 |
self.load_data()
|
49 |
|
50 |
def load_data(self) -> None:
|
@@ -71,7 +72,7 @@ class OKVQADatasetAnalyzer:
|
|
71 |
questions.
|
72 |
"""
|
73 |
|
74 |
-
question_keywords =
|
75 |
|
76 |
for question in self.questions:
|
77 |
question = contractions.fix(question)
|
@@ -98,7 +99,7 @@ class OKVQADatasetAnalyzer:
|
|
98 |
The chart sorts question types by count in descending order and includes detailed tooltips for interaction.
|
99 |
This method is intended for visualization in a Streamlit application.
|
100 |
"""
|
101 |
-
|
102 |
# Prepare data
|
103 |
total_questions = sum(self.question_types.values())
|
104 |
items = [(key, value, (value / total_questions) * 100) for key, value in self.question_types.items()]
|
@@ -118,7 +119,7 @@ class OKVQADatasetAnalyzer:
|
|
118 |
# Create the bar chart
|
119 |
bars = alt.Chart(df).mark_bar().encode(
|
120 |
x=alt.X('Question Keyword:N', sort=order, title='Question Keyword', axis=alt.Axis(labelAngle=-45)),
|
121 |
-
y=alt.Y('Count:Q', title='
|
122 |
color=alt.Color('Question Keyword:N', scale=alt.Scale(scheme='category20'), legend=None),
|
123 |
tooltip=[alt.Tooltip('Question Keyword:N', title='Type'),
|
124 |
alt.Tooltip('Count:Q', title='Count'),
|
@@ -138,17 +139,83 @@ class OKVQADatasetAnalyzer:
|
|
138 |
|
139 |
# Combine the bar and text layers
|
140 |
chart = (bars + text).properties(
|
141 |
-
width=
|
142 |
-
height=
|
143 |
-
|
144 |
-
).configure_title(fontSize=20).configure_axis(
|
145 |
labelFontSize=12,
|
146 |
-
titleFontSize=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
)
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
# Display the chart in Streamlit
|
150 |
st.altair_chart(chart, use_container_width=True)
|
151 |
|
|
|
|
|
152 |
def export_to_csv(self, qs_filename: str, question_types_filename: str) -> None:
|
153 |
"""
|
154 |
Exports the categorized questions and their counts to two separate CSV files.
|
@@ -174,4 +241,83 @@ class OKVQADatasetAnalyzer:
|
|
174 |
writer = csv.writer(file)
|
175 |
writer.writerow(['Question Type', 'Count'])
|
176 |
for q_type, count in self.question_types.items():
|
177 |
-
writer.writerow([q_type, count])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from my_model.dataset.dataset_processor import process_okvqa_dataset
|
9 |
from my_model.config import dataset_config as config
|
10 |
|
11 |
+
|
12 |
class OKVQADatasetAnalyzer:
|
13 |
"""
|
14 |
Provides tools for analyzing and visualizing distributions of question types within given question datasets.
|
|
|
30 |
|
31 |
Parameters:
|
32 |
train_file_path (str): Path to the training dataset JSON file. This file should contain a list of questions.
|
33 |
+
test_file_path (str): Path to the testing dataset JSON file. This file should also contain a list of
|
34 |
questions.
|
35 |
+
data_choice (str): Specifies which dataset(s) to load and analyze. Valid options are 'train', 'test', or
|
36 |
'train_test'indicating whether to load training data, testing data, or both.
|
37 |
|
38 |
+
The constructor initializes the paths, selects the dataset based on the choice, and loads the initial data by
|
39 |
calling the `load_data` method.
|
40 |
It also prepares structures for categorizing questions and storing the results.
|
41 |
"""
|
42 |
+
|
43 |
self.train_file_path = train_file_path
|
44 |
self.test_file_path = test_file_path
|
45 |
self.data_choice = data_choice
|
46 |
self.questions = []
|
47 |
self.question_types = Counter()
|
48 |
+
self.Qs = {keyword: [] for keyword in config.QUESTION_KEYWORDS + ['others']}
|
49 |
self.load_data()
|
50 |
|
51 |
def load_data(self) -> None:
|
|
|
72 |
questions.
|
73 |
"""
|
74 |
|
75 |
+
question_keywords = self.QUESTION_KEYWORDS
|
76 |
|
77 |
for question in self.questions:
|
78 |
question = contractions.fix(question)
|
|
|
99 |
The chart sorts question types by count in descending order and includes detailed tooltips for interaction.
|
100 |
This method is intended for visualization in a Streamlit application.
|
101 |
"""
|
102 |
+
|
103 |
# Prepare data
|
104 |
total_questions = sum(self.question_types.values())
|
105 |
items = [(key, value, (value / total_questions) * 100) for key, value in self.question_types.items()]
|
|
|
119 |
# Create the bar chart
|
120 |
bars = alt.Chart(df).mark_bar().encode(
|
121 |
x=alt.X('Question Keyword:N', sort=order, title='Question Keyword', axis=alt.Axis(labelAngle=-45)),
|
122 |
+
y=alt.Y('Count:Q', title='Question Count'),
|
123 |
color=alt.Color('Question Keyword:N', scale=alt.Scale(scheme='category20'), legend=None),
|
124 |
tooltip=[alt.Tooltip('Question Keyword:N', title='Type'),
|
125 |
alt.Tooltip('Count:Q', title='Count'),
|
|
|
139 |
|
140 |
# Combine the bar and text layers
|
141 |
chart = (bars + text).properties(
|
142 |
+
width=800,
|
143 |
+
height=600,
|
144 |
+
).configure_axis(
|
|
|
145 |
labelFontSize=12,
|
146 |
+
titleFontSize=16,
|
147 |
+
labelFontWeight='bold',
|
148 |
+
titleFontWeight='bold',
|
149 |
+
grid=False
|
150 |
+
).configure_text(
|
151 |
+
fontWeight='bold'
|
152 |
+
).configure_title(
|
153 |
+
fontSize=20,
|
154 |
+
font='bold',
|
155 |
+
anchor='middle'
|
156 |
+
)
|
157 |
+
|
158 |
+
# Display the chart in Streamlit
|
159 |
+
st.altair_chart(chart, use_container_width=True)
|
160 |
+
|
161 |
+
def plot_bar_chart(self, df: pd.DataFrame, category_col: str, value_col: str, chart_title: str) -> None:
|
162 |
+
"""
|
163 |
+
Plots an interactive bar chart using Altair and Streamlit.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
df (pd.DataFrame): DataFrame containing the data for the bar chart.
|
167 |
+
category_col (str): Name of the column containing the categories.
|
168 |
+
value_col (str): Name of the column containing the values.
|
169 |
+
chart_title (str): Title of the chart.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
None
|
173 |
+
"""
|
174 |
+
# Calculate percentage for each category
|
175 |
+
df['Percentage'] = (df[value_col] / df[value_col].sum()) * 100
|
176 |
+
df['PercentageText'] = df['Percentage'].round(1).astype(str) + '%'
|
177 |
+
|
178 |
+
# Create the bar chart
|
179 |
+
bars = alt.Chart(df).mark_bar().encode(
|
180 |
+
x=alt.X(field=category_col, title='Category', sort='-y', axis=alt.Axis(labelAngle=-45)),
|
181 |
+
y=alt.Y(field=value_col, type='quantitative', title='Percentage'),
|
182 |
+
color=alt.Color(field=category_col, type='nominal', legend=None),
|
183 |
+
tooltip=[
|
184 |
+
alt.Tooltip(field=category_col, type='nominal', title='Category'),
|
185 |
+
alt.Tooltip(field=value_col, type='quantitative', title='Percentage'),
|
186 |
+
alt.Tooltip(field='Percentage', type='quantitative', title='Percentage', format='.1f')
|
187 |
+
]
|
188 |
+
).properties(
|
189 |
+
width=800,
|
190 |
+
height=600
|
191 |
+
)
|
192 |
+
|
193 |
+
# Add text labels to the bars
|
194 |
+
text = bars.mark_text(
|
195 |
+
align='center',
|
196 |
+
baseline='bottom',
|
197 |
+
dy=-10 # Nudges text up so it appears above the bar
|
198 |
+
).encode(
|
199 |
+
text=alt.Text('PercentageText:N')
|
200 |
)
|
201 |
|
202 |
+
# Combine the bar chart and text labels
|
203 |
+
chart = (bars + text).configure_title(
|
204 |
+
fontSize=20
|
205 |
+
).configure_axis(
|
206 |
+
labelFontSize=12,
|
207 |
+
titleFontSize=16,
|
208 |
+
labelFontWeight='bold',
|
209 |
+
titleFontWeight='bold',
|
210 |
+
grid=False
|
211 |
+
).configure_text(
|
212 |
+
fontWeight='bold')
|
213 |
+
|
214 |
# Display the chart in Streamlit
|
215 |
st.altair_chart(chart, use_container_width=True)
|
216 |
|
217 |
+
|
218 |
+
|
219 |
def export_to_csv(self, qs_filename: str, question_types_filename: str) -> None:
|
220 |
"""
|
221 |
Exports the categorized questions and their counts to two separate CSV files.
|
|
|
241 |
writer = csv.writer(file)
|
242 |
writer.writerow(['Question Type', 'Count'])
|
243 |
for q_type, count in self.question_types.items():
|
244 |
+
writer.writerow([q_type, count])
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
def run_dataset_analyzer():
|
249 |
+
datasets_comparison_table = pd.read_excel("dataset_analyses.xlsx", sheet_name="VQA Datasets Comparison")
|
250 |
+
okvqa_dataset_characteristics = pd.read_excel("dataset_analyses.xlsx", sheet_name="OK-VQA Dataset Characteristics")
|
251 |
+
val_data = process_okvqa_dataset('OpenEnded_mscoco_val2014_questions.json', 'mscoco_val2014_annotations.json',
|
252 |
+
save_to_csv=False)
|
253 |
+
train_data = process_okvqa_dataset('OpenEnded_mscoco_train2014_questions.json', 'mscoco_train2014_annotations.json',
|
254 |
+
save_to_csv=False)
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
dataset_analyzer = OKVQADatasetAnalyzer('OpenEnded_mscoco_train2014_questions.json',
|
259 |
+
'OpenEnded_mscoco_val2014_questions.json', 'train_test')
|
260 |
+
|
261 |
+
with st.container():
|
262 |
+
st.markdown("## Overview of KB-VQA Datasets")
|
263 |
+
col1, col2 = st.columns([2, 1])
|
264 |
+
with col1:
|
265 |
+
st.write(" ")
|
266 |
+
with st.expander("1 - Knowledge-Based VQA (KB-VQA)"):
|
267 |
+
st.markdown(""" [Knowledge-Based VQA (KB-VQA)](https://arxiv.org/abs/1511.02570): One of the earliest
|
268 |
+
datasets in this domain, KB-VQA comprises 700 images and 2,402 questions, with each
|
269 |
+
question associated with both an image and a knowledge base (KB). The KB encapsulates
|
270 |
+
facts about the world, including object names, properties, and relationships, aiming to
|
271 |
+
foster models capable of answering questions through reasoning over both the image
|
272 |
+
and the KB.\n""")
|
273 |
+
with st.expander("2 - Factual VQA (FVQA)"):
|
274 |
+
st.markdown(""" [Factual VQA (FVQA)](https://arxiv.org/abs/1606.05433): This dataset includes 2,190
|
275 |
+
images and 5,826 questions, accompanied by a knowledge base containing 193,449 facts.
|
276 |
+
The FVQA's questions are predominantly factual and less open-ended compared to those
|
277 |
+
in KB-VQA, offering a different challenge in knowledge-based reasoning.\n""")
|
278 |
+
with st.expander("3 - Outside-Knowledge VQA (OK-VQA)"):
|
279 |
+
st.markdown(""" [Outside-Knowledge VQA (OK-VQA)](https://arxiv.org/abs/1906.00067): OK-VQA poses a more
|
280 |
+
demanding challenge than KB-VQA, featuring an open-ended knowledge base that can be
|
281 |
+
updated during model training. This dataset contains 14,055 questions and 14,031 images.
|
282 |
+
Questions are carefully curated to ensure they require reasoning beyond the image
|
283 |
+
content alone.\n""")
|
284 |
+
with st.expander("4 - Augmented OK-VQA (A-OKVQA)"):
|
285 |
+
st.markdown(""" [Augmented OK-VQA (A-OKVQA)](https://arxiv.org/abs/2206.01718): Augmented successor of
|
286 |
+
OK-VQA dataset, focused on common-sense knowledge and reasoning rather than purely
|
287 |
+
factual knowledge, A-OKVQA offers approximately 24,903 questions across 23,692 images.
|
288 |
+
Questions in this dataset demand commonsense reasoning about the scenes depicted in the
|
289 |
+
images, moving beyond straightforward knowledge base queries. It also provides
|
290 |
+
rationales for answers, aiming to be a significant testbed for the development of AI
|
291 |
+
models that integrate visual and natural language reasoning.\n""")
|
292 |
+
with col2:
|
293 |
+
st.markdown("#### KB-VQA Datasets Comparison")
|
294 |
+
st.write(datasets_comparison_table, use_column_width=True)
|
295 |
+
st.write("-----------------------")
|
296 |
+
with st.container():
|
297 |
+
st.write("\n" * 10)
|
298 |
+
st.markdown("## OK-VQA Dataset")
|
299 |
+
st.write("This model was fine-tuned and evaluated using OK-VQA dataset.\n")
|
300 |
+
col1, col2, col3 = st.columns([2, 5, 5])
|
301 |
+
|
302 |
+
with col1:
|
303 |
+
st.markdown("#### OK-VQA Dataset Characteristics")
|
304 |
+
st.write(okvqa_dataset_characteristics)
|
305 |
+
with col2:
|
306 |
+
df = pd.read_excel("dataset_analyses.xlsx", sheet_name="Question Category Dist")
|
307 |
+
st.markdown("#### Questions Distribution over Knowledge Category")
|
308 |
+
dataset_analyzer.plot_bar_chart(df, "Knowledge Category", "Percentage", "Questions Distribution over "
|
309 |
+
"Knowledge Category")
|
310 |
+
|
311 |
+
with col3:
|
312 |
+
|
313 |
+
#with st.expander("Distribution of Question Keywords"):
|
314 |
+
dataset_analyzer.categorize_questions()
|
315 |
+
st.markdown("#### Distribution of Question Keywords")
|
316 |
+
dataset_analyzer.plot_question_distribution()
|
317 |
+
|
318 |
+
with st.container():
|
319 |
+
with st.expander("Show Dataset Samples"):
|
320 |
+
st.write(train_data[:10])
|
321 |
+
|
322 |
+
|
323 |
+
|