from argparse import ArgumentParser import hashlib import os from sftp.data_reader import BetterDatasetReader, ConcreteDatasetReader from tools.ontology_mapping.force_map import ontology_map, read_framenet def read_ace_better(reader, data_path): sentences = list() for ins in reader.read(data_path): sentences.append(tuple(ins.fields['raw_inputs'].metadata[key] for key in ['sentence', 'spans'])) return sentences def run(model_path, src_data_path, tgt_data_path, device, dst_path): if model_path.endswith('.tar.gz'): model_md5 = hashlib.md5(open(model_path, 'rb').read()).hexdigest() else: model_md5 = hashlib.md5(open(os.path.join(model_path, 'model.tar.gz'), 'rb').read()).hexdigest() print('model md5: ', model_md5) if 'better' in tgt_data_path.lower(): reader = BetterDatasetReader(eval_type='basic', pretrained_model='roberta-large', ignore_label=False) elif 'ace' in tgt_data_path.lower(): reader = ConcreteDatasetReader(ignore_unlabeled_sentence=True, pretrained_model='roberta-large') else: raise NotImplementedError meta = { 'model': {'path': model_path, 'md5': model_md5}, 'src_data_path': src_data_path, 'tgt_data_path': tgt_data_path } # event_list and arg_list are target ontology # label_list is source ontology (i.e. FrameNet) src_data, tgt_data = read_framenet(src_data_path), read_ace_better(reader, tgt_data_path) ontology_map(model_path, src_data, tgt_data, device, dst_path, meta) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('model', metavar='MODEL_PATH') parser.add_argument('src', metavar='SRC_DATA_PATH') parser.add_argument('tgt', metavar='TGT_DATA_PATH') parser.add_argument('dst', metavar='DESTINATION_PATH') parser.add_argument('-d', type=int, help='device', default=-1) cmd_args = parser.parse_args() run(cmd_args.model, cmd_args.src, cmd_args.tgt, cmd_args.d, cmd_args.dst)