Spaces:
Sleeping
Sleeping
import json | |
import os | |
from collections import defaultdict | |
from typing import * | |
import numpy as np | |
from allennlp.data import Vocabulary | |
from tqdm import tqdm | |
from sftp import SpanPredictor, Span | |
from sftp.utils import VIRTUAL_ROOT | |
def read_framenet(path: str): | |
ret = list() | |
for line in map(json.loads, open(path).readlines()): | |
ret.append((line['tokens'], Span.from_json(line['annotations']))) | |
return ret | |
def co_occur( | |
predictor: SpanPredictor, | |
sentences: List[Tuple[List[str], Span]], | |
event_list: List[str], | |
arg_list: List[str], | |
): | |
idx2label = predictor.vocab.get_index_to_token_vocabulary('span_label') | |
event_count = np.zeros([len(event_list), len(idx2label)], np.float64) | |
arg_count = np.zeros([len(arg_list), len(idx2label)], np.float64) | |
for sent, vr in tqdm(sentences): | |
# For events | |
_, _, event_dist = predictor.force_decode(sent, child_spans=[event.boundary for event in vr]) | |
for event, dist in zip(vr, event_dist): | |
event_count[event_list.index(event.label)] += dist | |
# For args | |
for event, one_event_dist in zip(vr, event_dist): | |
parent_label = idx2label[int(one_event_dist.argmax())] | |
arg_spans = [child.boundary for child in event] | |
_, _, arg_dist = predictor.force_decode( | |
sent, event.boundary, parent_label, arg_spans | |
) | |
for arg, dist in zip(event, arg_dist): | |
arg_count[arg_list.index(arg.label)] += dist | |
return event_count, arg_count | |
def create_vocab(events, args): | |
vocab = Vocabulary() | |
vocab.add_token_to_namespace(VIRTUAL_ROOT, 'span_label') | |
for event in events: | |
vocab.add_token_to_namespace(event, 'span_label') | |
for arg in args: | |
vocab.add_token_to_namespace(arg, 'span_label') | |
return vocab | |
def count_data(annotations: Iterable[Span]): | |
event_cnt, arg_cnt = defaultdict(int), defaultdict(int) | |
for sent in annotations: | |
for event in sent: | |
event_cnt[event.label] += 1 | |
for arg in event: | |
arg_cnt[arg.label] += 1 | |
return dict(event_cnt), dict(arg_cnt) | |
def gen_mapping( | |
src_label: List[str], src_count: Dict[str, int], | |
tgt_onto: List[str], tgt_label: List[str], | |
cooccur_count: np.ndarray | |
): | |
""" | |
:param src_label: Src label list, including events and args. | |
:param src_count: Src label count, event or arg. | |
:param tgt_onto: Target label list, only event or arg. | |
:param tgt_label: Target label count, event or arg. | |
:param cooccur_count: Co-occurrence counting table. | |
:return: Mapping dict. | |
""" | |
onto2label = np.zeros([len(tgt_onto), len(tgt_label)], dtype=np.float) | |
for onto_idx, onto_tag in enumerate(tgt_onto): | |
onto2label[onto_idx, tgt_label.index(onto_tag)] = 1.0 | |
ret = dict() | |
for src_tag, src_freq in src_count.items(): | |
if src_tag in src_label: | |
src_idx = src_label.index(src_tag) | |
ret[src_tag] = list((cooccur_count[:, src_idx] / src_freq) @ onto2label) | |
return ret | |
def ontology_map( | |
model_path, | |
src_data: List[Tuple[List[str], Span]], | |
tgt_data: List[Tuple[List[str], Span]], | |
device: int, | |
dst_path: str, | |
meta: Optional[dict] = None, | |
) -> None: | |
ret = {'meta': meta or {}} | |
data = {'src': {}, 'tgt': {}} | |
for name, datasets in [['src', src_data], ['tgt', tgt_data]]: | |
d = data[name] | |
d['sentences'], d['annotations'] = zip(*datasets) | |
d['event_cnt'], d['arg_cnt'] = count_data(d['annotations']) | |
d['event'], d['arg'] = list(d['event_cnt']), list(d['arg_cnt']) | |
predictor = SpanPredictor.from_path(model_path, cuda_device=device) | |
tgt_vocab = create_vocab(data['tgt']['event'], data['tgt']['arg']) | |
for name, vocab in [['src', predictor.vocab], ['tgt', tgt_vocab]]: | |
data[name]['label'] = [ | |
vocab.get_index_to_token_vocabulary('span_label')[i] for i in range(vocab.get_vocab_size('span_label')) | |
] | |
data['event'], data['arg'] = co_occur( | |
predictor, tgt_data, data['tgt']['event'], data['tgt']['arg'] | |
) | |
mapping = {} | |
for layer in ['event', 'arg']: | |
mapping[layer] = gen_mapping( | |
data['src']['label'], data['src'][layer+'_cnt'], data['tgt'][layer], data['tgt']['label'], data[layer] | |
) | |
for key, name in [['source', 'src'], ['target', 'tgt']]: | |
ret[key] = { | |
'label': data[name]['label'], | |
'event': data[name]['event'], | |
'argument': data[name]['arg'] | |
} | |
ret['mapping'] = { | |
'event': mapping['event'], | |
'argument': mapping['arg'] | |
} | |
os.makedirs(dst_path, exist_ok=True) | |
with open(os.path.join(dst_path, 'ontology_mapping.json'), 'w') as fp: | |
json.dump(ret, fp) | |
with open(os.path.join(dst_path, 'ontology.tsv'), 'w') as fp: | |
to_dump = list() | |
to_dump.append('\t'.join([VIRTUAL_ROOT] + ret['target']['event'])) | |
for event in ret['target']['event']: | |
to_dump.append('\t'.join([event] + ret['target']['argument'])) | |
fp.write('\n'.join(to_dump)) | |
tgt_vocab.save_to_files(os.path.join(dst_path, 'vocabulary')) | |