Spaces:
Runtime error
Runtime error
import evaluate | |
import numpy as np | |
import pickle | |
metric = evaluate.load("seqeval") | |
with open('./data/ner_feature.pickle', 'rb') as f: | |
ner_feature = pickle.load(f) | |
label_names = ner_feature.feature.names | |
# label2id = {label: ner_feature.feature.str2int(label) for label in label_names} | |
# id2label = {v: k for k, v in label2id.items()} | |
def compute_metrics(eval_preds): | |
""" | |
This compute_metrics() function first takes the argmax of the logits to convert them to predictions | |
(as usual, the logits and the probabilities are in the same order, | |
so we donβt need to apply the softmax). | |
Then we have to convert both labels and predictions from integers to strings. | |
We remove all the values where the label is -100, then pass the results to the metric.compute() method: | |
""" | |
logits, labels = eval_preds | |
predictions = np.argmax(logits, axis=-1) | |
# Remove ignored index (special tokens) and convert to labels | |
true_labels = [[label_names[l] for l in label if l != -100] for label in labels] | |
true_predictions = [ | |
[label_names[p] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
all_metrics = metric.compute(predictions=true_predictions, references=true_labels) | |
# return all_metrics | |
# return { | |
# "precision": all_metrics["overall_precision"], | |
# "recall": all_metrics["overall_recall"], | |
# "f1": all_metrics["overall_f1"], | |
# "accuracy": all_metrics["overall_accuracy"], | |
# } | |
return { | |
# organization metrics | |
'org_precision': all_metrics['ORG']['precision'], | |
'org_recall': all_metrics['ORG']['recall'], | |
'org_f1': all_metrics['ORG']['f1'], | |
# person metrics | |
'per_precision': all_metrics['PER']['precision'], | |
'per_recall': all_metrics['PER']['recall'], | |
'per_f1': all_metrics['PER']['f1'], | |
# location metrics | |
'loc_precision': all_metrics['LOC']['precision'], | |
'loc_recall': all_metrics['LOC']['recall'], | |
'loc_f1': all_metrics['LOC']['f1'], | |
# over all metrics | |
'precision': all_metrics['overall_precision'], | |
'recall': all_metrics['overall_recall'], | |
'f1': all_metrics['overall_f1'], | |
'accuracy': all_metrics['overall_accuracy'] | |
} | |