Spaces:
Running
Running
File size: 12,285 Bytes
062b387 69ec8fe 062b387 71aefae 062b387 69ec8fe 062b387 736e48a 062b387 dd22c30 062b387 dd22c30 062b387 dd22c30 062b387 dd22c30 062b387 dd22c30 062b387 e2b7078 062b387 e5da591 69ec8fe 062b387 736e48a 062b387 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
import os
import altair as alt
from my_model.config import evaluation_config as config
import streamlit as st
from PIL import Image
import pandas as pd
import random
class ResultDemonstrator:
"""
A class to demonstrate the results of the Knowledge-Based Visual Question Answering (KB-VQA) model.
Attributes:
main_data (pd.DataFrame): Data loaded from an Excel file containing evaluation results.
sample_img_pool (list[str]): List of image file names available for demonstration.
model_names (list[str]): List of model names as defined in the configuration.
model_configs (list[str]): List of model configurations as defined in the configuration.
demo_images_path(str): Path to the demo images directory.
"""
def __init__(self) -> None:
"""
Initializes the ResultDemonstrator class by loading the data from an Excel file.
"""
# Load data
self.main_data = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Main Data")
self.sample_img_pool = list(os.listdir(config.DEMO_IMAGES_PATH))
self.model_names = config.MODEL_NAMES
self.model_configs = config.MODEL_CONFIGURATIONS
self.demo_images_path = config.DEMO_IMAGES_PATH
@staticmethod
def display_table(data: pd.DataFrame) -> None:
"""
Displays a DataFrame using Streamlit's dataframe display function.
Args:
data (pd.DataFrame): The data to display.
"""
st.dataframe(data)
def calculate_and_append_data(self, data_list: list, score_column: str, model_config: str) -> None:
"""
Calculates mean scores by category and appends them to the data list.
Args:
data_list (list): List to append new data rows.
score_column (str): Name of the column to calculate mean scores for.
model_config (str): Configuration of the model.
"""
if score_column in self.main_data.columns:
category_means = self.main_data.groupby('question_category')[score_column].mean()
for category, mean_value in category_means.items():
data_list.append({
"Category": category,
"Configuration": model_config,
"Mean Value": round(mean_value * 100, 2)
})
def display_ablation_results_per_question_category(self) -> None:
"""Displays ablation results per question category for each model configuration."""
score_types = ['vqa', 'vqa_gpt4', 'em', 'em_gpt4']
data_lists = {key: [] for key in score_types}
column_names = {
'vqa': 'vqa_score_{config}',
'vqa_gpt4': 'gpt4_vqa_score_{config}',
'em': 'exact_match_score_{config}',
'em_gpt4': 'gpt4_em_score_{config}'
}
for model_name in config.MODEL_NAMES:
for conf in config.MODEL_CONFIGURATIONS:
model_config = f"{model_name}_{conf}"
for score_type, col_template in column_names.items():
self.calculate_and_append_data(data_lists[score_type],
col_template.format(config=model_config),
model_config)
# Process and display results for each score type
for score_type, data_list in data_lists.items():
df = pd.DataFrame(data_list)
results_df = df.pivot(index='Category', columns='Configuration', values='Mean Value').applymap(
lambda x: f"{x:.2f}%")
with st.expander(f"{score_type.upper()} Scores per Question Category and Model Configuration"):
self.display_table(results_df)
def display_main_results(self) -> None:
"""Displays the main model results from the Scores sheet, these are displayed from the file directly."""
main_scores = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Scores", index_col=0)
st.markdown("### Main Model Results (Inclusive of Ablation Experiments)")
main_scores.reset_index()
self.display_table(main_scores)
def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None:
"""
Plots an interactive scatter plot comparing token count to VQA or EM scores using Altair.
Args:
conf (str): The configuration name.
model_name (str): The name of the model.
score_name (str): The type of score to plot.
"""
# Construct the full model configuration name
model_configuration = f"{model_name}_{conf}"
# Determine the score column name and legend mapping based on the score type
if score_name == 'VQA Score':
score_column_name = f"vqa_score_{model_configuration}"
scores = self.main_data[score_column_name]
# Map scores to categories for the legend
legend_map = ['Correct' if score == 1 else 'Partially Correct' if round(score, 2) == 0.67 else 'Incorrect'
for score in scores]
color_scale = alt.Scale(domain=['Correct', 'Partially Correct', 'Incorrect'], range=['green', 'orange',
'red'])
else:
score_column_name = f"exact_match_score_{model_configuration}"
scores = self.main_data[score_column_name]
# Map scores to categories for the legend
legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores]
color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red'])
# Retrieve token count from the data
token_count = self.main_data[f'tokens_count_{conf}']
# Create a DataFrame for the scatter plot
scatter_data = pd.DataFrame({
'Index': range(len(token_count)),
'Token Count': token_count,
score_name: legend_map
})
# Create an interactive scatter plot using Altair
chart = alt.Chart(scatter_data).mark_circle(
size=60,
fillOpacity=1, # Sets the fill opacity to maximum
strokeWidth=1, # Adjusts the border width making the circles bolder
stroke='black' # Sets the border color to black
).encode(
x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])),
y=alt.Y('Token Count', scale=alt.Scale(domain=[token_count.min()-200, token_count.max()+200])),
color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)),
tooltip=['Index', 'Token Count', score_name]
).interactive() # Enables zoom & pan
chart = chart.properties(
title={
"text": f"Token Count vs {score_name} ({model_configuration.replace('_', '-')})",
"color": "black", # Optional color
"fontSize": 20, # Optional font size
"anchor": "middle", # Optional anchor position
"offset": 0 # Optional offset
},
width=700,
height=500
)
# Display the interactive plot in Streamlit
st.altair_chart(chart, use_container_width=True)
@staticmethod
def color_scores(value: float) -> str:
"""
Applies color coding based on the score value.
Args:
value (float): The score value.
Returns:
str: CSS color style based on score value.
"""
try:
value = float(value) # Convert to float to handle numerical comparisons
except ValueError:
return 'color: black;' # Return black if value is not a number
if value == 1.0:
return 'color: green;'
elif value == 0.0:
return 'color: red;'
elif value == 0.67:
return 'color: orange;'
return 'color: black;'
def show_samples(self, num_samples: int = 3) -> None:
"""
Displays random sample images and their associated models answers and evaluations.
Args:
num_samples (int): Number of sample images to display.
"""
# Sample images from the pool
target_imgs = random.sample(self.sample_img_pool, num_samples)
# Generate model configurations
model_configs = [f"{model_name}_{conf}" for model_name in self.model_names for conf in self.model_configs]
# Define column names for scores dynamically
column_names = {
'vqa': 'vqa_score_{config}',
'vqa_gpt4': 'gpt4_vqa_score_{config}',
'em': 'exact_match_score_{config}',
'em_gpt4': 'gpt4_em_score_{config}'
}
for img_filename in target_imgs:
image_data = self.main_data[self.main_data['image_filename'] == img_filename]
im = Image.open(f"{self.demo_images_path}/{img_filename}")
col1, col2 = st.columns([1, 2]) # to display images side by side with their data.
# Create a container for each image
with st.container():
st.write("-------------------------------")
with col1:
st.image(im, use_column_width=True)
with st.expander('Show Caption'):
st.text(image_data.iloc[0]['caption'])
with st.expander('Show DETIC Objects'):
st.text(image_data.iloc[0]['objects_detic_trimmed'])
with st.expander('Show YOLOv5 Objects'):
st.text(image_data.iloc[0]['objects_yolov5'])
with col2:
if not image_data.empty:
st.write(f"**Question:** {image_data.iloc[0]['question']}")
st.write(f"**Ground Truth Answers:** {image_data.iloc[0]['raw_answers']}")
# Initialize an empty DataFrame for summary data
summary_data = pd.DataFrame(
columns=['Model Configuration', 'Answer', 'VQA Score', 'VQA Score (GPT-4)', 'EM Score',
'EM Score (GPT-4)'])
for config in model_configs:
# Collect data for each model configuration
row_data = {
'Model Configuration': config,
'Answer': image_data.iloc[0].get(f'{config}', '-')
}
for score_type, score_template in column_names.items():
score_col = score_template.format(config=config)
score_value = image_data.iloc[0].get(score_col, '-')
if pd.notna(score_value) and not isinstance(score_value, str):
# Format score to two decimals if it's a valid number
score_value = f"{float(score_value):.2f}"
row_data[score_type.replace('_', ' ').title()] = score_value
# Convert row data to a DataFrame and concatenate it
rd = pd.DataFrame([row_data])
rd.columns = summary_data.columns
summary_data = pd.concat([summary_data, rd], axis=0, ignore_index=True)
# Apply styling to DataFrame for score coloring
styled_summary = summary_data.style.applymap(self.color_scores,
subset=['VQA Score', 'VQA Score (GPT-4)',
'EM Score',
'EM Score (GPT-4)'])
st.markdown(styled_summary.to_html(escape=False, index=False), unsafe_allow_html=True)
else:
st.write("No data available for this image.")
|