BiRefNet / BiRefNet_github /eval_existingOnes.py
ZhengPeng7's picture
For users to load in one key.
2a41a22
raw
history blame
No virus
7.3 kB
import os
import argparse
from glob import glob
import prettytable as pt
from evaluation.evaluate import evaluator
from config import Config
config = Config()
def do_eval(args):
# evaluation for whole dataset
# dataset first in evaluation
for _data_name in args.data_lst.split('+'):
pred_data_dir = sorted(glob(os.path.join(args.pred_root, args.model_lst[0], _data_name)))
if not pred_data_dir:
print('Skip dataset {}.'.format(_data_name))
continue
gt_src = os.path.join(args.gt_root, _data_name)
gt_paths = sorted(glob(os.path.join(gt_src, 'gt', '*')))
print('#' * 20, _data_name, '#' * 20)
filename = os.path.join(args.save_dir, '{}_eval.txt'.format(_data_name))
tb = pt.PrettyTable()
tb.vertical_char = '&'
if config.task == 'DIS5K':
tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm"]
elif config.task == 'COD':
tb.field_names = ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE"]
elif config.task == 'HRSOD':
tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
elif config.task == 'DIS5K+HRSOD+HRS10K':
tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm"]
elif config.task == 'P3M-10k':
tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
else:
tb.field_names = ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
for _model_name in args.model_lst[:]:
print('\t', 'Evaluating model: {}...'.format(_model_name))
pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, _model_name)).replace('/gt/', '/') for p in gt_paths]
# print(pred_paths[:1], gt_paths[:1])
em, sm, fm, mae, wfm, hce = evaluator(
gt_paths=gt_paths,
pred_paths=pred_paths,
metrics=args.metrics.split('+'),
verbose=config.verbose_eval
)
if config.task == 'DIS5K':
scores = [
fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()),
em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
]
elif config.task == 'COD':
scores = [
sm.round(3), wfm.round(3), fm['curve'].mean().round(3), em['curve'].mean().round(3), em['curve'].max().round(3), mae.round(3),
fm['curve'].max().round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
elif config.task == 'HRSOD':
scores = [
sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3),
em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
elif config.task == 'DIS5K+HRSOD+HRS10K':
scores = [
fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()),
em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
]
elif config.task == 'P3M-10k':
scores = [
sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3),
em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
else:
scores = [
sm.round(3), mae.round(3), em['curve'].max().round(3), em['curve'].mean().round(3),
fm['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3),
em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
for idx_score, score in enumerate(scores):
scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1 else format(score, '<4')
records = [_data_name, _model_name] + scores
tb.add_row(records)
# Write results after every check.
with open(filename, 'w+') as file_to_write:
file_to_write.write(str(tb)+'\n')
print(tb)
if __name__ == '__main__':
# set parameters
parser = argparse.ArgumentParser()
parser.add_argument(
'--gt_root', type=str, help='ground-truth root',
default=os.path.join(config.data_root_dir, config.task))
parser.add_argument(
'--pred_root', type=str, help='prediction root',
default='./e_preds')
parser.add_argument(
'--data_lst', type=str, help='test dataset',
default={
'DIS5K': '+'.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:]),
'COD': '+'.join(['TE-COD10K', 'NC4K', 'TE-CAMO', 'CHAMELEON'][:]),
'HRSOD': '+'.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'TE-DUTS', 'DUT-OMRON'][:]),
'DIS5K+HRSOD+HRS10K': '+'.join(['DIS-VD'][:]),
'P3M-10k': '+'.join(['TE-P3M-500-P', 'TE-P3M-500-NP'][:]),
}[config.task])
parser.add_argument(
'--save_dir', type=str, help='candidate competitors',
default='e_results')
parser.add_argument(
'--check_integrity', type=bool, help='whether to check the file integrity',
default=False)
parser.add_argument(
'--metrics', type=str, help='candidate competitors',
default='+'.join(['S', 'MAE', 'E', 'F', 'WF', 'HCE'][:100 if 'DIS5K' in config.task else -1]))
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
try:
args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1]), reverse=True) if int(m.split('epoch_')[-1]) % 1 == 0]
except:
args.model_lst = [m for m in sorted(os.listdir(args.pred_root))]
# check the integrity of each candidates
if args.check_integrity:
for _data_name in args.data_lst.split('+'):
for _model_name in args.model_lst:
gt_pth = os.path.join(args.gt_root, _data_name)
pred_pth = os.path.join(args.pred_root, _model_name, _data_name)
if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)):
print(len(sorted(os.listdir(gt_pth))), len(sorted(os.listdir(pred_pth))))
print('The {} Dataset of {} Model is not matching to the ground-truth'.format(_data_name, _model_name))
else:
print('>>> skip check the integrity of each candidates')
# start engine
do_eval(args)