import json import os from collections import defaultdict from typing import * import numpy as np from allennlp.data import Vocabulary from tqdm import tqdm from sftp import SpanPredictor, Span from sftp.utils import VIRTUAL_ROOT def read_framenet(path: str): ret = list() for line in map(json.loads, open(path).readlines()): ret.append((line['tokens'], Span.from_json(line['annotations']))) return ret def co_occur( predictor: SpanPredictor, sentences: List[Tuple[List[str], Span]], event_list: List[str], arg_list: List[str], ): idx2label = predictor.vocab.get_index_to_token_vocabulary('span_label') event_count = np.zeros([len(event_list), len(idx2label)], np.float64) arg_count = np.zeros([len(arg_list), len(idx2label)], np.float64) for sent, vr in tqdm(sentences): # For events _, _, event_dist = predictor.force_decode(sent, child_spans=[event.boundary for event in vr]) for event, dist in zip(vr, event_dist): event_count[event_list.index(event.label)] += dist # For args for event, one_event_dist in zip(vr, event_dist): parent_label = idx2label[int(one_event_dist.argmax())] arg_spans = [child.boundary for child in event] _, _, arg_dist = predictor.force_decode( sent, event.boundary, parent_label, arg_spans ) for arg, dist in zip(event, arg_dist): arg_count[arg_list.index(arg.label)] += dist return event_count, arg_count def create_vocab(events, args): vocab = Vocabulary() vocab.add_token_to_namespace(VIRTUAL_ROOT, 'span_label') for event in events: vocab.add_token_to_namespace(event, 'span_label') for arg in args: vocab.add_token_to_namespace(arg, 'span_label') return vocab def count_data(annotations: Iterable[Span]): event_cnt, arg_cnt = defaultdict(int), defaultdict(int) for sent in annotations: for event in sent: event_cnt[event.label] += 1 for arg in event: arg_cnt[arg.label] += 1 return dict(event_cnt), dict(arg_cnt) def gen_mapping( src_label: List[str], src_count: Dict[str, int], tgt_onto: List[str], tgt_label: List[str], cooccur_count: np.ndarray ): """ :param src_label: Src label list, including events and args. :param src_count: Src label count, event or arg. :param tgt_onto: Target label list, only event or arg. :param tgt_label: Target label count, event or arg. :param cooccur_count: Co-occurrence counting table. :return: Mapping dict. """ onto2label = np.zeros([len(tgt_onto), len(tgt_label)], dtype=np.float) for onto_idx, onto_tag in enumerate(tgt_onto): onto2label[onto_idx, tgt_label.index(onto_tag)] = 1.0 ret = dict() for src_tag, src_freq in src_count.items(): if src_tag in src_label: src_idx = src_label.index(src_tag) ret[src_tag] = list((cooccur_count[:, src_idx] / src_freq) @ onto2label) return ret def ontology_map( model_path, src_data: List[Tuple[List[str], Span]], tgt_data: List[Tuple[List[str], Span]], device: int, dst_path: str, meta: Optional[dict] = None, ) -> None: ret = {'meta': meta or {}} data = {'src': {}, 'tgt': {}} for name, datasets in [['src', src_data], ['tgt', tgt_data]]: d = data[name] d['sentences'], d['annotations'] = zip(*datasets) d['event_cnt'], d['arg_cnt'] = count_data(d['annotations']) d['event'], d['arg'] = list(d['event_cnt']), list(d['arg_cnt']) predictor = SpanPredictor.from_path(model_path, cuda_device=device) tgt_vocab = create_vocab(data['tgt']['event'], data['tgt']['arg']) for name, vocab in [['src', predictor.vocab], ['tgt', tgt_vocab]]: data[name]['label'] = [ vocab.get_index_to_token_vocabulary('span_label')[i] for i in range(vocab.get_vocab_size('span_label')) ] data['event'], data['arg'] = co_occur( predictor, tgt_data, data['tgt']['event'], data['tgt']['arg'] ) mapping = {} for layer in ['event', 'arg']: mapping[layer] = gen_mapping( data['src']['label'], data['src'][layer+'_cnt'], data['tgt'][layer], data['tgt']['label'], data[layer] ) for key, name in [['source', 'src'], ['target', 'tgt']]: ret[key] = { 'label': data[name]['label'], 'event': data[name]['event'], 'argument': data[name]['arg'] } ret['mapping'] = { 'event': mapping['event'], 'argument': mapping['arg'] } os.makedirs(dst_path, exist_ok=True) with open(os.path.join(dst_path, 'ontology_mapping.json'), 'w') as fp: json.dump(ret, fp) with open(os.path.join(dst_path, 'ontology.tsv'), 'w') as fp: to_dump = list() to_dump.append('\t'.join([VIRTUAL_ROOT] + ret['target']['event'])) for event in ret['target']['event']: to_dump.append('\t'.join([event] + ret['target']['argument'])) fp.write('\n'.join(to_dump)) tgt_vocab.save_to_files(os.path.join(dst_path, 'vocabulary'))