Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
import hashlib | |
import os | |
from sftp.data_reader import BetterDatasetReader, ConcreteDatasetReader | |
from tools.ontology_mapping.force_map import ontology_map, read_framenet | |
def read_ace_better(reader, data_path): | |
sentences = list() | |
for ins in reader.read(data_path): | |
sentences.append(tuple(ins.fields['raw_inputs'].metadata[key] for key in ['sentence', 'spans'])) | |
return sentences | |
def run(model_path, src_data_path, tgt_data_path, device, dst_path): | |
if model_path.endswith('.tar.gz'): | |
model_md5 = hashlib.md5(open(model_path, 'rb').read()).hexdigest() | |
else: | |
model_md5 = hashlib.md5(open(os.path.join(model_path, 'model.tar.gz'), 'rb').read()).hexdigest() | |
print('model md5: ', model_md5) | |
if 'better' in tgt_data_path.lower(): | |
reader = BetterDatasetReader(eval_type='basic', pretrained_model='roberta-large', ignore_label=False) | |
elif 'ace' in tgt_data_path.lower(): | |
reader = ConcreteDatasetReader(ignore_unlabeled_sentence=True, pretrained_model='roberta-large') | |
else: | |
raise NotImplementedError | |
meta = { | |
'model': {'path': model_path, 'md5': model_md5}, | |
'src_data_path': src_data_path, | |
'tgt_data_path': tgt_data_path | |
} | |
# event_list and arg_list are target ontology | |
# label_list is source ontology (i.e. FrameNet) | |
src_data, tgt_data = read_framenet(src_data_path), read_ace_better(reader, tgt_data_path) | |
ontology_map(model_path, src_data, tgt_data, device, dst_path, meta) | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
parser.add_argument('model', metavar='MODEL_PATH') | |
parser.add_argument('src', metavar='SRC_DATA_PATH') | |
parser.add_argument('tgt', metavar='TGT_DATA_PATH') | |
parser.add_argument('dst', metavar='DESTINATION_PATH') | |
parser.add_argument('-d', type=int, help='device', default=-1) | |
cmd_args = parser.parse_args() | |
run(cmd_args.model, cmd_args.src, cmd_args.tgt, cmd_args.d, cmd_args.dst) | |