import os.path import logging import pandas as pd from pathlib import Path from datetime import datetime import csv from utils.dedup import Dedup class DatasetBase: """ This class store and manage all the dataset records (including the annotations and prediction) """ def __init__(self, config): if config.records_path is None: self.records = pd.DataFrame(columns=['id', 'text', 'prediction', 'annotation', 'metadata', 'score', 'batch_id']) else: self.records = pd.read_csv(config.records_path) dt_string = datetime.now().strftime("%d_%m_%Y_%H_%M_%S") self.name = config.name + '__' + dt_string self.label_schema = config.label_schema self.dedup = Dedup(config) self.sample_size = config.get("sample_size", 3) self.semantic_sampling = config.get("semantic_sampling", False) if not config.get('dedup_new_samples', False): self.remove_duplicates = self._null_remove def __len__(self): """ Return the number of samples in the dataset. """ return len(self.records) def __getitem__(self, batch_idx): """ Return the batch idx. """ extract_records = self.records[self.records['batch_id'] == batch_idx] extract_records = extract_records.reset_index(drop=True) return extract_records def get_leq(self, batch_idx): """ Return all the records up to batch_idx (includes). """ extract_records = self.records[self.records['batch_id'] <= batch_idx] extract_records = extract_records.reset_index(drop=True) return extract_records def add(self, sample_list: dict = None, batch_id: int = None, records: pd.DataFrame = None): """ Add records to the dataset. :param sample_list: The samples to add in a dict structure (only used in case record=None) :param batch_id: The batch_id for the upload records (only used in case record= None) :param records: dataframes, update using pandas """ if records is None: records = pd.DataFrame([{'id': len(self.records) + i, 'text': sample, 'batch_id': batch_id} for i, sample in enumerate(sample_list)]) self.records = pd.concat([self.records, records], ignore_index=True) def update(self, records: pd.DataFrame): """ Update records in dataset. """ # Ignore if records is empty if len(records) == 0: return # Set 'id' as the index for both DataFrames records.set_index('id', inplace=True) self.records.set_index('id', inplace=True) # Update using 'id' as the key self.records.update(records) # Remove null annotations if len(self.records.loc[self.records["annotation"]=="Discarded"]) > 0: discarded_annotation_records = self.records.loc[self.records["annotation"]=="Discarded"] #TODO: direct `discarded_annotation_records` to another dataset to be used later for corner-cases self.records = self.records.loc[self.records["annotation"]!="Discarded"] # Reset index self.records.reset_index(inplace=True) def modify(self, index: int, record: dict): """ Modify a record in the dataset. """ self.records[index] = record def apply(self, function, column_name: str): """ Apply function on each record. """ self.records[column_name] = self.records.apply(function, axis=1) def save_dataset(self, path: Path): self.records.to_csv(path, index=False, quoting=csv.QUOTE_NONNUMERIC) def load_dataset(self, path: Path): """ Loading dataset :param path: path for the csv """ if os.path.isfile(path): self.records = pd.read_csv(path, dtype={'annotation': str, 'prediction': str, 'batch_id': int}) else: logging.warning('Dataset dump not found, initializing from zero') def remove_duplicates(self, samples: list) -> list: """ Remove (soft) duplicates from the given samples :param samples: The samples :return: The samples without duplicates """ dd = self.dedup.copy() df = pd.DataFrame(samples, columns=['text']) df_dedup = dd.sample(df, operation_function=min) return df_dedup['text'].tolist() def _null_remove(self, samples: list) -> list: # Identity function that returns the input unmodified return samples def sample_records(self, n: int = None) -> pd.DataFrame: """ Return a sample of the records after semantic clustering :param n: The number of samples to return :return: A sample of the records """ n = n or self.sample_size if self.semantic_sampling: dd = self.dedup.copy() df_samples = dd.sample(self.records).head(n) if len(df_samples) < n: df_samples = self.records.head(n) else: df_samples = self.records.sample(n) return df_samples @staticmethod def samples_to_text(records: pd.DataFrame) -> str: """ Return a string that organize the samples for a meta-prompt :param records: The samples for the step :return: A string that contains the organized samples """ txt_res = '##\n' for i, row in records.iterrows(): txt_res += f"Sample:\n {row.text}\n#\n" return txt_res