import json from collections import Counter import contractions import csv from typing import Tuple, List, Optional from my_model.config import dataset_config as config class OKVQADatasetProcessor: """ Processes the OKVQA dataset by loading, processing, and merging question and annotation data. Attributes: questions_file_path (str): Path to the questions JSON file. annotations_file_path (str): Path to the annotations JSON file. questions (List[dict]): Extracted list of question entries from the JSON file. annotations (List[dict]): Extracted list of annotation entries from the JSON file. df_questions (DataFrame): DataFrame holding the questions. df_answers (DataFrame): DataFrame holding the annotations. merged_df (Optional[DataFrame]): DataFrame resulting from merging questions and answers, initialized as None. """ def __init__(self, questions_file_path: str, annotations_file_path: str) -> None: """ Initializes the dataset processor with file paths and loads the data into DataFrames. Parameters: questions_file_path (str): The file path for the questions JSON file. annotations_file_path (str): The file path for the annotations JSON file. """ self.questions_file_path = questions_file_path self.annotations_file_path = annotations_file_path self.questions, self.annotations = self.load_data_files() self.df_questions = pd.DataFrame(self.questions) self.df_answers = pd.DataFrame(self.annotations) self.merged_df = None def load_data_files(self) -> Tuple[List[dict], List[dict]]: """ Loads the question and annotation data from JSON files. Returns: Tuple[List[dict], List[dict]]: A tuple containing lists of questions and annotations. """ with open(self.questions_file_path, 'r') as file: data = json.load(file) questions = data['questions'] with open(self.annotations_file_path, 'r') as file: data = json.load(file) annotations = data['annotations'] return questions, annotations @staticmethod def find_most_frequent(my_list: List[str]) -> Optional[str]: """ Determines the most frequent item in a list. Parameters: my_list (List[str]): The list from which to find the most frequent item. Returns: Optional[str]: The most frequent item or None if the list is empty. """ if not my_list: return None counter = Counter(my_list) most_common = counter.most_common(1) return most_common[0][0] def merge_data(self) -> None: """ Merges the question and answer DataFrames on a common key. This method sets the 'merged_df' attribute to the resulting DataFrame after merging 'df_questions' and 'df_answers' on the 'question_id' field, which is assumed to be present in both DataFrames. """ self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id']) def join_words_with_hyphen(self, sentence): return '-'.join(sentence.split()) def process_answers(self) -> None: """ Processes answers from merged DataFrame by extracting and identifying the most frequent answers. """ if self.merged_df is not None: self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x]) self.merged_df['processed_answers'] = self.merged_df['answers'].apply( lambda x: [ans['answer'] for ans in x]) self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent) self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply( self.find_most_frequent) self.merged_df.drop(columns=['answers'], inplace=True) else: print("DataFrames have not been merged yet.") # Apply the function to the 'most_frequent_processed_answer' column self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply( self.join_words_with_hyphen) def get_processed_data(self) -> Optional[pd.DataFrame]: """ Retrieves the processed DataFrame. Returns: Optional[pd.DataFrame]: The processed DataFrame or None if it is not available. """ if self.merged_df is not None: return self.merged_df else: print("DataFrame is empty or not processed yet.") return None def save_to_csv(self, df: pd.DataFrame, saved_file_name: Optional[str]) -> None: """ Saves the DataFrame to a CSV file. Parameters: df (pd.DataFrame): The DataFrame to save. saved_file_name (Optional[str]): The target file name or path. """ if saved_file_name is not None: if ".csv" not in saved_file_name: df.to_csv(os.path.join(saved_file_name, ".csv"), index=None) else: df.to_csv(saved_file_name, index=None) else: df.to_csv("data.csv", index=None) def display_dataframe(self) -> None: """ Displays the processed DataFrame. """ if self.merged_df is not None: print(self.merged_df) else: print("DataFrame is empty.") def process_okvqa_dataset(questions_file_path: str, annotations_file_path: str, save_to_csv: bool = False, saved_file_name: Optional[str] = None) -> Optional[pd.DataFrame]: """ Orchestrates the processing of the OK-VQA dataset using specified JSON file paths for questions and annotations. Parameters: questions_file_path (str): Path to the questions JSON file. annotations_file_path (str): Path to the annotations JSON file. save_to_csv (bool): Flag to determine if the processed data should be saved to CSV. saved_file_name (Optional[str]): Filename or path to save the CSV file. If None, defaults to 'data.csv'. Returns: Optional[pd.DataFrame]: The processed DataFrame containing merged and processed VQA data or None if empty. """ # Initialize the dataset processor processor = OKVQADatasetProcessor(questions_file_path, annotations_file_path) # Merge question and answer data and process answers processor.merge_data() processor.process_answers() # Retrieve the processed DataFrame processed_data = processor.get_processed_data() # Optionally save the processed DataFrame to a CSV file if save_to_csv and processed_data is not None: processor.save_to_csv(processed_data, saved_file_name) return processed_data