File size: 2,586 Bytes
2a41a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import prettytable as pt

from evaluation.metrics import evaluator
from config import Config


config = Config()

def evaluate(pred_dir, method, testset, only_S_MAE=False, epoch=0):
    filename = os.path.join('evaluation', 'eval-{}.txt'.format(method))
    if os.path.exists(filename):
        id_suffix = 1
        filename = filename.rstrip('.txt') + '_{}.txt'.format(id_suffix)
        while os.path.exists(filename):
            id_suffix += 1
            filename = filename.replace('_{}.txt'.format(id_suffix-1), '_{}.txt'.format(id_suffix))
    gt_paths = sorted([
        os.path.join(config.data_root_dir, config.task, testset, 'gt', p)
        for p in os.listdir(os.path.join(config.data_root_dir, config.task, testset, 'gt'))
    ])
    pred_paths = sorted([os.path.join(pred_dir, method, testset, p) for p in os.listdir(os.path.join(pred_dir, method, testset))])
    with open(filename, 'a+') as file_to_write:
        tb = pt.PrettyTable()
        field_names = [
            "Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "maxEm", "meanFm",
            "adpEm", "adpFm", 'HCE'
        ]
        tb.field_names = [name for name in field_names if not only_S_MAE or all(metric not in name for metric in ['Em', 'Fm'])]
        em, sm, fm, mae, wfm, hce = evaluator(
            gt_paths=gt_paths[:],
            pred_paths=pred_paths[:],
            metrics=['S', 'MAE', 'E', 'F', 'HCE'][:10*(not only_S_MAE) + 2],    # , 'WF'
            verbose=config.verbose_eval,
        )
        e_max, e_mean, e_adp = em['curve'].max(), em['curve'].mean(), em['adp'].mean()
        f_max, f_mean, f_wfm, f_adp = fm['curve'].max(), fm['curve'].mean(), wfm, fm['adp']
        tb.add_row(
            [
                method+str(epoch), testset, f_max.round(3), f_wfm.round(3), mae.round(3), sm.round(3),
                e_mean.round(3), e_max.round(3), f_mean.round(3), em['adp'].round(3), f_adp.round(3), hce.round(3)
            ] if not only_S_MAE else [method, testset, mae.round(3), sm.round(3)]
        )
        print(tb)
        file_to_write.write(str(tb).replace('+', '|')+'\n')
        file_to_write.close()
    return {'e_max': e_max, 'e_mean': e_mean, 'e_adp': e_adp, 'sm': sm, 'mae': mae, 'f_max': f_max, 'f_mean': f_mean, 'f_wfm': f_wfm, 'f_adp': f_adp, 'hce': hce}


def main():
    only_S_MAE = False
    pred_dir = '.'
    method = 'tmp_val'
    testsets = 'DIS-VD+DIS-TE1'
    for testset in testsets.split('+'):
        res_dct = evaluate(pred_dir, method, testset, only_S_MAE=only_S_MAE)


if __name__ == '__main__':
    main()