Spaces:
Sleeping
Sleeping
import streamlit as st | |
import json | |
from collections import Counter | |
import contractions | |
import csv | |
import altair as alt | |
from typing import Tuple, List, Optional | |
from my_model.dataset.dataset_processor import process_okvqa_dataset | |
from my_model.config import dataset_config as config | |
class OKVQADatasetAnalyzer: | |
""" | |
Provides tools for analyzing and visualizing distributions of question types within given question datasets. | |
It supports operations such as data loading, categorization of questions based on keywords, visualization of q | |
uestion distribution, and exporting data to CSV files. | |
Attributes: | |
train_file_path (str): Path to the training dataset file. | |
test_file_path (str): Path to the testing dataset file. | |
data_choice (str): Choice of dataset(s) to analyze; options include 'train', 'test', or 'train_test'. | |
questions (List[str]): List of questions aggregated based on the dataset choice. | |
question_types (Counter): Counter object tracking the frequency of each question type. | |
Qs (Dict[str, List[str]]): Dictionary mapping question types to lists of corresponding questions. | |
""" | |
def __init__(self, train_file_path: str, test_file_path: str, data_choice: str): | |
""" | |
Initializes the OKVQADatasetAnalyzer with paths to dataset files and a choice of which datasets to analyze. | |
Parameters: | |
train_file_path (str): Path to the training dataset JSON file. This file should contain a list of questions. | |
test_file_path (str): Path to the testing dataset JSON file. This file should also contain a list of | |
questions. | |
data_choice (str): Specifies which dataset(s) to load and analyze. Valid options are 'train', 'test', or | |
'train_test'indicating whether to load training data, testing data, or both. | |
The constructor initializes the paths, selects the dataset based on the choice, and loads the initial data by | |
calling the `load_data` method. | |
It also prepares structures for categorizing questions and storing the results. | |
""" | |
self.train_file_path = train_file_path | |
self.test_file_path = test_file_path | |
self.data_choice = data_choice | |
self.questions = [] | |
self.question_types = Counter() | |
self.Qs = {keyword: [] for keyword in config.QUESTION_KEYWORDS + ['others']} | |
self.load_data() | |
def load_data(self) -> None: | |
""" | |
Loads the dataset(s) from the specified JSON file(s) based on the user's choice of 'train', 'test', or | |
'train_test'. | |
This method updates the internal list of questions depending on the chosen dataset. | |
""" | |
if self.data_choice in ['train', 'train_test']: | |
with open(self.train_file_path, 'r') as file: | |
train_data = json.load(file) | |
self.questions += [q['question'] for q in train_data['questions']] | |
if self.data_choice in ['test', 'train_test']: | |
with open(self.test_file_path, 'r') as file: | |
test_data = json.load(file) | |
self.questions += [q['question'] for q in test_data['questions']] | |
def categorize_questions(self) -> None: | |
""" | |
Categorizes each question in the loaded data into predefined categories based on keywords. | |
This method updates the internal dictionary `self.Qs` and the Counter `self.question_types` with categorized | |
questions. | |
""" | |
question_keywords = self.QUESTION_KEYWORDS | |
for question in self.questions: | |
question = contractions.fix(question) | |
words = question.lower().split() | |
question_keyword = None | |
if words[:2] == ['name', 'the']: | |
question_keyword = 'name the' | |
else: | |
for word in words: | |
if word in question_keywords: | |
question_keyword = word | |
break | |
if question_keyword: | |
self.question_types[question_keyword] += 1 | |
self.Qs[question_keyword].append(question) | |
else: | |
self.question_types["others"] += 1 | |
self.Qs["others"].append(question) | |
def plot_question_distribution(self) -> None: | |
""" | |
Plots an interactive bar chart of question types using Altair and Streamlit, displaying the count and percentage | |
of each type. | |
The chart sorts question types by count in descending order and includes detailed tooltips for interaction. | |
This method is intended for visualization in a Streamlit application. | |
""" | |
# Prepare data | |
total_questions = sum(self.question_types.values()) | |
items = [(key, value, (value / total_questions) * 100) for key, value in self.question_types.items()] | |
df = pd.DataFrame(items, columns=['Question Keyword', 'Count', 'Percentage']) | |
# Sort data and handle 'others' category specifically if present | |
df = df[df['Question Keyword'] != 'others'].sort_values('Count', ascending=False) | |
if 'others' in self.question_types: | |
others_df = pd.DataFrame([('others', self.question_types['others'], | |
(self.question_types['others'] / total_questions) * 100)], | |
columns=['Question Keyword', 'Count', 'Percentage']) | |
df = pd.concat([df, others_df], ignore_index=True) | |
# Explicitly set the order of the x-axis based on the sorted DataFrame | |
order = df['Question Keyword'].tolist() | |
# Create the bar chart | |
bars = alt.Chart(df).mark_bar().encode( | |
x=alt.X('Question Keyword:N', sort=order, title='Question Keyword', axis=alt.Axis(labelAngle=-45)), | |
y=alt.Y('Count:Q', title='Question Count'), | |
color=alt.Color('Question Keyword:N', scale=alt.Scale(scheme='category20'), legend=None), | |
tooltip=[alt.Tooltip('Question Keyword:N', title='Type'), | |
alt.Tooltip('Count:Q', title='Count'), | |
alt.Tooltip('Percentage:Q', title='Percentage', format='.1f')] | |
) | |
# Create text labels for the bars with count and percentage | |
text = bars.mark_text( | |
align='center', | |
baseline='bottom', | |
dy=-5 # Nudges text up so it appears above the bar | |
).encode( | |
text=alt.Text('PercentageText:N') | |
).transform_calculate( | |
PercentageText="datum.Count + ' (' + format(datum.Percentage, '.1f') + '%)'" | |
) | |
# Combine the bar and text layers | |
chart = (bars + text).properties( | |
width=800, | |
height=600, | |
).configure_axis( | |
labelFontSize=12, | |
titleFontSize=16, | |
labelFontWeight='bold', | |
titleFontWeight='bold', | |
grid=False | |
).configure_text( | |
fontWeight='bold' | |
).configure_title( | |
fontSize=20, | |
font='bold', | |
anchor='middle' | |
) | |
# Display the chart in Streamlit | |
st.altair_chart(chart, use_container_width=True) | |
def plot_bar_chart(self, df: pd.DataFrame, category_col: str, value_col: str, chart_title: str) -> None: | |
""" | |
Plots an interactive bar chart using Altair and Streamlit. | |
Args: | |
df (pd.DataFrame): DataFrame containing the data for the bar chart. | |
category_col (str): Name of the column containing the categories. | |
value_col (str): Name of the column containing the values. | |
chart_title (str): Title of the chart. | |
Returns: | |
None | |
""" | |
# Calculate percentage for each category | |
df['Percentage'] = (df[value_col] / df[value_col].sum()) * 100 | |
df['PercentageText'] = df['Percentage'].round(1).astype(str) + '%' | |
# Create the bar chart | |
bars = alt.Chart(df).mark_bar().encode( | |
x=alt.X(field=category_col, title='Category', sort='-y', axis=alt.Axis(labelAngle=-45)), | |
y=alt.Y(field=value_col, type='quantitative', title='Percentage'), | |
color=alt.Color(field=category_col, type='nominal', legend=None), | |
tooltip=[ | |
alt.Tooltip(field=category_col, type='nominal', title='Category'), | |
alt.Tooltip(field=value_col, type='quantitative', title='Percentage'), | |
alt.Tooltip(field='Percentage', type='quantitative', title='Percentage', format='.1f') | |
] | |
).properties( | |
width=800, | |
height=600 | |
) | |
# Add text labels to the bars | |
text = bars.mark_text( | |
align='center', | |
baseline='bottom', | |
dy=-10 # Nudges text up so it appears above the bar | |
).encode( | |
text=alt.Text('PercentageText:N') | |
) | |
# Combine the bar chart and text labels | |
chart = (bars + text).configure_title( | |
fontSize=20 | |
).configure_axis( | |
labelFontSize=12, | |
titleFontSize=16, | |
labelFontWeight='bold', | |
titleFontWeight='bold', | |
grid=False | |
).configure_text( | |
fontWeight='bold') | |
# Display the chart in Streamlit | |
st.altair_chart(chart, use_container_width=True) | |
def export_to_csv(self, qs_filename: str, question_types_filename: str) -> None: | |
""" | |
Exports the categorized questions and their counts to two separate CSV files. | |
Parameters: | |
qs_filename (str): The filename or path for exporting the `self.Qs` dictionary data. | |
question_types_filename (str): The filename or path for exporting the `self.question_types` Counter data. | |
This method writes the contents of `self.Qs` and `self.question_types` to the specified files in CSV format. | |
Each CSV file includes headers for better understanding and use of the exported data. | |
""" | |
# Export self.Qs dictionary | |
with open(qs_filename, mode='w', newline='', encoding='utf-8') as file: | |
writer = csv.writer(file) | |
writer.writerow(['Question Type', 'Questions']) | |
for q_type, questions in self.Qs.items(): | |
for question in questions: | |
writer.writerow([q_type, question]) | |
# Export self.question_types Counter | |
with open(question_types_filename, mode='w', newline='', encoding='utf-8') as file: | |
writer = csv.writer(file) | |
writer.writerow(['Question Type', 'Count']) | |
for q_type, count in self.question_types.items(): | |
writer.writerow([q_type, count]) | |
def run_dataset_analyzer(): | |
datasets_comparison_table = pd.read_excel(CONFIG.DATASET_ANALYSES_PATH, sheet_name="VQA Datasets Comparison") | |
okvqa_dataset_characteristics = pd.read_excel(CONFIG.DATASET_ANALYSES_PATH, sheet_name="OK-VQA Dataset Characteristics") | |
val_data = process_okvqa_dataset(config.DATASET_VAL_QUESTIONS_PATH, config.DATASET_VAL_ANNOTATIONS_PATH, | |
save_to_csv=False) | |
train_data = process_okvqa_dataset(config.DATASET_TRAIN_QUESTIONS_PATH, config.DATASET_TRAIN_ANNOTATIONS_PATH , | |
save_to_csv=False) | |
dataset_analyzer = OKVQADatasetAnalyzer(config.DATASET_TRAIN_QUESTIONS_PATH, | |
config.DATASET_VAL_QUESTIONS_PATH, 'train_test') | |
with st.container(): | |
st.markdown("## Overview of KB-VQA Datasets") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.write(" ") | |
with st.expander("1 - Knowledge-Based VQA (KB-VQA)"): | |
st.markdown(""" [Knowledge-Based VQA (KB-VQA)](https://arxiv.org/abs/1511.02570): One of the earliest | |
datasets in this domain, KB-VQA comprises 700 images and 2,402 questions, with each | |
question associated with both an image and a knowledge base (KB). The KB encapsulates | |
facts about the world, including object names, properties, and relationships, aiming to | |
foster models capable of answering questions through reasoning over both the image | |
and the KB.\n""") | |
with st.expander("2 - Factual VQA (FVQA)"): | |
st.markdown(""" [Factual VQA (FVQA)](https://arxiv.org/abs/1606.05433): This dataset includes 2,190 | |
images and 5,826 questions, accompanied by a knowledge base containing 193,449 facts. | |
The FVQA's questions are predominantly factual and less open-ended compared to those | |
in KB-VQA, offering a different challenge in knowledge-based reasoning.\n""") | |
with st.expander("3 - Outside-Knowledge VQA (OK-VQA)"): | |
st.markdown(""" [Outside-Knowledge VQA (OK-VQA)](https://arxiv.org/abs/1906.00067): OK-VQA poses a more | |
demanding challenge than KB-VQA, featuring an open-ended knowledge base that can be | |
updated during model training. This dataset contains 14,055 questions and 14,031 images. | |
Questions are carefully curated to ensure they require reasoning beyond the image | |
content alone.\n""") | |
with st.expander("4 - Augmented OK-VQA (A-OKVQA)"): | |
st.markdown(""" [Augmented OK-VQA (A-OKVQA)](https://arxiv.org/abs/2206.01718): Augmented successor of | |
OK-VQA dataset, focused on common-sense knowledge and reasoning rather than purely | |
factual knowledge, A-OKVQA offers approximately 24,903 questions across 23,692 images. | |
Questions in this dataset demand commonsense reasoning about the scenes depicted in the | |
images, moving beyond straightforward knowledge base queries. It also provides | |
rationales for answers, aiming to be a significant testbed for the development of AI | |
models that integrate visual and natural language reasoning.\n""") | |
with col2: | |
st.markdown("#### KB-VQA Datasets Comparison") | |
st.write(datasets_comparison_table, use_column_width=True) | |
st.write("-----------------------") | |
with st.container(): | |
st.write("\n" * 10) | |
st.markdown("## OK-VQA Dataset") | |
st.write("This model was fine-tuned and evaluated using OK-VQA dataset.\n") | |
col1, col2, col3 = st.columns([2, 5, 5]) | |
with col1: | |
st.markdown("#### OK-VQA Dataset Characteristics") | |
st.write(okvqa_dataset_characteristics) | |
with col2: | |
df = pd.read_excel("dataset_analyses.xlsx", sheet_name="Question Category Dist") | |
st.markdown("#### Questions Distribution over Knowledge Category") | |
dataset_analyzer.plot_bar_chart(df, "Knowledge Category", "Percentage", "Questions Distribution over " | |
"Knowledge Category") | |
with col3: | |
#with st.expander("Distribution of Question Keywords"): | |
dataset_analyzer.categorize_questions() | |
st.markdown("#### Distribution of Question Keywords") | |
dataset_analyzer.plot_question_distribution() | |
with st.container(): | |
with st.expander("Show Dataset Samples"): | |
st.write(train_data[:10]) | |