|
import os |
|
import altair as alt |
|
from my_model.config import evaluation_config as config |
|
import streamlit as st |
|
from PIL import Image |
|
import pandas as pd |
|
import random |
|
|
|
|
|
class ResultDemonstrator: |
|
""" |
|
A class to demonstrate the results of the Knowledge-Based Visual Question Answering (KB-VQA) model. |
|
|
|
Attributes: |
|
main_data (pd.DataFrame): Data loaded from an Excel file containing evaluation results. |
|
sample_img_pool (list[str]): List of image file names available for demonstration. |
|
model_names (list[str]): List of model names as defined in the configuration. |
|
model_configs (list[str]): List of model configurations as defined in the configuration. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Initializes the ResultDemonstrator class by loading the data from an Excel file. |
|
""" |
|
|
|
self.main_data = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Main Data") |
|
self.sample_img_pool = list(os.listdir(config.DEMO_IMAGES_PATH)) |
|
self.model_names = config.MODEL_NAMES |
|
self.model_configs = config.MODEL_CONFIGURATIONS |
|
|
|
@staticmethod |
|
def display_table(data: pd.DataFrame) -> None: |
|
""" |
|
Displays a DataFrame using Streamlit's dataframe display function. |
|
|
|
Args: |
|
data (pd.DataFrame): The data to display. |
|
""" |
|
st.dataframe(data) |
|
|
|
def calculate_and_append_data(self, data_list: list, score_column: str, model_config: str) -> None: |
|
""" |
|
Calculates mean scores by category and appends them to the data list. |
|
|
|
Args: |
|
data_list (list): List to append new data rows. |
|
score_column (str): Name of the column to calculate mean scores for. |
|
model_config (str): Configuration of the model. |
|
""" |
|
if score_column in self.main_data.columns: |
|
category_means = self.main_data.groupby('question_category')[score_column].mean() |
|
for category, mean_value in category_means.items(): |
|
data_list.append({ |
|
"Category": category, |
|
"Configuration": model_config, |
|
"Mean Value": round(mean_value * 100, 2) |
|
}) |
|
|
|
def display_ablation_results_per_question_category(self) -> None: |
|
"""Displays ablation results per question category for each model configuration.""" |
|
|
|
score_types = ['vqa', 'vqa_gpt4', 'em', 'em_gpt4'] |
|
data_lists = {key: [] for key in score_types} |
|
column_names = { |
|
'vqa': 'vqa_score_{config}', |
|
'vqa_gpt4': 'gpt4_vqa_score_{config}', |
|
'em': 'exact_match_score_{config}', |
|
'em_gpt4': 'gpt4_em_score_{config}' |
|
} |
|
|
|
for model_name in config.MODEL_NAMES: |
|
for conf in config.MODEL_CONFIGURATIONS: |
|
model_config = f"{model_name}_{conf}" |
|
for score_type, col_template in column_names.items(): |
|
self.calculate_and_append_data(data_lists[score_type], |
|
col_template.format(config=model_config), |
|
model_config) |
|
|
|
|
|
for score_type, data_list in data_lists.items(): |
|
df = pd.DataFrame(data_list) |
|
results_df = df.pivot(index='Category', columns='Configuration', values='Mean Value').applymap( |
|
lambda x: f"{x:.2f}%") |
|
|
|
with st.expander(f"{score_type.upper()} Scores per Question Category and Model Configuration"): |
|
self.display_table(results_df) |
|
|
|
def display_main_results(self) -> None: |
|
"""Displays the main model results from the Scores sheet, these are displayed from the file directly.""" |
|
main_scores = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Scores", index_col=0) |
|
st.markdown("### Main Model Results (Inclusive of Ablation Experiments)") |
|
main_scores.reset_index() |
|
self.display_table(main_scores) |
|
|
|
def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None: |
|
""" |
|
Plots an interactive scatter plot comparing token counts to VQA or EM scores using Altair. |
|
|
|
Args: |
|
conf (str): The configuration name. |
|
model_name (str): The name of the model. |
|
score_name (str): The type of score to plot. |
|
""" |
|
|
|
|
|
model_configuration = f"{model_name}_{conf}" |
|
|
|
|
|
if score_name == 'VQA Score': |
|
|
|
score_column_name = f"vqa_score_{model_configuration}" |
|
scores = self.main_data[score_column_name] |
|
|
|
legend_map = ['Correct' if score == 1 else 'Partially Correct' if round(score, 2) == 0.67 else 'Incorrect' |
|
for score in scores] |
|
|
|
color_scale = alt.Scale(domain=['Correct', 'Partially Correct', 'Incorrect'], range=['green', 'orange', |
|
'red']) |
|
else: |
|
score_column_name = f"exact_match_score_{model_configuration}" |
|
scores = self.main_data[score_column_name] |
|
|
|
legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores] |
|
color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red']) |
|
|
|
|
|
token_counts = self.main_data[f'tokens_count_{conf}'] |
|
|
|
|
|
scatter_data = pd.DataFrame({ |
|
'Index': range(len(token_counts)), |
|
'Token Counts': token_counts, |
|
score_name: legend_map |
|
}) |
|
|
|
|
|
chart = alt.Chart(scatter_data).mark_circle( |
|
size=60, |
|
fillOpacity=1, |
|
strokeWidth=1, |
|
stroke='black' |
|
).encode( |
|
x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])), |
|
y=alt.Y('Token Counts', scale=alt.Scale(domain=[token_counts.min()-200, token_counts.max()+200])), |
|
color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)), |
|
tooltip=['Index', 'Token Counts', score_name] |
|
).interactive() |
|
|
|
chart = chart.properties( |
|
title={ |
|
"text": f"Token Counts vs {score_name} + Score + ({model_configuration})", |
|
"color": "black", |
|
"fontSize": 20, |
|
"anchor": "middle", |
|
"offset": 0 |
|
}, |
|
width=700, |
|
height=500 |
|
) |
|
|
|
|
|
st.altair_chart(chart, use_container_width=True) |
|
|
|
@staticmethod |
|
def color_scores(value: float) -> str: |
|
""" |
|
Applies color coding based on the score value. |
|
|
|
Args: |
|
value (float): The score value. |
|
|
|
Returns: |
|
str: CSS color style based on score value. |
|
""" |
|
|
|
try: |
|
value = float(value) |
|
except ValueError: |
|
return 'color: black;' |
|
|
|
if value == 1.0: |
|
return 'color: green;' |
|
elif value == 0.0: |
|
return 'color: red;' |
|
elif value == 0.67: |
|
return 'color: orange;' |
|
return 'color: black;' |
|
|
|
def show_samples(self, num_samples: int = 3) -> None: |
|
""" |
|
Displays random sample images and their associated models answers and evaluations. |
|
|
|
Args: |
|
num_samples (int): Number of sample images to display. |
|
""" |
|
|
|
|
|
target_imgs = random.sample(self.sample_img_pool, num_samples) |
|
|
|
model_configs = [f"{model_name}_{conf}" for model_name in self.model_names for conf in self.model_configs] |
|
|
|
column_names = { |
|
'vqa': 'vqa_score_{config}', |
|
'vqa_gpt4': 'gpt4_vqa_score_{config}', |
|
'em': 'exact_match_score_{config}', |
|
'em_gpt4': 'gpt4_em_score_{config}' |
|
} |
|
|
|
for img_filename in target_imgs: |
|
image_data = self.main_data[self.main_data['image_filename'] == img_filename] |
|
im = Image.open(f"{config.DEMO_IMAGES_PATH}/{img_filename}") |
|
col1, col2 = st.columns([1, 2]) |
|
|
|
with st.container(): |
|
st.write("-------------------------------") |
|
with col1: |
|
st.image(im, use_column_width=True) |
|
with st.expander('Show Caption'): |
|
st.text(image_data.iloc[0]['caption']) |
|
with st.expander('Show DETIC Objects'): |
|
st.text(image_data.iloc[0]['objects_detic_trimmed']) |
|
with st.expander('Show YOLOv5 Objects'): |
|
st.text(image_data.iloc[0]['objects_yolov5']) |
|
with col2: |
|
if not image_data.empty: |
|
st.write(f"**Question:** {image_data.iloc[0]['question']}") |
|
st.write(f"**Ground Truth Answers:** {image_data.iloc[0]['raw_answers']}") |
|
|
|
|
|
summary_data = pd.DataFrame( |
|
columns=['Model Configuration', 'Answer', 'VQA Score', 'VQA Score (GPT-4)', 'EM Score', |
|
'EM Score (GPT-4)']) |
|
|
|
for config in model_configs: |
|
|
|
row_data = { |
|
'Model Configuration': config, |
|
'Answer': image_data.iloc[0].get(f'{config}', '-') |
|
} |
|
for score_type, score_template in column_names.items(): |
|
score_col = score_template.format(config=config) |
|
score_value = image_data.iloc[0].get(score_col, '-') |
|
if pd.notna(score_value) and not isinstance(score_value, str): |
|
|
|
score_value = f"{float(score_value):.2f}" |
|
row_data[score_type.replace('_', ' ').title()] = score_value |
|
|
|
|
|
rd = pd.DataFrame([row_data]) |
|
rd.columns = summary_data.columns |
|
summary_data = pd.concat([summary_data, rd], axis=0, ignore_index=True) |
|
|
|
|
|
styled_summary = summary_data.style.applymap(self.color_scores, |
|
subset=['VQA Score', 'VQA Score (GPT-4)', |
|
'EM Score', |
|
'EM Score (GPT-4)']) |
|
st.markdown(styled_summary.to_html(escape=False, index=False), unsafe_allow_html=True) |
|
else: |
|
st.write("No data available for this image.") |
|
|
|
|