import evaluate import numpy as np import pickle metric = evaluate.load("seqeval") with open('./data/ner_feature.pickle', 'rb') as f: ner_feature = pickle.load(f) label_names = ner_feature.feature.names # label2id = {label: ner_feature.feature.str2int(label) for label in label_names} # id2label = {v: k for k, v in label2id.items()} def compute_metrics(eval_preds): """ This compute_metrics() function first takes the argmax of the logits to convert them to predictions (as usual, the logits and the probabilities are in the same order, so we don’t need to apply the softmax). Then we have to convert both labels and predictions from integers to strings. We remove all the values where the label is -100, then pass the results to the metric.compute() method: """ logits, labels = eval_preds predictions = np.argmax(logits, axis=-1) # Remove ignored index (special tokens) and convert to labels true_labels = [[label_names[l] for l in label if l != -100] for label in labels] true_predictions = [ [label_names[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] all_metrics = metric.compute(predictions=true_predictions, references=true_labels) # return all_metrics # return { # "precision": all_metrics["overall_precision"], # "recall": all_metrics["overall_recall"], # "f1": all_metrics["overall_f1"], # "accuracy": all_metrics["overall_accuracy"], # } return { # organization metrics 'org_precision': all_metrics['ORG']['precision'], 'org_recall': all_metrics['ORG']['recall'], 'org_f1': all_metrics['ORG']['f1'], # person metrics 'per_precision': all_metrics['PER']['precision'], 'per_recall': all_metrics['PER']['recall'], 'per_f1': all_metrics['PER']['f1'], # location metrics 'loc_precision': all_metrics['LOC']['precision'], 'loc_recall': all_metrics['LOC']['recall'], 'loc_f1': all_metrics['LOC']['f1'], # over all metrics 'precision': all_metrics['overall_precision'], 'recall': all_metrics['overall_recall'], 'f1': all_metrics['overall_f1'], 'accuracy': all_metrics['overall_accuracy'] }