Gosse Minnema
Initial commit
05922fb
raw
history blame
2 kB
from argparse import ArgumentParser
import hashlib
import os
from sftp.data_reader import BetterDatasetReader, ConcreteDatasetReader
from tools.ontology_mapping.force_map import ontology_map, read_framenet
def read_ace_better(reader, data_path):
sentences = list()
for ins in reader.read(data_path):
sentences.append(tuple(ins.fields['raw_inputs'].metadata[key] for key in ['sentence', 'spans']))
return sentences
def run(model_path, src_data_path, tgt_data_path, device, dst_path):
if model_path.endswith('.tar.gz'):
model_md5 = hashlib.md5(open(model_path, 'rb').read()).hexdigest()
else:
model_md5 = hashlib.md5(open(os.path.join(model_path, 'model.tar.gz'), 'rb').read()).hexdigest()
print('model md5: ', model_md5)
if 'better' in tgt_data_path.lower():
reader = BetterDatasetReader(eval_type='basic', pretrained_model='roberta-large', ignore_label=False)
elif 'ace' in tgt_data_path.lower():
reader = ConcreteDatasetReader(ignore_unlabeled_sentence=True, pretrained_model='roberta-large')
else:
raise NotImplementedError
meta = {
'model': {'path': model_path, 'md5': model_md5},
'src_data_path': src_data_path,
'tgt_data_path': tgt_data_path
}
# event_list and arg_list are target ontology
# label_list is source ontology (i.e. FrameNet)
src_data, tgt_data = read_framenet(src_data_path), read_ace_better(reader, tgt_data_path)
ontology_map(model_path, src_data, tgt_data, device, dst_path, meta)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('model', metavar='MODEL_PATH')
parser.add_argument('src', metavar='SRC_DATA_PATH')
parser.add_argument('tgt', metavar='TGT_DATA_PATH')
parser.add_argument('dst', metavar='DESTINATION_PATH')
parser.add_argument('-d', type=int, help='device', default=-1)
cmd_args = parser.parse_args()
run(cmd_args.model, cmd_args.src, cmd_args.tgt, cmd_args.d, cmd_args.dst)