Spaces:
Running
Running
import json | |
from collections import Counter | |
import contractions | |
import csv | |
import pandas as pd | |
from typing import Tuple, List, Optional | |
from my_model.config import dataset_config as config | |
class OKVQADatasetProcessor: | |
""" | |
Processes the OKVQA dataset by loading, processing, and merging question and annotation data. | |
Attributes: | |
questions_file_path (str): Path to the questions JSON file. | |
annotations_file_path (str): Path to the annotations JSON file. | |
questions (List[dict]): Extracted list of question entries from the JSON file. | |
annotations (List[dict]): Extracted list of annotation entries from the JSON file. | |
df_questions (DataFrame): DataFrame holding the questions. | |
df_answers (DataFrame): DataFrame holding the annotations. | |
merged_df (Optional[DataFrame]): DataFrame resulting from merging questions and answers, initialized as None. | |
""" | |
def __init__(self, questions_file_path: str, annotations_file_path: str) -> None: | |
""" | |
Initializes the dataset processor with file paths and loads the data into DataFrames. | |
Parameters: | |
questions_file_path (str): The file path for the questions JSON file. | |
annotations_file_path (str): The file path for the annotations JSON file. | |
""" | |
self.questions_file_path = questions_file_path | |
self.annotations_file_path = annotations_file_path | |
self.questions, self.annotations = self.load_data_files() | |
self.df_questions = pd.DataFrame(self.questions) | |
self.df_answers = pd.DataFrame(self.annotations) | |
self.merged_df = None | |
def load_data_files(self) -> Tuple[List[dict], List[dict]]: | |
""" | |
Loads the question and annotation data from JSON files. | |
Returns: | |
Tuple[List[dict], List[dict]]: A tuple containing lists of questions and annotations. | |
""" | |
with open(self.questions_file_path, 'r') as file: | |
data = json.load(file) | |
questions = data['questions'] | |
with open(self.annotations_file_path, 'r') as file: | |
data = json.load(file) | |
annotations = data['annotations'] | |
return questions, annotations | |
def find_most_frequent(my_list: List[str]) -> Optional[str]: | |
""" | |
Determines the most frequent item in a list. | |
Parameters: | |
my_list (List[str]): The list from which to find the most frequent item. | |
Returns: | |
Optional[str]: The most frequent item or None if the list is empty. | |
""" | |
if not my_list: | |
return None | |
counter = Counter(my_list) | |
most_common = counter.most_common(1) | |
return most_common[0][0] | |
def merge_data(self) -> None: | |
""" | |
Merges the question and answer DataFrames on a common key. | |
This method sets the 'merged_df' attribute to the resulting DataFrame after merging | |
'df_questions' and 'df_answers' on the 'question_id' field, which is assumed to be | |
present in both DataFrames. | |
""" | |
self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id']) | |
def join_words_with_hyphen(self, sentence): | |
return '-'.join(sentence.split()) | |
def process_answers(self) -> None: | |
""" | |
Processes answers from merged DataFrame by extracting and identifying the most frequent answers. | |
""" | |
if self.merged_df is not None: | |
self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x]) | |
self.merged_df['processed_answers'] = self.merged_df['answers'].apply( | |
lambda x: [ans['answer'] for ans in x]) | |
self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent) | |
self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply( | |
self.find_most_frequent) | |
self.merged_df.drop(columns=['answers'], inplace=True) | |
else: | |
print("DataFrames have not been merged yet.") | |
# Apply the function to the 'most_frequent_processed_answer' column | |
self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply( | |
self.join_words_with_hyphen) | |
def get_processed_data(self) -> Optional[pd.DataFrame]: | |
""" | |
Retrieves the processed DataFrame. | |
Returns: | |
Optional[pd.DataFrame]: The processed DataFrame or None if it is not available. | |
""" | |
if self.merged_df is not None: | |
return self.merged_df | |
else: | |
print("DataFrame is empty or not processed yet.") | |
return None | |
def save_to_csv(self, df: pd.DataFrame, saved_file_name: Optional[str]) -> None: | |
""" | |
Saves the DataFrame to a CSV file. | |
Parameters: | |
df (pd.DataFrame): The DataFrame to save. | |
saved_file_name (Optional[str]): The target file name or path. | |
""" | |
if saved_file_name is not None: | |
if ".csv" not in saved_file_name: | |
df.to_csv(os.path.join(saved_file_name, ".csv"), index=None) | |
else: | |
df.to_csv(saved_file_name, index=None) | |
else: | |
df.to_csv("data.csv", index=None) | |
def display_dataframe(self) -> None: | |
""" | |
Displays the processed DataFrame. | |
""" | |
if self.merged_df is not None: | |
print(self.merged_df) | |
else: | |
print("DataFrame is empty.") | |
def process_okvqa_dataset(questions_file_path: str, annotations_file_path: str, save_to_csv: bool = False, | |
saved_file_name: Optional[str] = None) -> Optional[pd.DataFrame]: | |
""" | |
Orchestrates the processing of the OK-VQA dataset using specified JSON file paths for questions and annotations. | |
Parameters: | |
questions_file_path (str): Path to the questions JSON file. | |
annotations_file_path (str): Path to the annotations JSON file. | |
save_to_csv (bool): Flag to determine if the processed data should be saved to CSV. | |
saved_file_name (Optional[str]): Filename or path to save the CSV file. If None, defaults to 'data.csv'. | |
Returns: | |
Optional[pd.DataFrame]: The processed DataFrame containing merged and processed VQA data or None if empty. | |
""" | |
# Initialize the dataset processor | |
processor = OKVQADatasetProcessor(questions_file_path, annotations_file_path) | |
# Merge question and answer data and process answers | |
processor.merge_data() | |
processor.process_answers() | |
# Retrieve the processed DataFrame | |
processed_data = processor.get_processed_data() | |
# Optionally save the processed DataFrame to a CSV file | |
if save_to_csv and processed_data is not None: | |
processor.save_to_csv(processed_data, saved_file_name) | |
return processed_data | |