File size: 2,586 Bytes
2a41a22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import os
import prettytable as pt
from evaluation.metrics import evaluator
from config import Config
config = Config()
def evaluate(pred_dir, method, testset, only_S_MAE=False, epoch=0):
filename = os.path.join('evaluation', 'eval-{}.txt'.format(method))
if os.path.exists(filename):
id_suffix = 1
filename = filename.rstrip('.txt') + '_{}.txt'.format(id_suffix)
while os.path.exists(filename):
id_suffix += 1
filename = filename.replace('_{}.txt'.format(id_suffix-1), '_{}.txt'.format(id_suffix))
gt_paths = sorted([
os.path.join(config.data_root_dir, config.task, testset, 'gt', p)
for p in os.listdir(os.path.join(config.data_root_dir, config.task, testset, 'gt'))
])
pred_paths = sorted([os.path.join(pred_dir, method, testset, p) for p in os.listdir(os.path.join(pred_dir, method, testset))])
with open(filename, 'a+') as file_to_write:
tb = pt.PrettyTable()
field_names = [
"Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "maxEm", "meanFm",
"adpEm", "adpFm", 'HCE'
]
tb.field_names = [name for name in field_names if not only_S_MAE or all(metric not in name for metric in ['Em', 'Fm'])]
em, sm, fm, mae, wfm, hce = evaluator(
gt_paths=gt_paths[:],
pred_paths=pred_paths[:],
metrics=['S', 'MAE', 'E', 'F', 'HCE'][:10*(not only_S_MAE) + 2], # , 'WF'
verbose=config.verbose_eval,
)
e_max, e_mean, e_adp = em['curve'].max(), em['curve'].mean(), em['adp'].mean()
f_max, f_mean, f_wfm, f_adp = fm['curve'].max(), fm['curve'].mean(), wfm, fm['adp']
tb.add_row(
[
method+str(epoch), testset, f_max.round(3), f_wfm.round(3), mae.round(3), sm.round(3),
e_mean.round(3), e_max.round(3), f_mean.round(3), em['adp'].round(3), f_adp.round(3), hce.round(3)
] if not only_S_MAE else [method, testset, mae.round(3), sm.round(3)]
)
print(tb)
file_to_write.write(str(tb).replace('+', '|')+'\n')
file_to_write.close()
return {'e_max': e_max, 'e_mean': e_mean, 'e_adp': e_adp, 'sm': sm, 'mae': mae, 'f_max': f_max, 'f_mean': f_mean, 'f_wfm': f_wfm, 'f_adp': f_adp, 'hce': hce}
def main():
only_S_MAE = False
pred_dir = '.'
method = 'tmp_val'
testsets = 'DIS-VD+DIS-TE1'
for testset in testsets.split('+'):
res_dct = evaluate(pred_dir, method, testset, only_S_MAE=only_S_MAE)
if __name__ == '__main__':
main()
|