Gosse Minnema
Initial commit
05922fb
raw
history blame
5.24 kB
import json
import os
from collections import defaultdict
from typing import *
import numpy as np
from allennlp.data import Vocabulary
from tqdm import tqdm
from sftp import SpanPredictor, Span
from sftp.utils import VIRTUAL_ROOT
def read_framenet(path: str):
ret = list()
for line in map(json.loads, open(path).readlines()):
ret.append((line['tokens'], Span.from_json(line['annotations'])))
return ret
def co_occur(
predictor: SpanPredictor,
sentences: List[Tuple[List[str], Span]],
event_list: List[str],
arg_list: List[str],
):
idx2label = predictor.vocab.get_index_to_token_vocabulary('span_label')
event_count = np.zeros([len(event_list), len(idx2label)], np.float64)
arg_count = np.zeros([len(arg_list), len(idx2label)], np.float64)
for sent, vr in tqdm(sentences):
# For events
_, _, event_dist = predictor.force_decode(sent, child_spans=[event.boundary for event in vr])
for event, dist in zip(vr, event_dist):
event_count[event_list.index(event.label)] += dist
# For args
for event, one_event_dist in zip(vr, event_dist):
parent_label = idx2label[int(one_event_dist.argmax())]
arg_spans = [child.boundary for child in event]
_, _, arg_dist = predictor.force_decode(
sent, event.boundary, parent_label, arg_spans
)
for arg, dist in zip(event, arg_dist):
arg_count[arg_list.index(arg.label)] += dist
return event_count, arg_count
def create_vocab(events, args):
vocab = Vocabulary()
vocab.add_token_to_namespace(VIRTUAL_ROOT, 'span_label')
for event in events:
vocab.add_token_to_namespace(event, 'span_label')
for arg in args:
vocab.add_token_to_namespace(arg, 'span_label')
return vocab
def count_data(annotations: Iterable[Span]):
event_cnt, arg_cnt = defaultdict(int), defaultdict(int)
for sent in annotations:
for event in sent:
event_cnt[event.label] += 1
for arg in event:
arg_cnt[arg.label] += 1
return dict(event_cnt), dict(arg_cnt)
def gen_mapping(
src_label: List[str], src_count: Dict[str, int],
tgt_onto: List[str], tgt_label: List[str],
cooccur_count: np.ndarray
):
"""
:param src_label: Src label list, including events and args.
:param src_count: Src label count, event or arg.
:param tgt_onto: Target label list, only event or arg.
:param tgt_label: Target label count, event or arg.
:param cooccur_count: Co-occurrence counting table.
:return: Mapping dict.
"""
onto2label = np.zeros([len(tgt_onto), len(tgt_label)], dtype=np.float)
for onto_idx, onto_tag in enumerate(tgt_onto):
onto2label[onto_idx, tgt_label.index(onto_tag)] = 1.0
ret = dict()
for src_tag, src_freq in src_count.items():
if src_tag in src_label:
src_idx = src_label.index(src_tag)
ret[src_tag] = list((cooccur_count[:, src_idx] / src_freq) @ onto2label)
return ret
def ontology_map(
model_path,
src_data: List[Tuple[List[str], Span]],
tgt_data: List[Tuple[List[str], Span]],
device: int,
dst_path: str,
meta: Optional[dict] = None,
) -> None:
ret = {'meta': meta or {}}
data = {'src': {}, 'tgt': {}}
for name, datasets in [['src', src_data], ['tgt', tgt_data]]:
d = data[name]
d['sentences'], d['annotations'] = zip(*datasets)
d['event_cnt'], d['arg_cnt'] = count_data(d['annotations'])
d['event'], d['arg'] = list(d['event_cnt']), list(d['arg_cnt'])
predictor = SpanPredictor.from_path(model_path, cuda_device=device)
tgt_vocab = create_vocab(data['tgt']['event'], data['tgt']['arg'])
for name, vocab in [['src', predictor.vocab], ['tgt', tgt_vocab]]:
data[name]['label'] = [
vocab.get_index_to_token_vocabulary('span_label')[i] for i in range(vocab.get_vocab_size('span_label'))
]
data['event'], data['arg'] = co_occur(
predictor, tgt_data, data['tgt']['event'], data['tgt']['arg']
)
mapping = {}
for layer in ['event', 'arg']:
mapping[layer] = gen_mapping(
data['src']['label'], data['src'][layer+'_cnt'], data['tgt'][layer], data['tgt']['label'], data[layer]
)
for key, name in [['source', 'src'], ['target', 'tgt']]:
ret[key] = {
'label': data[name]['label'],
'event': data[name]['event'],
'argument': data[name]['arg']
}
ret['mapping'] = {
'event': mapping['event'],
'argument': mapping['arg']
}
os.makedirs(dst_path, exist_ok=True)
with open(os.path.join(dst_path, 'ontology_mapping.json'), 'w') as fp:
json.dump(ret, fp)
with open(os.path.join(dst_path, 'ontology.tsv'), 'w') as fp:
to_dump = list()
to_dump.append('\t'.join([VIRTUAL_ROOT] + ret['target']['event']))
for event in ret['target']['event']:
to_dump.append('\t'.join([event] + ret['target']['argument']))
fp.write('\n'.join(to_dump))
tgt_vocab.save_to_files(os.path.join(dst_path, 'vocabulary'))