File size: 2,361 Bytes
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import evaluate
import numpy as np
import pickle

metric = evaluate.load("seqeval")
with open('./data/ner_feature.pickle', 'rb') as f:
    ner_feature = pickle.load(f)
    
label_names = ner_feature.feature.names
# label2id = {label: ner_feature.feature.str2int(label) for label in label_names}
# id2label = {v: k for k, v in label2id.items()}

def compute_metrics(eval_preds):
    """
    This compute_metrics() function first takes the argmax of the logits to convert them to predictions
    (as usual, the logits and the probabilities are in the same order,
    so we don’t need to apply the softmax).
    Then we have to convert both labels and predictions from integers to strings.
    We remove all the values where the label is -100, then pass the results to the metric.compute() method:
    """

    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    
    # return all_metrics
    # return {
    #     "precision": all_metrics["overall_precision"],
    #     "recall": all_metrics["overall_recall"],
    #     "f1": all_metrics["overall_f1"],
    #     "accuracy": all_metrics["overall_accuracy"],
    # }

    return {
        # organization metrics
        'org_precision': all_metrics['ORG']['precision'],
        'org_recall': all_metrics['ORG']['recall'],
        'org_f1': all_metrics['ORG']['f1'],
        
        # person metrics
        'per_precision': all_metrics['PER']['precision'],
        'per_recall': all_metrics['PER']['recall'],
        'per_f1': all_metrics['PER']['f1'],
        
        # location metrics
        'loc_precision': all_metrics['LOC']['precision'],
        'loc_recall': all_metrics['LOC']['recall'],
        'loc_f1': all_metrics['LOC']['f1'],
        
        # over all metrics
        'precision': all_metrics['overall_precision'],
        'recall': all_metrics['overall_recall'],
        'f1': all_metrics['overall_f1'],
        'accuracy': all_metrics['overall_accuracy']
    }