KB-VQA / my_model /dataset /dataset_processor.py
m7mdal7aj's picture
Update my_model/dataset/dataset_processor.py
55cd839 verified
raw
history blame
7.04 kB
import json
from collections import Counter
import contractions
import csv
import pandas as pd
from typing import Tuple, List, Optional
from my_model.config import dataset_config as config
class OKVQADatasetProcessor:
"""
Processes the OKVQA dataset by loading, processing, and merging question and annotation data.
Attributes:
questions_file_path (str): Path to the questions JSON file.
annotations_file_path (str): Path to the annotations JSON file.
questions (List[dict]): Extracted list of question entries from the JSON file.
annotations (List[dict]): Extracted list of annotation entries from the JSON file.
df_questions (DataFrame): DataFrame holding the questions.
df_answers (DataFrame): DataFrame holding the annotations.
merged_df (Optional[DataFrame]): DataFrame resulting from merging questions and answers, initialized as None.
"""
def __init__(self, questions_file_path: str, annotations_file_path: str) -> None:
"""
Initializes the dataset processor with file paths and loads the data into DataFrames.
Parameters:
questions_file_path (str): The file path for the questions JSON file.
annotations_file_path (str): The file path for the annotations JSON file.
"""
self.questions_file_path = questions_file_path
self.annotations_file_path = annotations_file_path
self.questions, self.annotations = self.load_data_files()
self.df_questions = pd.DataFrame(self.questions)
self.df_answers = pd.DataFrame(self.annotations)
self.merged_df = None
def load_data_files(self) -> Tuple[List[dict], List[dict]]:
"""
Loads the question and annotation data from JSON files.
Returns:
Tuple[List[dict], List[dict]]: A tuple containing lists of questions and annotations.
"""
with open(self.questions_file_path, 'r') as file:
data = json.load(file)
questions = data['questions']
with open(self.annotations_file_path, 'r') as file:
data = json.load(file)
annotations = data['annotations']
return questions, annotations
@staticmethod
def find_most_frequent(my_list: List[str]) -> Optional[str]:
"""
Determines the most frequent item in a list.
Parameters:
my_list (List[str]): The list from which to find the most frequent item.
Returns:
Optional[str]: The most frequent item or None if the list is empty.
"""
if not my_list:
return None
counter = Counter(my_list)
most_common = counter.most_common(1)
return most_common[0][0]
def merge_data(self) -> None:
"""
Merges the question and answer DataFrames on a common key.
This method sets the 'merged_df' attribute to the resulting DataFrame after merging
'df_questions' and 'df_answers' on the 'question_id' field, which is assumed to be
present in both DataFrames.
"""
self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id'])
def join_words_with_hyphen(self, sentence):
return '-'.join(sentence.split())
def process_answers(self) -> None:
"""
Processes answers from merged DataFrame by extracting and identifying the most frequent answers.
"""
if self.merged_df is not None:
self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x])
self.merged_df['processed_answers'] = self.merged_df['answers'].apply(
lambda x: [ans['answer'] for ans in x])
self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent)
self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply(
self.find_most_frequent)
self.merged_df.drop(columns=['answers'], inplace=True)
else:
print("DataFrames have not been merged yet.")
# Apply the function to the 'most_frequent_processed_answer' column
self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply(
self.join_words_with_hyphen)
def get_processed_data(self) -> Optional[pd.DataFrame]:
"""
Retrieves the processed DataFrame.
Returns:
Optional[pd.DataFrame]: The processed DataFrame or None if it is not available.
"""
if self.merged_df is not None:
return self.merged_df
else:
print("DataFrame is empty or not processed yet.")
return None
def save_to_csv(self, df: pd.DataFrame, saved_file_name: Optional[str]) -> None:
"""
Saves the DataFrame to a CSV file.
Parameters:
df (pd.DataFrame): The DataFrame to save.
saved_file_name (Optional[str]): The target file name or path.
"""
if saved_file_name is not None:
if ".csv" not in saved_file_name:
df.to_csv(os.path.join(saved_file_name, ".csv"), index=None)
else:
df.to_csv(saved_file_name, index=None)
else:
df.to_csv("data.csv", index=None)
def display_dataframe(self) -> None:
"""
Displays the processed DataFrame.
"""
if self.merged_df is not None:
print(self.merged_df)
else:
print("DataFrame is empty.")
def process_okvqa_dataset(questions_file_path: str, annotations_file_path: str, save_to_csv: bool = False,
saved_file_name: Optional[str] = None) -> Optional[pd.DataFrame]:
"""
Orchestrates the processing of the OK-VQA dataset using specified JSON file paths for questions and annotations.
Parameters:
questions_file_path (str): Path to the questions JSON file.
annotations_file_path (str): Path to the annotations JSON file.
save_to_csv (bool): Flag to determine if the processed data should be saved to CSV.
saved_file_name (Optional[str]): Filename or path to save the CSV file. If None, defaults to 'data.csv'.
Returns:
Optional[pd.DataFrame]: The processed DataFrame containing merged and processed VQA data or None if empty.
"""
# Initialize the dataset processor
processor = OKVQADatasetProcessor(questions_file_path, annotations_file_path)
# Merge question and answer data and process answers
processor.merge_data()
processor.process_answers()
# Retrieve the processed DataFrame
processed_data = processor.get_processed_data()
# Optionally save the processed DataFrame to a CSV file
if save_to_csv and processed_data is not None:
processor.save_to_csv(processed_data, saved_file_name)
return processed_data