|
import os.path as osp |
|
import argparse |
|
import json |
|
from data import Tasks, DATASET_TASK_DICT |
|
from utils import preprocess_path |
|
|
|
|
|
def process_result(entry, name, task): |
|
processed = { |
|
'name': name, |
|
'task': str(task), |
|
} |
|
|
|
if task == Tasks.EXTRACTIVE_QUESTION_ANSWERING: |
|
key = 'em,none' if name == 'mkqa_tr' else 'exact,none' |
|
scale = 0.01 if name != 'mkqa_tr' else 1 |
|
processed['exact_match'] = scale * entry[key] |
|
processed['f1'] = scale * entry['f1,none'] |
|
elif task == Tasks.SUMMARIZATION: |
|
processed['rouge1'] = entry['rouge1,none'] |
|
processed['rouge2'] = entry['rouge2,none'] |
|
processed['rougeL'] = entry['rougeL,none'] |
|
elif task in ( |
|
Tasks.MULTIPLE_CHOICE, |
|
Tasks.NATURAL_LANGUAGE_INFERENCE, |
|
Tasks.TEXT_CLASSIFICATION, |
|
): |
|
processed['acc'] = entry['acc,none'] |
|
processed['acc_norm'] = entry.get('acc_norm,none', processed['acc']) |
|
elif task == Tasks.MACHINE_TRANSLATION: |
|
processed['wer'] = entry['wer,none'] |
|
processed['bleu'] = entry['bleu,none'] |
|
elif task == Tasks.GRAMMATICAL_ERROR_CORRECTION: |
|
processed['exact_match'] = entry['exact_match,none'] |
|
|
|
return processed |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Results file formatter.') |
|
parser.add_argument('-i', '--input-file', type=str, help='Input JSON file for the results.') |
|
parser.add_argument('-o', '--output-file', type=str, help='Output JSON file for the formatted results.') |
|
args = parser.parse_args() |
|
|
|
with open(preprocess_path(args.input_file)) as f: |
|
raw_data = json.load(f) |
|
|
|
|
|
model_args = raw_data['config']['model_args'].split(',') |
|
model_args = dict([tuple(pair.split('=')) for pair in model_args]) |
|
processed = dict() |
|
model_args['model'] = model_args.pop('pretrained') |
|
processed['model'] = model_args |
|
processed['model']['api'] = raw_data['config']['model'] |
|
|
|
|
|
results = raw_data['results'] |
|
processed['results'] = list() |
|
for dataset, entry in results.items(): |
|
if dataset not in DATASET_TASK_DICT.keys(): |
|
continue |
|
task = DATASET_TASK_DICT[dataset] |
|
processed['results'].append(process_result(entry, dataset, task)) |
|
|
|
with open(preprocess_path(args.output_file), 'w') as f: |
|
json.dump(processed, f, indent=4) |
|
|
|
print('done') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |