Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,873 Bytes
e1aa577 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import eval.eval_utils as utils
class Eval:
"""
The Eval class is responsible to calculate the score and the large errors
"""
def __init__(self, config, analyzer=None, label_schema=None):
"""
Initialize a new instance of the Eval class.
:param config: The configuration file (EasyDict)
:analyzer (optional): A chain that analyze the errors
:label_schema (optional): The label schema
"""
self.score_function_name = config.function_name
self.score_func = self.get_eval_function(config)
self.num_errors = config.num_large_errors
self.error_threshold = config.error_threshold
self.dataset = None
self.mean_score = None
self.label_schema = label_schema
self.errors = None
self.history = []
self.analyzer = analyzer
@staticmethod
def get_eval_function(config: dict):
"""
Returns the eval function
:param config: The eval configuration
:return: The function implementation on a record
"""
if config.function_name == 'accuracy':
return utils.set_function_from_iterrow(lambda record: record['annotation'] == record['prediction'])
elif config.function_name == 'ranking':
return utils.set_ranking_function(config.function_params)
else:
raise NotImplementedError("Eval function not implemented")
def eval_score(self) -> float:
"""
Calculate the score on each row and return the mean score.
:return: The mean score
"""
# filter out the discarded samples
self.dataset = self.dataset[(self.dataset['prediction'] != 'Discarded') &
(self.dataset['annotation'] != 'Discarded')]
self.dataset = self.score_func(self.dataset)
self.mean_score = self.dataset['score'].mean()
return self.mean_score
def get_max_score(self, warmup=0):
"""
Return the maximum 'mean score' (with respect to all history epochs, starting form warmup, up to last) and the epoch index of the maximum score
:return: The epoch index of the maximum score, and the maximum score
"""
max_idx = np.argmax([epoch['score'] for epoch in self.history[warmup:-1]])
max_idx += warmup
return max_idx, self.history[max_idx]['score']
def large_error_to_str(self, error_df: pd.DataFrame, num_large_errors_per_label: int) -> str:
"""
Return a string that contains the large errors
:param error_df: A dataframe contains all the mislabeled samples
:param num_large_errors_per_label: The (maximum) number of large errors per label
:return: A string that contains the large errors that is used in the meta-prompt
"""
required_columns = ['annotation', 'text', 'score', 'prediction']
label_schema = error_df['annotation'].unique()
if self.score_function_name == 'ranker':
gt_name = 'Rank:'
else:
gt_name = 'GT:'
error_res_df_list = []
txt_res = ''
for label in label_schema:
cur_df = error_df[error_df['annotation'] == label]
cur_df = cur_df.sample(frac=1.0, random_state=42)[:num_large_errors_per_label]
error_res_df_list.append(cur_df[required_columns])
if len(error_res_df_list) > 0:
error_res_df = pd.concat(error_res_df_list, ignore_index=True)
error_res_df = error_res_df.sample(frac=1.0, random_state=42)
for i, row in error_res_df.iterrows():
txt_res += f"Sample: {row.text}\nPrediction: {row.prediction}, {gt_name}: {row.annotation}\n#\n"
return txt_res
def sample_to_text(self, sample: dict, num_errors_per_label: int = 0, is_score: bool = True) -> str:
"""
Return a string that organize the information of from the step run for the meta-prompt
:param sample: The eval information for specific step
:param num_errors_per_label: The max number of large errors per class that will appear in the meta-prompt
:param is_score: If True, add the score information to the meta-prompt
:return: A string that contains the information of the step run
"""
if is_score:
return f"####\n##Prompt Score: {sample['score']:.2f}\n##Prompt:\n{sample['prompt']}\n#################\n"
else:
return f"####\n##Prompt:\n{sample['prompt']}\n{self.large_error_to_str(sample['errors'], num_errors_per_label)}####\n "
def add_history(self, prompt: str, task_description: str):
"""
Add the current step information to the history
:param prompt: The current prompt
:param task_description: The task description
"""
conf_matrix = None
large_error_to_str = self.large_error_to_str(self.errors, self.num_errors)
prompt_input = {'task_description': task_description, 'accuracy': self.mean_score, 'prompt': prompt,
'failure_cases': large_error_to_str}
if self.score_function_name == 'accuracy':
conf_matrix = confusion_matrix(self.dataset['annotation'],
self.dataset['prediction'], labels=self.label_schema)
conf_text = f"Confusion matrix columns:{self.label_schema} the matrix data:"
for i, row in enumerate(conf_matrix):
conf_text += f"\n{self.label_schema[i]}: {row}"
prompt_input['confusion_matrix'] = conf_text
elif self.score_function_name == 'ranking':
prompt_input['labels'] = self.label_schema
analysis = self.analyzer.invoke(prompt_input)
self.history.append({'prompt': prompt, 'score': self.mean_score,
'errors': self.errors, 'confusion_matrix': conf_matrix, 'analysis': analysis['text']})
def extract_errors(self) -> pd.DataFrame:
"""
Extract the errors from the dataset
:return: records that contains the errors
"""
df = self.dataset
err_df = df[df['score'] < self.error_threshold]
err_df.sort_values(by=['score'])
self.errors = err_df
return self.errors
def extract_correct(self) -> pd.DataFrame:
"""
Extract the correct samples from the dataset
:return: records that contains the correct samples
"""
df = self.dataset
return df[df['score'] > self.error_threshold]
def extract_boundary_predictions(self) -> pd.DataFrame:
"""
Extract boundary samples on which the model is uncertain
:return: records that contains boundary samples
"""
pass |