import os.path as osp import argparse import json from data import Tasks, DATASET_TASK_DICT from utils import preprocess_path def process_result(entry, name, task): processed = { 'name': name, 'task': str(task), } if task == Tasks.EXTRACTIVE_QUESTION_ANSWERING: key = 'em,none' if name == 'mkqa_tr' else 'exact,none' scale = 0.01 if name != 'mkqa_tr' else 1 processed['exact_match'] = scale * entry[key] processed['f1'] = scale * entry['f1,none'] elif task == Tasks.SUMMARIZATION: processed['rouge1'] = entry['rouge1,none'] processed['rouge2'] = entry['rouge2,none'] processed['rougeL'] = entry['rougeL,none'] elif task in ( Tasks.MULTIPLE_CHOICE, Tasks.NATURAL_LANGUAGE_INFERENCE, Tasks.TEXT_CLASSIFICATION, ): processed['acc'] = entry['acc,none'] processed['acc_norm'] = entry.get('acc_norm,none', processed['acc']) elif task == Tasks.MACHINE_TRANSLATION: processed['wer'] = entry['wer,none'] processed['bleu'] = entry['bleu,none'] elif task == Tasks.GRAMMATICAL_ERROR_CORRECTION: processed['exact_match'] = entry['exact_match,none'] return processed def main(): parser = argparse.ArgumentParser(description='Results file formatter.') parser.add_argument('-i', '--input-file', type=str, help='Input JSON file for the results.') parser.add_argument('-o', '--output-file', type=str, help='Output JSON file for the formatted results.') args = parser.parse_args() with open(preprocess_path(args.input_file)) as f: raw_data = json.load(f) # first, get model args model_args = raw_data['config']['model_args'].split(',') model_args = dict([tuple(pair.split('=')) for pair in model_args]) processed = dict() model_args['model'] = model_args.pop('pretrained') processed['model'] = model_args processed['model']['api'] = raw_data['config']['model'] # then, process results results = raw_data['results'] processed['results'] = list() for dataset, entry in results.items(): if dataset not in DATASET_TASK_DICT.keys(): continue task = DATASET_TASK_DICT[dataset] processed['results'].append(process_result(entry, dataset, task)) with open(preprocess_path(args.output_file), 'w') as f: json.dump(processed, f, indent=4) print('done') if __name__ == '__main__': main()