m7mdal7aj's picture
Update my_model/results/demo.py
7e216b9 verified
raw
history blame
12.2 kB
import os
import altair as alt
from my_model.config import evaluation_config as config
import streamlit as st
from PIL import Image
import pandas as pd
import random
class ResultDemonstrator:
"""
A class to demonstrate the results of the Knowledge-Based Visual Question Answering (KB-VQA) model.
Attributes:
main_data (pd.DataFrame): Data loaded from an Excel file containing evaluation results.
sample_img_pool (list[str]): List of image file names available for demonstration.
model_names (list[str]): List of model names as defined in the configuration.
model_configs (list[str]): List of model configurations as defined in the configuration.
"""
def __init__(self) -> None:
"""
Initializes the ResultDemonstrator class by loading the data from an Excel file.
"""
# Load data
self.main_data = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Main Data")
self.sample_img_pool = list(os.listdir(config.DEMO_IMAGES_PATH))
self.model_names = config.MODEL_NAMES
self.model_configs = config.MODEL_CONFIGURATIONS
@staticmethod
def display_table(data: pd.DataFrame) -> None:
"""
Displays a DataFrame using Streamlit's dataframe display function.
Args:
data (pd.DataFrame): The data to display.
"""
st.dataframe(data)
def calculate_and_append_data(self, data_list: list, score_column: str, model_config: str) -> None:
"""
Calculates mean scores by category and appends them to the data list.
Args:
data_list (list): List to append new data rows.
score_column (str): Name of the column to calculate mean scores for.
model_config (str): Configuration of the model.
"""
if score_column in self.main_data.columns:
category_means = self.main_data.groupby('question_category')[score_column].mean()
for category, mean_value in category_means.items():
data_list.append({
"Category": category,
"Configuration": model_config,
"Mean Value": round(mean_value * 100, 2)
})
def display_ablation_results_per_question_category(self) -> None:
"""Displays ablation results per question category for each model configuration."""
score_types = ['vqa', 'vqa_gpt4', 'em', 'em_gpt4']
data_lists = {key: [] for key in score_types}
column_names = {
'vqa': 'vqa_score_{config}',
'vqa_gpt4': 'gpt4_vqa_score_{config}',
'em': 'exact_match_score_{config}',
'em_gpt4': 'gpt4_em_score_{config}'
}
for model_name in config.MODEL_NAMES:
for conf in config.MODEL_CONFIGURATIONS:
model_config = f"{model_name}_{conf}"
for score_type, col_template in column_names.items():
self.calculate_and_append_data(data_lists[score_type],
col_template.format(config=model_config),
model_config)
# Process and display results for each score type
for score_type, data_list in data_lists.items():
df = pd.DataFrame(data_list)
results_df = df.pivot(index='Category', columns='Configuration', values='Mean Value').applymap(
lambda x: f"{x:.2f}%")
with st.expander(f"{score_type.upper()} Scores per Question Category and Model Configuration"):
self.display_table(results_df)
def display_main_results(self) -> None:
"""Displays the main model results from the Scores sheet, these are displayed from the file directly."""
main_scores = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Scores", index_col=0)
st.markdown("### Main Model Results (Inclusive of Ablation Experiments)")
main_scores.reset_index()
self.display_table(main_scores)
def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None:
"""
Plots an interactive scatter plot comparing token counts to VQA or EM scores using Altair.
Args:
conf (str): The configuration name.
model_name (str): The name of the model.
score_name (str): The type of score to plot.
"""
# Construct the full model configuration name
model_configuration = f"{model_name}_{conf}"
# Determine the score column name and legend mapping based on the score type
if score_name == 'VQA Score':
score_column_name = f"vqa_score_{model_configuration}"
scores = self.main_data[score_column_name]
# Map scores to categories for the legend
legend_map = ['Correct' if score == 1 else 'Partially Correct' if round(score, 2) == 0.67 else 'Incorrect'
for score in scores]
color_scale = alt.Scale(domain=['Correct', 'Partially Correct', 'Incorrect'], range=['green', 'orange',
'red'])
else:
score_column_name = f"exact_match_score_{model_configuration}"
scores = self.main_data[score_column_name]
# Map scores to categories for the legend
legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores]
color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red'])
# Retrieve token counts from the data
token_counts = self.main_data[f'tokens_count_{conf}']
# Create a DataFrame for the scatter plot
scatter_data = pd.DataFrame({
'Index': range(len(token_counts)),
'Token Counts': token_counts,
score_name: legend_map
})
# Create an interactive scatter plot using Altair
chart = alt.Chart(scatter_data).mark_circle(
size=60,
fillOpacity=1, # Sets the fill opacity to maximum
strokeWidth=1, # Adjusts the border width making the circles bolder
stroke='black' # Sets the border color to black
).encode(
x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])),
y=alt.Y('Token Counts', scale=alt.Scale(domain=[token_counts.min()-200, token_counts.max()+200])),
color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)),
tooltip=['Index', 'Token Counts', score_name]
).interactive() # Enables zoom & pan
chart = chart.properties(
title={
"text": f"Token Counts vs {score_name} + Score + ({model_configuration})",
"color": "black", # Optional color
"fontSize": 20, # Optional font size
"anchor": "middle", # Optional anchor position
"offset": 0 # Optional offset
},
width=700,
height=500
)
# Display the interactive plot in Streamlit
st.altair_chart(chart, use_container_width=True)
@staticmethod
def color_scores(value: float) -> str:
"""
Applies color coding based on the score value.
Args:
value (float): The score value.
Returns:
str: CSS color style based on score value.
"""
try:
value = float(value) # Convert to float to handle numerical comparisons
except ValueError:
return 'color: black;' # Return black if value is not a number
if value == 1.0:
return 'color: green;'
elif value == 0.0:
return 'color: red;'
elif value == 0.67:
return 'color: orange;'
return 'color: black;'
def show_samples(self, num_samples: int = 3) -> None:
"""
Displays random sample images and their associated models answers and evaluations.
Args:
num_samples (int): Number of sample images to display.
"""
# Sample images from the pool
target_imgs = random.sample(self.sample_img_pool, num_samples)
# Generate model configurations
model_configs = [f"{model_name}_{conf}" for model_name in self.model_names for conf in self.model_configs]
# Define column names for scores dynamically
column_names = {
'vqa': 'vqa_score_{config}',
'vqa_gpt4': 'gpt4_vqa_score_{config}',
'em': 'exact_match_score_{config}',
'em_gpt4': 'gpt4_em_score_{config}'
}
for img_filename in target_imgs:
image_data = self.main_data[self.main_data['image_filename'] == img_filename]
im = Image.open(f"{config.DEMO_IMAGES_PATH}/{img_filename}")
col1, col2 = st.columns([1, 2]) # to display images side by side with their data.
# Create a container for each image
with st.container():
st.write("-------------------------------")
with col1:
st.image(im, use_column_width=True)
with st.expander('Show Caption'):
st.text(image_data.iloc[0]['caption'])
with st.expander('Show DETIC Objects'):
st.text(image_data.iloc[0]['objects_detic_trimmed'])
with st.expander('Show YOLOv5 Objects'):
st.text(image_data.iloc[0]['objects_yolov5'])
with col2:
if not image_data.empty:
st.write(f"**Question:** {image_data.iloc[0]['question']}")
st.write(f"**Ground Truth Answers:** {image_data.iloc[0]['raw_answers']}")
# Initialize an empty DataFrame for summary data
summary_data = pd.DataFrame(
columns=['Model Configuration', 'Answer', 'VQA Score', 'VQA Score (GPT-4)', 'EM Score',
'EM Score (GPT-4)'])
for config in model_configs:
# Collect data for each model configuration
row_data = {
'Model Configuration': config,
'Answer': image_data.iloc[0].get(f'{config}', '-')
}
for score_type, score_template in column_names.items():
score_col = score_template.format(config=config)
score_value = image_data.iloc[0].get(score_col, '-')
if pd.notna(score_value) and not isinstance(score_value, str):
# Format score to two decimals if it's a valid number
score_value = f"{float(score_value):.2f}"
row_data[score_type.replace('_', ' ').title()] = score_value
# Convert row data to a DataFrame and concatenate it
rd = pd.DataFrame([row_data])
rd.columns = summary_data.columns
summary_data = pd.concat([summary_data, rd], axis=0, ignore_index=True)
# Apply styling to DataFrame for score coloring
styled_summary = summary_data.style.applymap(self.color_scores,
subset=['VQA Score', 'VQA Score (GPT-4)',
'EM Score',
'EM Score (GPT-4)'])
st.markdown(styled_summary.to_html(escape=False, index=False), unsafe_allow_html=True)
else:
st.write("No data available for this image.")