diff --git a/__pycache__/parse.cpython-38.pyc b/__pycache__/parse.cpython-38.pyc index 91f1165b32821bd9de19f13c1cc2eefee1394ea6..efc1dd9f5d8ded68c8d8d07fd7fb0e71f12f499e 100644 Binary files a/__pycache__/parse.cpython-38.pyc and b/__pycache__/parse.cpython-38.pyc differ diff --git a/app.py b/app.py index f1c54fb8b3fee713ce3296bd3208a8146ab3cbee..7e32cb2603cb8ae19556299fe28beb9de6b9d24e 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,5 @@ import streamlit as st -# from parse import parse_text +from parse import parse from nltk import Tree import pandas as pd import re @@ -31,19 +31,21 @@ if text: df = pd.DataFrame(zipped, columns=['Token', 'Tag', 'Prob.']) - # # Convert the bracket parse tree into an NLTK Tree - # t = Tree.fromstring(re.sub(r'(\.[^ )]+)+', '', parse_tree)) + parse_tree = parse(tokens) - # tree_svg = TreePrettyPrinter(t).svg(nodecolor='black', leafcolor='black', funccolor='black') + # Convert the bracket parse tree into an NLTK Tree + t = Tree.fromstring(re.sub(r'-[^ )]*', '', parse_tree)) + + tree_svg = TreePrettyPrinter(t).svg(nodecolor='black', leafcolor='black', funccolor='black') col1 = st.columns(1)[0] col1.header("POS tagging result:") col1.table(df) -# col2 = st.columns(1)[0] -# col2.header("Parsing result:") -# col2.write(parse_tree.replace('_', '\_').replace('$', '\$').replace('*', '\*')) + col2 = st.columns(1)[0] + col2.header("Parsing result:") + col2.write(parse_tree.replace('_', '\_').replace('$', '\$').replace('*', '\*')) -# # Display the graph in the Streamlit app -# col2.image(tree_svg, use_column_width=True) +# Display the graph in the Streamlit app + col2.image(tree_svg, use_column_width=True) diff --git a/benepar/__init__.py b/benepar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6ad660648d0c407a0cda608e0e5fc027ca33c5 --- /dev/null +++ b/benepar/__init__.py @@ -0,0 +1,20 @@ +""" +benepar: Berkeley Neural Parser +""" + +# This file and all code in integrations/ relate to the version of the parser +# released via PyPI. If you only need to run research experiments, it is safe +# to delete the integrations/ folder and replace this __init__.py with an +# empty file. + +__all__ = [ + "Parser", + "InputSentence", + "download", + "BeneparComponent", + "NonConstituentException", +] + +from .integrations.downloader import download +from .integrations.nltk_plugin import Parser, InputSentence +from .integrations.spacy_plugin import BeneparComponent, NonConstituentException diff --git a/benepar/__pycache__/__init__.cpython-310.pyc b/benepar/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45a19cfab4c95912135340975ee7c80d9e6653c9 Binary files /dev/null and b/benepar/__pycache__/__init__.cpython-310.pyc differ diff --git a/benepar/__pycache__/__init__.cpython-37.pyc b/benepar/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d77d11129d29385632095630e86529454267644 Binary files /dev/null and b/benepar/__pycache__/__init__.cpython-37.pyc differ diff --git a/benepar/__pycache__/__init__.cpython-38.pyc b/benepar/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da1ed57d62dc6473027badf2011e8f38258aa7f Binary files /dev/null and b/benepar/__pycache__/__init__.cpython-38.pyc differ diff --git a/benepar/__pycache__/char_lstm.cpython-310.pyc b/benepar/__pycache__/char_lstm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc3c7ea69a2b62715f0b44ad8ebb01a9b76a1072 Binary files /dev/null and b/benepar/__pycache__/char_lstm.cpython-310.pyc differ diff --git a/benepar/__pycache__/char_lstm.cpython-37.pyc b/benepar/__pycache__/char_lstm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..660cc6476b5690ff29bf5f7529664bec9db83385 Binary files /dev/null and b/benepar/__pycache__/char_lstm.cpython-37.pyc differ diff --git a/benepar/__pycache__/char_lstm.cpython-38.pyc b/benepar/__pycache__/char_lstm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..323b09a8ec3d934999aaad1cb482e3b10a75a092 Binary files /dev/null and b/benepar/__pycache__/char_lstm.cpython-38.pyc differ diff --git a/benepar/__pycache__/decode_chart.cpython-310.pyc b/benepar/__pycache__/decode_chart.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e3431b0d230bd82050ba9f408cb7cc8eddfaea1 Binary files /dev/null and b/benepar/__pycache__/decode_chart.cpython-310.pyc differ diff --git a/benepar/__pycache__/decode_chart.cpython-37.pyc b/benepar/__pycache__/decode_chart.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58c7245b3c10c4ac4a03e513b40e702e54bf6dd8 Binary files /dev/null and b/benepar/__pycache__/decode_chart.cpython-37.pyc differ diff --git a/benepar/__pycache__/decode_chart.cpython-38.pyc b/benepar/__pycache__/decode_chart.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff87ed9f80b99e40af476df23eb396a8a2ba351a Binary files /dev/null and b/benepar/__pycache__/decode_chart.cpython-38.pyc differ diff --git a/benepar/__pycache__/nkutil.cpython-310.pyc b/benepar/__pycache__/nkutil.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4e63dc023c3ea8621d9c30a145aa344c429e029 Binary files /dev/null and b/benepar/__pycache__/nkutil.cpython-310.pyc differ diff --git a/benepar/__pycache__/nkutil.cpython-37.pyc b/benepar/__pycache__/nkutil.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2db77757aaf4cf3c9d6fc39ef488f8ceae44e031 Binary files /dev/null and b/benepar/__pycache__/nkutil.cpython-37.pyc differ diff --git a/benepar/__pycache__/nkutil.cpython-38.pyc b/benepar/__pycache__/nkutil.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07a6ad6c28301eb8c4502c11bb1250aa9afb053f Binary files /dev/null and b/benepar/__pycache__/nkutil.cpython-38.pyc differ diff --git a/benepar/__pycache__/parse_base.cpython-310.pyc b/benepar/__pycache__/parse_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c85df1217db38b533bbe05dd804240c97ed2f73 Binary files /dev/null and b/benepar/__pycache__/parse_base.cpython-310.pyc differ diff --git a/benepar/__pycache__/parse_base.cpython-37.pyc b/benepar/__pycache__/parse_base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2011e6e6e6c9f547d20e43d5f496ebfe29a48a8 Binary files /dev/null and b/benepar/__pycache__/parse_base.cpython-37.pyc differ diff --git a/benepar/__pycache__/parse_base.cpython-38.pyc b/benepar/__pycache__/parse_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a953338cc7b745be1886527e5cec3cc6d7a37a9c Binary files /dev/null and b/benepar/__pycache__/parse_base.cpython-38.pyc differ diff --git a/benepar/__pycache__/parse_chart.cpython-310.pyc b/benepar/__pycache__/parse_chart.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c508f6f3183df932a0201463dd700baae36b1972 Binary files /dev/null and b/benepar/__pycache__/parse_chart.cpython-310.pyc differ diff --git a/benepar/__pycache__/parse_chart.cpython-37.pyc b/benepar/__pycache__/parse_chart.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94b981d88bcbc5884b57712e367ce0caff0a2d8d Binary files /dev/null and b/benepar/__pycache__/parse_chart.cpython-37.pyc differ diff --git a/benepar/__pycache__/parse_chart.cpython-38.pyc b/benepar/__pycache__/parse_chart.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9308df4c1fe8573b564c26cb9043757e0347165c Binary files /dev/null and b/benepar/__pycache__/parse_chart.cpython-38.pyc differ diff --git a/benepar/__pycache__/partitioned_transformer.cpython-310.pyc b/benepar/__pycache__/partitioned_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6db3303e19ba14903564c9461db9844040a9bb11 Binary files /dev/null and b/benepar/__pycache__/partitioned_transformer.cpython-310.pyc differ diff --git a/benepar/__pycache__/partitioned_transformer.cpython-37.pyc b/benepar/__pycache__/partitioned_transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58454aa626b29a0d715096a0a094576700177d01 Binary files /dev/null and b/benepar/__pycache__/partitioned_transformer.cpython-37.pyc differ diff --git a/benepar/__pycache__/partitioned_transformer.cpython-38.pyc b/benepar/__pycache__/partitioned_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74d8b4ef4b30654c29ace3cc6034c23e85f283f7 Binary files /dev/null and b/benepar/__pycache__/partitioned_transformer.cpython-38.pyc differ diff --git a/benepar/__pycache__/ptb_unescape.cpython-310.pyc b/benepar/__pycache__/ptb_unescape.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f10e64bc86bcbbad136506cf917bda48c0dee1aa Binary files /dev/null and b/benepar/__pycache__/ptb_unescape.cpython-310.pyc differ diff --git a/benepar/__pycache__/ptb_unescape.cpython-37.pyc b/benepar/__pycache__/ptb_unescape.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d6b52653a2dacc6450a2ebd7e69dd9d0f5d9423 Binary files /dev/null and b/benepar/__pycache__/ptb_unescape.cpython-37.pyc differ diff --git a/benepar/__pycache__/ptb_unescape.cpython-38.pyc b/benepar/__pycache__/ptb_unescape.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e59b7b3cc54d485edf5a05c299f2dbf47186abd Binary files /dev/null and b/benepar/__pycache__/ptb_unescape.cpython-38.pyc differ diff --git a/benepar/__pycache__/retokenization.cpython-310.pyc b/benepar/__pycache__/retokenization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..143dbcadd97f9b68aefc3f0d8d1f808d17fd85f6 Binary files /dev/null and b/benepar/__pycache__/retokenization.cpython-310.pyc differ diff --git a/benepar/__pycache__/retokenization.cpython-37.pyc b/benepar/__pycache__/retokenization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5ac8095a1a2de0062fd5b84d767b411ae15a9fd Binary files /dev/null and b/benepar/__pycache__/retokenization.cpython-37.pyc differ diff --git a/benepar/__pycache__/retokenization.cpython-38.pyc b/benepar/__pycache__/retokenization.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32215cdfc07304ad96dbcaf1709f3c7c7807ffec Binary files /dev/null and b/benepar/__pycache__/retokenization.cpython-38.pyc differ diff --git a/benepar/__pycache__/subbatching.cpython-310.pyc b/benepar/__pycache__/subbatching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d8eceffd6ae995f21f1b4d9c2bb095b3a5c741 Binary files /dev/null and b/benepar/__pycache__/subbatching.cpython-310.pyc differ diff --git a/benepar/__pycache__/subbatching.cpython-37.pyc b/benepar/__pycache__/subbatching.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f421b80679299e0a14e40db0572f5d3e5c560319 Binary files /dev/null and b/benepar/__pycache__/subbatching.cpython-37.pyc differ diff --git a/benepar/__pycache__/subbatching.cpython-38.pyc b/benepar/__pycache__/subbatching.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38b82df09d7bb4ef8b02392c20f4d370b978b94f Binary files /dev/null and b/benepar/__pycache__/subbatching.cpython-38.pyc differ diff --git a/benepar/char_lstm.py b/benepar/char_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..0aefc5c18959865e9a75cbb476b21e0d2afd5678 --- /dev/null +++ b/benepar/char_lstm.py @@ -0,0 +1,160 @@ +""" +Character LSTM implementation (matches https://arxiv.org/pdf/1805.01052.pdf) +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CharacterLSTM(nn.Module): + def __init__(self, num_embeddings, d_embedding, d_out, char_dropout=0.0, **kwargs): + super().__init__() + + self.d_embedding = d_embedding + self.d_out = d_out + + self.lstm = nn.LSTM( + self.d_embedding, self.d_out // 2, num_layers=1, bidirectional=True + ) + + self.emb = nn.Embedding(num_embeddings, self.d_embedding, **kwargs) + self.char_dropout = nn.Dropout(char_dropout) + + def forward(self, chars_packed, valid_token_mask): + inp_embs = nn.utils.rnn.PackedSequence( + self.char_dropout(self.emb(chars_packed.data)), + batch_sizes=chars_packed.batch_sizes, + sorted_indices=chars_packed.sorted_indices, + unsorted_indices=chars_packed.unsorted_indices, + ) + + _, (lstm_out, _) = self.lstm(inp_embs) + lstm_out = torch.cat([lstm_out[0], lstm_out[1]], -1) + + # Switch to a representation where there are dummy vectors for invalid + # tokens generated by padding. + res = lstm_out.new_zeros( + (valid_token_mask.shape[0], valid_token_mask.shape[1], lstm_out.shape[-1]) + ) + res[valid_token_mask] = lstm_out + return res + + +class RetokenizerForCharLSTM: + # Assumes that these control characters are not present in treebank text + CHAR_UNK = "\0" + CHAR_ID_UNK = 0 + CHAR_START_SENTENCE = "\1" + CHAR_START_WORD = "\2" + CHAR_STOP_WORD = "\3" + CHAR_STOP_SENTENCE = "\4" + + def __init__(self, char_vocab): + self.char_vocab = char_vocab + + @classmethod + def build_vocab(cls, sentences): + char_set = set() + for sentence in sentences: + if isinstance(sentence, tuple): + sentence = sentence[0] + for word in sentence: + char_set |= set(word) + + # If codepoints are small (e.g. Latin alphabet), index by codepoint + # directly + highest_codepoint = max(ord(char) for char in char_set) + if highest_codepoint < 512: + if highest_codepoint < 256: + highest_codepoint = 256 + else: + highest_codepoint = 512 + + char_vocab = {} + # This also takes care of constants like CHAR_UNK, etc. + for codepoint in range(highest_codepoint): + char_vocab[chr(codepoint)] = codepoint + return char_vocab + else: + char_vocab = {} + char_vocab[cls.CHAR_UNK] = 0 + char_vocab[cls.CHAR_START_SENTENCE] = 1 + char_vocab[cls.CHAR_START_WORD] = 2 + char_vocab[cls.CHAR_STOP_WORD] = 3 + char_vocab[cls.CHAR_STOP_SENTENCE] = 4 + for id_, char in enumerate(sorted(char_set), start=5): + char_vocab[char] = id_ + return char_vocab + + def __call__(self, words, space_after="ignored", return_tensors=None): + if return_tensors != "np": + raise NotImplementedError("Only return_tensors='np' is supported.") + + res = {} + + # Sentence-level start/stop tokens are encoded as 3 pseudo-chars + # Within each word, account for 2 start/stop characters + max_word_len = max(3, max(len(word) for word in words)) + 2 + char_ids = np.zeros((len(words) + 2, max_word_len), dtype=int) + word_lens = np.zeros(len(words) + 2, dtype=int) + + char_ids[0, :5] = [ + self.char_vocab[self.CHAR_START_WORD], + self.char_vocab[self.CHAR_START_SENTENCE], + self.char_vocab[self.CHAR_START_SENTENCE], + self.char_vocab[self.CHAR_START_SENTENCE], + self.char_vocab[self.CHAR_STOP_WORD], + ] + word_lens[0] = 5 + for i, word in enumerate(words, start=1): + char_ids[i, 0] = self.char_vocab[self.CHAR_START_WORD] + for j, char in enumerate(word, start=1): + char_ids[i, j] = self.char_vocab.get(char, self.CHAR_ID_UNK) + char_ids[i, j + 1] = self.char_vocab[self.CHAR_STOP_WORD] + word_lens[i] = j + 2 + char_ids[i + 1, :5] = [ + self.char_vocab[self.CHAR_START_WORD], + self.char_vocab[self.CHAR_STOP_SENTENCE], + self.char_vocab[self.CHAR_STOP_SENTENCE], + self.char_vocab[self.CHAR_STOP_SENTENCE], + self.char_vocab[self.CHAR_STOP_WORD], + ] + word_lens[i + 1] = 5 + + res["char_ids"] = char_ids + res["word_lens"] = word_lens + res["valid_token_mask"] = np.ones_like(word_lens, dtype=bool) + + return res + + def pad(self, examples, return_tensors=None): + if return_tensors != "pt": + raise NotImplementedError("Only return_tensors='pt' is supported.") + max_word_len = max(example["char_ids"].shape[-1] for example in examples) + char_ids = torch.cat( + [ + F.pad( + torch.tensor(example["char_ids"]), + (0, max_word_len - example["char_ids"].shape[-1]), + ) + for example in examples + ] + ) + word_lens = torch.cat( + [torch.tensor(example["word_lens"]) for example in examples] + ) + valid_token_mask = nn.utils.rnn.pad_sequence( + [torch.tensor(example["valid_token_mask"]) for example in examples], + batch_first=True, + padding_value=False, + ) + + char_ids = nn.utils.rnn.pack_padded_sequence( + char_ids, word_lens, batch_first=True, enforce_sorted=False + ) + return { + "char_ids": char_ids, + "valid_token_mask": valid_token_mask, + } diff --git a/benepar/decode_chart.py b/benepar/decode_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..8d32ed1bdbe3bef17f509ceffdd1138267a36b0e --- /dev/null +++ b/benepar/decode_chart.py @@ -0,0 +1,291 @@ +""" +Parsing formulated as span classification (https://arxiv.org/abs/1705.03919) +""" + +import nltk +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_struct + +from .parse_base import CompressedParserOutput + + +def pad_charts(charts, padding_value=-100): + """Pad a list of variable-length charts with `padding_value`.""" + batch_size = len(charts) + max_len = max(chart.shape[0] for chart in charts) + padded_charts = torch.full( + (batch_size, max_len, max_len), + padding_value, + dtype=charts[0].dtype, + device=charts[0].device, + ) + for i, chart in enumerate(charts): + chart_size = chart.shape[0] + padded_charts[i, :chart_size, :chart_size] = chart + return padded_charts + + +def collapse_unary_strip_pos(tree, strip_top=True): + """Collapse unary chains and strip part of speech tags.""" + + def strip_pos(tree): + if len(tree) == 1 and isinstance(tree[0], str): + return tree[0] + else: + return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree]) + + collapsed_tree = strip_pos(tree) + collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::") + if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"): + if strip_top: + if len(collapsed_tree) == 1: + collapsed_tree = collapsed_tree[0] + else: + collapsed_tree.set_label("") + elif len(collapsed_tree) == 1: + collapsed_tree[0].set_label( + collapsed_tree.label() + "::" + collapsed_tree[0].label()) + collapsed_tree = collapsed_tree[0] + return collapsed_tree + + +def _get_labeled_spans(tree, spans_out, start): + if isinstance(tree, str): + return start + 1 + + assert len(tree) > 1 or isinstance( + tree[0], str + ), "Must call collapse_unary_strip_pos first" + end = start + for child in tree: + end = _get_labeled_spans(child, spans_out, end) + # Spans are returned as closed intervals on both ends + spans_out.append((start, end - 1, tree.label())) + return end + + +def get_labeled_spans(tree): + """Converts a tree into a list of labeled spans. + + Args: + tree: an nltk.tree.Tree object + + Returns: + A list of (span_start, span_end, span_label) tuples. The start and end + indices indicate the first and last words of the span (a closed + interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will + result in a single span labeled "S+VP". + """ + tree = collapse_unary_strip_pos(tree) + spans_out = [] + _get_labeled_spans(tree, spans_out, start=0) + return spans_out + + +def uncollapse_unary(tree, ensure_top=False): + """Un-collapse unary chains.""" + if isinstance(tree, str): + return tree + else: + labels = tree.label().split("::") + if ensure_top and labels[0] != "TOP": + labels = ["TOP"] + labels + children = [] + for child in tree: + child = uncollapse_unary(child) + children.append(child) + for label in labels[::-1]: + children = [nltk.tree.Tree(label, children)] + return children[0] + + +class ChartDecoder: + """A chart decoder for parsing formulated as span classification.""" + + def __init__(self, label_vocab, force_root_constituent=True): + """Constructs a new ChartDecoder object. + Args: + label_vocab: A mapping from span labels to integer indices. + """ + self.label_vocab = label_vocab + self.label_from_index = {i: label for label, i in label_vocab.items()} + self.force_root_constituent = force_root_constituent + + @staticmethod + def build_vocab(trees): + label_set = set() + for tree in trees: + for _, _, label in get_labeled_spans(tree): + if label: + label_set.add(label) + label_set = [""] + sorted(label_set) + return {label: i for i, label in enumerate(label_set)} + + @staticmethod + def infer_force_root_constituent(trees): + for tree in trees: + for _, _, label in get_labeled_spans(tree): + if not label: + return False + return True + + def chart_from_tree(self, tree): + spans = get_labeled_spans(tree) + num_words = len(tree.leaves()) + chart = np.full((num_words, num_words), -100, dtype=int) + chart = np.tril(chart, -1) + # Now all invalid entries are filled with -100, and valid entries with 0 + for start, end, label in spans: + # Previously unseen unary chains can occur in the dev/test sets. + # For now, we ignore them and don't mark the corresponding chart + # entry as a constituent. + if label in self.label_vocab: + chart[start, end] = self.label_vocab[label] + return chart + + def charts_from_pytorch_scores_batched(self, scores, lengths): + """Runs CKY to recover span labels from scores (e.g. logits). + + This method uses pytorch-struct to speed up decoding compared to the + pure-Python implementation of CKY used by tree_from_scores(). + + Args: + scores: a pytorch tensor of shape (batch size, max length, + max length, label vocab size). + lengths: a pytorch tensor of shape (batch size,) + + Returns: + A list of numpy arrays, each of shape (sentence length, sentence + length). + """ + scores = scores.detach() + scores = scores - scores[..., :1] + if self.force_root_constituent: + scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9 + dist = torch_struct.TreeCRF(scores, lengths=lengths) + amax = dist.argmax + amax[..., 0] += 1e-9 + padded_charts = amax.argmax(-1) + padded_charts = padded_charts.detach().cpu().numpy() + return [ + chart[:length, :length] for chart, length in zip(padded_charts, lengths) + ] + + def compressed_output_from_chart(self, chart): + chart_with_filled_diagonal = chart.copy() + np.fill_diagonal(chart_with_filled_diagonal, 1) + chart_with_filled_diagonal[0, -1] = 1 + starts, inclusive_ends = np.where(chart_with_filled_diagonal) + preorder_sort = np.lexsort((-inclusive_ends, starts)) + starts = starts[preorder_sort] + inclusive_ends = inclusive_ends[preorder_sort] + labels = chart[starts, inclusive_ends] + ends = inclusive_ends + 1 + return CompressedParserOutput(starts=starts, ends=ends, labels=labels) + + def tree_from_chart(self, chart, leaves): + compressed_output = self.compressed_output_from_chart(chart) + return compressed_output.to_tree(leaves, self.label_from_index) + + def tree_from_scores(self, scores, leaves): + """Runs CKY to decode a tree from scores (e.g. logits). + + If speed is important, consider using charts_from_pytorch_scores_batched + followed by compressed_output_from_chart or tree_from_chart instead. + + Args: + scores: a chart of scores (or logits) of shape + (sentence length, sentence length, label vocab size). The first + two dimensions may be padded to a longer length, but all padded + values will be ignored. + leaves: the leaf nodes to use in the constructed tree. These + may be of type str or nltk.Tree, or (word, tag) tuples that + will be used to construct the leaf node objects. + + Returns: + An nltk.Tree object. + """ + leaves = [ + nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node + for node in leaves + ] + + chart = {} + scores = scores - scores[:, :, 0, None] + for length in range(1, len(leaves) + 1): + for left in range(0, len(leaves) + 1 - length): + right = left + length + + label_scores = scores[left, right - 1] + label_scores = label_scores - label_scores[0] + + argmax_label_index = int( + label_scores.argmax() + if length < len(leaves) or not self.force_root_constituent + else label_scores[1:].argmax() + 1 + ) + argmax_label = self.label_from_index[argmax_label_index] + label = argmax_label + label_score = label_scores[argmax_label_index] + + if length == 1: + tree = leaves[left] + if label: + tree = nltk.tree.Tree(label, [tree]) + chart[left, right] = [tree], label_score + continue + + best_split = max( + range(left + 1, right), + key=lambda split: (chart[left, split][1] + chart[split, right][1]), + ) + + left_trees, left_score = chart[left, best_split] + right_trees, right_score = chart[best_split, right] + + children = left_trees + right_trees + if label: + children = [nltk.tree.Tree(label, children)] + + chart[left, right] = (children, label_score + left_score + right_score) + + children, score = chart[0, len(leaves)] + tree = nltk.tree.Tree("TOP", children) + tree = uncollapse_unary(tree) + return tree + + +class SpanClassificationMarginLoss(nn.Module): + def __init__(self, force_root_constituent=True, reduction="mean"): + super().__init__() + self.force_root_constituent = force_root_constituent + if reduction not in ("none", "mean", "sum"): + raise ValueError(f"Invalid value for reduction: {reduction}") + self.reduction = reduction + + def forward(self, logits, labels): + gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1]) + + logits = logits - logits[..., :1] + lengths = (labels[:, 0, :] != -100).sum(-1) + augment = (1 - gold_event).to(torch.float) + if self.force_root_constituent: + augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9 + dist = torch_struct.TreeCRF(logits + augment, lengths=lengths) + + pred_score = dist.max + gold_score = (logits * gold_event).sum((1, 2, 3)) + + margin_losses = F.relu(pred_score - gold_score) + + if self.reduction == "none": + return margin_losses + elif self.reduction == "mean": + return margin_losses.mean() + elif self.reduction == "sum": + return margin_losses.sum() + else: + assert False, f"Unexpected reduction: {self.reduction}" diff --git a/benepar/decode_chart.py~ b/benepar/decode_chart.py~ new file mode 100644 index 0000000000000000000000000000000000000000..8d32ed1bdbe3bef17f509ceffdd1138267a36b0e --- /dev/null +++ b/benepar/decode_chart.py~ @@ -0,0 +1,291 @@ +""" +Parsing formulated as span classification (https://arxiv.org/abs/1705.03919) +""" + +import nltk +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_struct + +from .parse_base import CompressedParserOutput + + +def pad_charts(charts, padding_value=-100): + """Pad a list of variable-length charts with `padding_value`.""" + batch_size = len(charts) + max_len = max(chart.shape[0] for chart in charts) + padded_charts = torch.full( + (batch_size, max_len, max_len), + padding_value, + dtype=charts[0].dtype, + device=charts[0].device, + ) + for i, chart in enumerate(charts): + chart_size = chart.shape[0] + padded_charts[i, :chart_size, :chart_size] = chart + return padded_charts + + +def collapse_unary_strip_pos(tree, strip_top=True): + """Collapse unary chains and strip part of speech tags.""" + + def strip_pos(tree): + if len(tree) == 1 and isinstance(tree[0], str): + return tree[0] + else: + return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree]) + + collapsed_tree = strip_pos(tree) + collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::") + if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"): + if strip_top: + if len(collapsed_tree) == 1: + collapsed_tree = collapsed_tree[0] + else: + collapsed_tree.set_label("") + elif len(collapsed_tree) == 1: + collapsed_tree[0].set_label( + collapsed_tree.label() + "::" + collapsed_tree[0].label()) + collapsed_tree = collapsed_tree[0] + return collapsed_tree + + +def _get_labeled_spans(tree, spans_out, start): + if isinstance(tree, str): + return start + 1 + + assert len(tree) > 1 or isinstance( + tree[0], str + ), "Must call collapse_unary_strip_pos first" + end = start + for child in tree: + end = _get_labeled_spans(child, spans_out, end) + # Spans are returned as closed intervals on both ends + spans_out.append((start, end - 1, tree.label())) + return end + + +def get_labeled_spans(tree): + """Converts a tree into a list of labeled spans. + + Args: + tree: an nltk.tree.Tree object + + Returns: + A list of (span_start, span_end, span_label) tuples. The start and end + indices indicate the first and last words of the span (a closed + interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will + result in a single span labeled "S+VP". + """ + tree = collapse_unary_strip_pos(tree) + spans_out = [] + _get_labeled_spans(tree, spans_out, start=0) + return spans_out + + +def uncollapse_unary(tree, ensure_top=False): + """Un-collapse unary chains.""" + if isinstance(tree, str): + return tree + else: + labels = tree.label().split("::") + if ensure_top and labels[0] != "TOP": + labels = ["TOP"] + labels + children = [] + for child in tree: + child = uncollapse_unary(child) + children.append(child) + for label in labels[::-1]: + children = [nltk.tree.Tree(label, children)] + return children[0] + + +class ChartDecoder: + """A chart decoder for parsing formulated as span classification.""" + + def __init__(self, label_vocab, force_root_constituent=True): + """Constructs a new ChartDecoder object. + Args: + label_vocab: A mapping from span labels to integer indices. + """ + self.label_vocab = label_vocab + self.label_from_index = {i: label for label, i in label_vocab.items()} + self.force_root_constituent = force_root_constituent + + @staticmethod + def build_vocab(trees): + label_set = set() + for tree in trees: + for _, _, label in get_labeled_spans(tree): + if label: + label_set.add(label) + label_set = [""] + sorted(label_set) + return {label: i for i, label in enumerate(label_set)} + + @staticmethod + def infer_force_root_constituent(trees): + for tree in trees: + for _, _, label in get_labeled_spans(tree): + if not label: + return False + return True + + def chart_from_tree(self, tree): + spans = get_labeled_spans(tree) + num_words = len(tree.leaves()) + chart = np.full((num_words, num_words), -100, dtype=int) + chart = np.tril(chart, -1) + # Now all invalid entries are filled with -100, and valid entries with 0 + for start, end, label in spans: + # Previously unseen unary chains can occur in the dev/test sets. + # For now, we ignore them and don't mark the corresponding chart + # entry as a constituent. + if label in self.label_vocab: + chart[start, end] = self.label_vocab[label] + return chart + + def charts_from_pytorch_scores_batched(self, scores, lengths): + """Runs CKY to recover span labels from scores (e.g. logits). + + This method uses pytorch-struct to speed up decoding compared to the + pure-Python implementation of CKY used by tree_from_scores(). + + Args: + scores: a pytorch tensor of shape (batch size, max length, + max length, label vocab size). + lengths: a pytorch tensor of shape (batch size,) + + Returns: + A list of numpy arrays, each of shape (sentence length, sentence + length). + """ + scores = scores.detach() + scores = scores - scores[..., :1] + if self.force_root_constituent: + scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9 + dist = torch_struct.TreeCRF(scores, lengths=lengths) + amax = dist.argmax + amax[..., 0] += 1e-9 + padded_charts = amax.argmax(-1) + padded_charts = padded_charts.detach().cpu().numpy() + return [ + chart[:length, :length] for chart, length in zip(padded_charts, lengths) + ] + + def compressed_output_from_chart(self, chart): + chart_with_filled_diagonal = chart.copy() + np.fill_diagonal(chart_with_filled_diagonal, 1) + chart_with_filled_diagonal[0, -1] = 1 + starts, inclusive_ends = np.where(chart_with_filled_diagonal) + preorder_sort = np.lexsort((-inclusive_ends, starts)) + starts = starts[preorder_sort] + inclusive_ends = inclusive_ends[preorder_sort] + labels = chart[starts, inclusive_ends] + ends = inclusive_ends + 1 + return CompressedParserOutput(starts=starts, ends=ends, labels=labels) + + def tree_from_chart(self, chart, leaves): + compressed_output = self.compressed_output_from_chart(chart) + return compressed_output.to_tree(leaves, self.label_from_index) + + def tree_from_scores(self, scores, leaves): + """Runs CKY to decode a tree from scores (e.g. logits). + + If speed is important, consider using charts_from_pytorch_scores_batched + followed by compressed_output_from_chart or tree_from_chart instead. + + Args: + scores: a chart of scores (or logits) of shape + (sentence length, sentence length, label vocab size). The first + two dimensions may be padded to a longer length, but all padded + values will be ignored. + leaves: the leaf nodes to use in the constructed tree. These + may be of type str or nltk.Tree, or (word, tag) tuples that + will be used to construct the leaf node objects. + + Returns: + An nltk.Tree object. + """ + leaves = [ + nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node + for node in leaves + ] + + chart = {} + scores = scores - scores[:, :, 0, None] + for length in range(1, len(leaves) + 1): + for left in range(0, len(leaves) + 1 - length): + right = left + length + + label_scores = scores[left, right - 1] + label_scores = label_scores - label_scores[0] + + argmax_label_index = int( + label_scores.argmax() + if length < len(leaves) or not self.force_root_constituent + else label_scores[1:].argmax() + 1 + ) + argmax_label = self.label_from_index[argmax_label_index] + label = argmax_label + label_score = label_scores[argmax_label_index] + + if length == 1: + tree = leaves[left] + if label: + tree = nltk.tree.Tree(label, [tree]) + chart[left, right] = [tree], label_score + continue + + best_split = max( + range(left + 1, right), + key=lambda split: (chart[left, split][1] + chart[split, right][1]), + ) + + left_trees, left_score = chart[left, best_split] + right_trees, right_score = chart[best_split, right] + + children = left_trees + right_trees + if label: + children = [nltk.tree.Tree(label, children)] + + chart[left, right] = (children, label_score + left_score + right_score) + + children, score = chart[0, len(leaves)] + tree = nltk.tree.Tree("TOP", children) + tree = uncollapse_unary(tree) + return tree + + +class SpanClassificationMarginLoss(nn.Module): + def __init__(self, force_root_constituent=True, reduction="mean"): + super().__init__() + self.force_root_constituent = force_root_constituent + if reduction not in ("none", "mean", "sum"): + raise ValueError(f"Invalid value for reduction: {reduction}") + self.reduction = reduction + + def forward(self, logits, labels): + gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1]) + + logits = logits - logits[..., :1] + lengths = (labels[:, 0, :] != -100).sum(-1) + augment = (1 - gold_event).to(torch.float) + if self.force_root_constituent: + augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9 + dist = torch_struct.TreeCRF(logits + augment, lengths=lengths) + + pred_score = dist.max + gold_score = (logits * gold_event).sum((1, 2, 3)) + + margin_losses = F.relu(pred_score - gold_score) + + if self.reduction == "none": + return margin_losses + elif self.reduction == "mean": + return margin_losses.mean() + elif self.reduction == "sum": + return margin_losses.sum() + else: + assert False, f"Unexpected reduction: {self.reduction}" diff --git a/benepar/integrations/__init__.py b/benepar/integrations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/benepar/integrations/__pycache__/__init__.cpython-310.pyc b/benepar/integrations/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..472b8f490320e3043a0516e7b31f60831e7d940d Binary files /dev/null and b/benepar/integrations/__pycache__/__init__.cpython-310.pyc differ diff --git a/benepar/integrations/__pycache__/__init__.cpython-37.pyc b/benepar/integrations/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3982b477e04fd204a3b37ed2cfc84c310c1e48d2 Binary files /dev/null and b/benepar/integrations/__pycache__/__init__.cpython-37.pyc differ diff --git a/benepar/integrations/__pycache__/__init__.cpython-38.pyc b/benepar/integrations/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1bd7e040e9d6b1d4fe5638f0589647ea6cbb2fc Binary files /dev/null and b/benepar/integrations/__pycache__/__init__.cpython-38.pyc differ diff --git a/benepar/integrations/__pycache__/downloader.cpython-310.pyc b/benepar/integrations/__pycache__/downloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b8c58a99ef65a47d79a1699bb00b08c3f4dce05 Binary files /dev/null and b/benepar/integrations/__pycache__/downloader.cpython-310.pyc differ diff --git a/benepar/integrations/__pycache__/downloader.cpython-37.pyc b/benepar/integrations/__pycache__/downloader.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31cf8a2b28a8eb77f6b70f2d65f75c9f542cf668 Binary files /dev/null and b/benepar/integrations/__pycache__/downloader.cpython-37.pyc differ diff --git a/benepar/integrations/__pycache__/downloader.cpython-38.pyc b/benepar/integrations/__pycache__/downloader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5a9d405faba5614aef68f041f7a905e5cc73ea8 Binary files /dev/null and b/benepar/integrations/__pycache__/downloader.cpython-38.pyc differ diff --git a/benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc b/benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1ce1b4bfa1f0673955e35059a6c9fe9ae0213c6 Binary files /dev/null and b/benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc differ diff --git a/benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc b/benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43ef2f670994fc2e3218226531e128894aefc5ec Binary files /dev/null and b/benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc differ diff --git a/benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc b/benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0fc4d2d9053cdbe0d0036b4405bf08ddb1f8c02 Binary files /dev/null and b/benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc differ diff --git a/benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc b/benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e11db5a43e08b4f6c74587b0e6a10f3da3d7c44b Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc differ diff --git a/benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc b/benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b1bcb9c6f7746983ae772c32e7d04f794a69faf Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc differ diff --git a/benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc b/benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a51c46b521bb319ee604635af1c6365d3afcdf3e Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc differ diff --git a/benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc b/benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eccc195c49c47f56f863d6baa52b1763a87f935 Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc differ diff --git a/benepar/integrations/__pycache__/spacy_plugin.cpython-37.pyc b/benepar/integrations/__pycache__/spacy_plugin.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2152c134b41ead0535fcbeb883261b6fc7fc8d5b Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_plugin.cpython-37.pyc differ diff --git a/benepar/integrations/__pycache__/spacy_plugin.cpython-38.pyc b/benepar/integrations/__pycache__/spacy_plugin.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f77e3e57eb793f6be64bb18d1652bcece6da10 Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_plugin.cpython-38.pyc differ diff --git a/benepar/integrations/downloader.py b/benepar/integrations/downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..019aa4e286fcce338659bab0894068512d5f71ce --- /dev/null +++ b/benepar/integrations/downloader.py @@ -0,0 +1,35 @@ +import os + +BENEPAR_SERVER_INDEX = "https://kitaev.com/benepar/index.xml" + +_downloader = None +def get_downloader(): + global _downloader + if _downloader is None: + import nltk.downloader + _downloader = nltk.downloader.Downloader(server_index_url=BENEPAR_SERVER_INDEX) + return _downloader + +def download(*args, **kwargs): + return get_downloader().download(*args, **kwargs) + +def locate_model(name): + if os.path.exists(name): + return name + elif "/" not in name and "." not in name: + import nltk.data + try: + nltk_loc = nltk.data.find(f"models/{name}") + return nltk_loc.path + except LookupError as e: + arg = e.args[0].replace("nltk.download", "benepar.download") + + raise LookupError(arg) + + raise LookupError("Can't find {}".format(name)) + +def load_trained_model(model_name_or_path): + model_path = locate_model(model_name_or_path) + from ..parse_chart import ChartParser + parser = ChartParser.from_trained(model_path) + return parser diff --git a/benepar/integrations/nltk_plugin.py b/benepar/integrations/nltk_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a454d79eed45700c0882779cc62d5a0190c5b6 --- /dev/null +++ b/benepar/integrations/nltk_plugin.py @@ -0,0 +1,279 @@ +import dataclasses +import itertools +from typing import List, Optional, Tuple + +import nltk +import torch + +from .downloader import load_trained_model +from ..parse_base import BaseParser, BaseInputExample +from ..ptb_unescape import ptb_unescape, guess_space_after + + +TOKENIZER_LOOKUP = { + "en": "english", + "de": "german", + "fr": "french", + "pl": "polish", + "sv": "swedish", +} + +LANGUAGE_GUESS = { + "ar": ("X", "XP", "WHADVP", "WHNP", "WHPP"), + "zh": ("VSB", "VRD", "VPT", "VNV"), + "en": ("WHNP", "WHADJP", "SINV", "SQ"), + "de": ("AA", "AP", "CCP", "CH", "CNP", "VZ"), + "fr": ("P+", "P+D+", "PRO+", "PROREL+"), + "he": ("PREDP", "SYN_REL", "SYN_yyDOT"), + "pl": ("formaczas", "znakkonca"), + "sv": ("PSEUDO", "AVP", "XP"), +} + + +def guess_language(label_vocab): + """Guess parser language based on its syntactic label inventory. + + The parser training scripts are designed to accept arbitrary input tree + files with minimal language-specific behavior, but at inference time we may + need to know the language identity in order to invoke other pipeline + elements, such as tokenizers. + """ + for language, required_labels in LANGUAGE_GUESS.items(): + if all(label in label_vocab for label in required_labels): + return language + return None + + +@dataclasses.dataclass +class InputSentence(BaseInputExample): + """Parser input for a single sentence. + + At least one of `words` and `escaped_words` is required for each input + sentence. The remaining fields are optional: the parser will attempt to + derive the value for any missing fields using the fields that are provided. + + `words` and `space_after` together form a reversible tokenization of the + input text: they represent, respectively, the Unicode text for each word and + an indicator for whether the word is followed by whitespace. These are used + as inputs by the parser. + + `tags` is a list of part-of-speech tags, if available prior to running the + parser. The parser does not actually use these tags as input, but it will + pass them through to its output. If `tags` is None, the parser will perform + its own part of speech tagging (if the parser was not trained to also do + tagging, "UNK" part-of-speech tags will be used in the output instead). + + `escaped_words` are the representations of each leaf to use in the output + tree. If `words` is provided, `escaped_words` will not be used by the neural + network portion of the parser, and will only be incorporated when + constructing the output tree. Therefore, `escaped_words` may be used to + accommodate any dataset-specific text encoding, such as transliteration. + + Here is an example of the differences between these fields for English PTB: + (raw text): "Fly safely." + words: " Fly safely . " + space_after: False True False False False + tags: `` VB RB . '' + escaped_words: `` Fly safely . '' + """ + + words: Optional[List[str]] = None + space_after: Optional[List[bool]] = None + tags: Optional[List[str]] = None + escaped_words: Optional[List[str]] = None + + @property + def tree(self): + return None + + def leaves(self): + return self.escaped_words + + def pos(self): + if self.tags is not None: + return list(zip(self.escaped_words, self.tags)) + else: + return [(word, "UNK") for word in self.escaped_words] + + +class Parser: + """Berkeley Neural Parser (benepar), integrated with NLTK. + + Use this class to apply the Berkeley Neural Parser to pre-tokenized datasets + and treebanks, or when integrating the parser into an NLP pipeline that + already performs tokenization, sentence splitting, and (optionally) + part-of-speech tagging. For parsing starting with raw text, it is strongly + encouraged that you use spaCy and benepar.BeneparComponent instead. + + Sample usage: + >>> parser = benepar.Parser("benepar_en3") + >>> input_sentence = benepar.InputSentence( + words=['"', 'Fly', 'safely', '.', '"'], + space_after=[False, True, False, False, False], + tags=['``', 'VB', 'RB', '.', "''"], + escaped_words=['``', 'Fly', 'safely', '.', "''"], + ) + >>> parser.parse(input_sentence) + + Not all fields of benepar.InputSentence are required, but at least one of + `words` and `escaped_words` must not be None. The parser will attempt to + guess the value for missing fields. For example, + >>> input_sentence = benepar.InputSentence( + words=['"', 'Fly', 'safely', '.', '"'], + ) + >>> parser.parse(input_sentence) + + Although this class is primarily designed for use with data that has already + been tokenized, to help with interactive use and debugging it also accepts + simple text string inputs. However, using this class to parse from raw text + is STRONGLY DISCOURAGED for any application where parsing accuracy matters. + When parsing from raw text, use spaCy and benepar.BeneparComponent instead. + The reason is that parser models do not ship with a tokenizer or sentence + splitter, and some models may not include a part-of-speech tagger either. A + toolkit must be used to fill in these pipeline components, and spaCy + outperforms NLTK in all of these areas (sometimes by a large margin). + >>> parser.parse('"Fly safely."') # For debugging/interactive use only. + """ + + def __init__(self, name, batch_size=64, language_code=None): + """Load a trained parser model. + + Args: + name (str): Model name, or path to pytorch saved model + batch_size (int): Maximum number of sentences to process per batch + language_code (str, optional): language code for the parser (e.g. + 'en', 'he', 'zh', etc). Our official trained models will set + this automatically, so this argument is only needed if training + on new languages or treebanks. + """ + self._parser = load_trained_model(name) + if torch.cuda.is_available(): + self._parser.cuda() + if language_code is not None: + self._language_code = language_code + else: + self._language_code = guess_language(self._parser.config["label_vocab"]) + self._tokenizer_lang = TOKENIZER_LOOKUP.get(self._language_code, None) + + self.batch_size = batch_size + + def parse(self, sentence): + """Parse a single sentence + + Args: + sentence (InputSentence or List[str] or str): Sentence to parse. + If the input is of List[str], it is assumed to be a sequence of + words and will behave the same as only setting the `words` field + of InputSentence. If the input is of type str, the sentence will + be tokenized using the default NLTK tokenizer (not recommended: + if parsing from raw text, use spaCy and benepar.BeneparComponent + instead). + + Returns: + nltk.Tree + """ + return list(self.parse_sents([sentence]))[0] + + def parse_sents(self, sents): + """Parse multiple sentences in batches. + + Args: + sents (Iterable[InputSentence]): An iterable of sentences to be + parsed. `sents` may also be a string, in which case it will be + segmented into sentences using the default NLTK sentence + splitter (not recommended: if parsing from raw text, use spaCy + and benepar.BeneparComponent instead). Otherwise, each element + of `sents` will be treated as a sentence. The elements of + `sents` may also be List[str] or str: see Parser.parse() for + documentation regarding these cases. + + Yields: + nltk.Tree objects, one per input sentence. + """ + if isinstance(sents, str): + if self._tokenizer_lang is None: + raise ValueError( + "No tokenizer available for this language. " + "Please split into individual sentences and tokens " + "before calling the parser." + ) + sents = nltk.sent_tokenize(sents, self._tokenizer_lang) + + end_sentinel = object() + for batch_sents in itertools.zip_longest( + *([iter(sents)] * self.batch_size), fillvalue=end_sentinel + ): + batch_inputs = [] + for sent in batch_sents: + if sent is end_sentinel: + break + elif isinstance(sent, str): + if self._tokenizer_lang is None: + raise ValueError( + "No word tokenizer available for this language. " + "Please tokenize before calling the parser." + ) + escaped_words = nltk.word_tokenize(sent, self._tokenizer_lang) + sent = InputSentence(escaped_words=escaped_words) + elif isinstance(sent, (list, tuple)): + sent = InputSentence(words=sent) + elif not isinstance(sent, InputSentence): + raise ValueError( + "Sentences must be one of: InputSentence, list, tuple, or str" + ) + batch_inputs.append(self._with_missing_fields_filled(sent)) + + for inp, output in zip( + batch_inputs, self._parser.parse(batch_inputs, return_compressed=True) + ): + # If pos tags are provided as input, ignore any tags predicted + # by the parser. + if inp.tags is not None: + output = output.without_predicted_tags() + yield output.to_tree( + inp.pos(), + self._parser.decoder.label_from_index, + self._parser.tag_from_index, + ) + + def _with_missing_fields_filled(self, sent): + if not isinstance(sent, InputSentence): + raise ValueError("Input is not an instance of InputSentence") + if sent.words is None and sent.escaped_words is None: + raise ValueError("At least one of words or escaped_words is required") + elif sent.words is None: + sent = dataclasses.replace(sent, words=ptb_unescape(sent.escaped_words)) + elif sent.escaped_words is None: + escaped_words = [ + word.replace("(", "-LRB-") + .replace(")", "-RRB-") + .replace("{", "-LCB-") + .replace("}", "-RCB-") + .replace("[", "-LSB-") + .replace("]", "-RSB-") + for word in sent.words + ] + sent = dataclasses.replace(sent, escaped_words=escaped_words) + else: + if len(sent.words) != len(sent.escaped_words): + raise ValueError( + f"Length of words ({len(sent.words)}) does not match " + f"escaped_words ({len(sent.escaped_words)})" + ) + + if sent.space_after is None: + if self._language_code == "zh": + space_after = [False for _ in sent.words] + elif self._language_code in ("ar", "he"): + space_after = [True for _ in sent.words] + else: + space_after = guess_space_after(sent.words) + sent = dataclasses.replace(sent, space_after=space_after) + elif len(sent.words) != len(sent.space_after): + raise ValueError( + f"Length of words ({len(sent.words)}) does not match " + f"space_after ({len(sent.space_after)})" + ) + + assert len(sent.words) == len(sent.escaped_words) == len(sent.space_after) + return sent diff --git a/benepar/integrations/spacy_extensions.py b/benepar/integrations/spacy_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..572dc45fa8371d97f758a39d213834ce33bed998 --- /dev/null +++ b/benepar/integrations/spacy_extensions.py @@ -0,0 +1,179 @@ +NOT_PARSED_SENTINEL = object() + + +class NonConstituentException(Exception): + pass + + +class ConstituentData: + def __init__(self, starts, ends, labels, loc_to_constituent, label_vocab): + self.starts = starts + self.ends = ends + self.labels = labels + self.loc_to_constituent = loc_to_constituent + self.label_vocab = label_vocab + + +def get_constituent(span): + constituent_data = span.doc._._constituent_data + if constituent_data is NOT_PARSED_SENTINEL: + raise Exception( + "No constituency parse is available for this document." + " Consider adding a BeneparComponent to the pipeline." + ) + + search_start = constituent_data.loc_to_constituent[span.start] + if span.start + 1 < len(constituent_data.loc_to_constituent): + search_end = constituent_data.loc_to_constituent[span.start + 1] + else: + search_end = len(constituent_data.ends) + found_position = None + for position in range(search_start, search_end): + if constituent_data.ends[position] <= span.end: + if constituent_data.ends[position] == span.end: + found_position = position + break + + if found_position is None: + raise NonConstituentException("Span is not a constituent: {}".format(span)) + return constituent_data, found_position + + +def get_labels(span): + constituent_data, position = get_constituent(span) + label_num = constituent_data.labels[position] + return constituent_data.label_vocab[label_num] + + +def parse_string(span): + constituent_data, position = get_constituent(span) + label_vocab = constituent_data.label_vocab + doc = span.doc + + idx = position - 1 + + def make_str(): + nonlocal idx + idx += 1 + i, j, label_idx = ( + constituent_data.starts[idx], + constituent_data.ends[idx], + constituent_data.labels[idx], + ) + label = label_vocab[label_idx] + if (i + 1) >= j: + token = doc[i] + s = ( + "(" + + u"{} {}".format(token.tag_, token.text) + .replace("(", "-LRB-") + .replace(")", "-RRB-") + .replace("{", "-LCB-") + .replace("}", "-RCB-") + .replace("[", "-LSB-") + .replace("]", "-RSB-") + + ")" + ) + else: + children = [] + while ( + (idx + 1) < len(constituent_data.starts) + and i <= constituent_data.starts[idx + 1] + and constituent_data.ends[idx + 1] <= j + ): + children.append(make_str()) + + s = u" ".join(children) + + for sublabel in reversed(label): + s = u"({} {})".format(sublabel, s) + return s + + return make_str() + + +def get_subconstituents(span): + constituent_data, position = get_constituent(span) + label_vocab = constituent_data.label_vocab + doc = span.doc + + while position < len(constituent_data.starts): + start = constituent_data.starts[position] + end = constituent_data.ends[position] + + if span.end <= start or span.end < end: + break + + yield doc[start:end] + position += 1 + + +def get_child_spans(span): + constituent_data, position = get_constituent(span) + label_vocab = constituent_data.label_vocab + doc = span.doc + + child_start_expected = span.start + position += 1 + while position < len(constituent_data.starts): + start = constituent_data.starts[position] + end = constituent_data.ends[position] + + if span.end <= start or span.end < end: + break + + if start == child_start_expected: + yield doc[start:end] + child_start_expected = end + + position += 1 + + +def get_parent_span(span): + constituent_data, position = get_constituent(span) + label_vocab = constituent_data.label_vocab + doc = span.doc + sent = span.sent + + position -= 1 + while position >= 0: + start = constituent_data.starts[position] + end = constituent_data.ends[position] + + if start <= span.start and span.end <= end: + return doc[start:end] + if end < span.sent.start: + break + position -= 1 + + return None + + +def install_spacy_extensions(): + from spacy.tokens import Doc, Span, Token + + # None is not allowed as a default extension value! + Doc.set_extension("_constituent_data", default=NOT_PARSED_SENTINEL) + + Span.set_extension("labels", getter=get_labels) + Span.set_extension("parse_string", getter=parse_string) + Span.set_extension("constituents", getter=get_subconstituents) + Span.set_extension("parent", getter=get_parent_span) + Span.set_extension("children", getter=get_child_spans) + + Token.set_extension( + "labels", getter=lambda token: get_labels(token.doc[token.i : token.i + 1]) + ) + Token.set_extension( + "parse_string", + getter=lambda token: parse_string(token.doc[token.i : token.i + 1]), + ) + Token.set_extension( + "parent", getter=lambda token: get_parent_span(token.doc[token.i : token.i + 1]) + ) + + +try: + install_spacy_extensions() +except ImportError: + pass diff --git a/benepar/integrations/spacy_plugin.py b/benepar/integrations/spacy_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..41ca8b6e41a6a3368a7c1d207a99704b68a82491 --- /dev/null +++ b/benepar/integrations/spacy_plugin.py @@ -0,0 +1,206 @@ +import numpy as np + +from .downloader import load_trained_model +from ..parse_base import BaseParser, BaseInputExample +from .spacy_extensions import ConstituentData, NonConstituentException + +import torch + + +class PartialConstituentData: + def __init__(self): + self.starts = [np.array([], dtype=int)] + self.ends = [np.array([], dtype=int)] + self.labels = [np.array([], dtype=int)] + + def finalize(self, doc, label_vocab): + self.starts = np.hstack(self.starts) + self.ends = np.hstack(self.ends) + self.labels = np.hstack(self.labels) + + # TODO(nikita): Python for loops aren't very fast + loc_to_constituent = np.full(len(doc), -1, dtype=int) + prev = None + for position in range(self.starts.shape[0]): + if self.starts[position] != prev: + prev = self.starts[position] + loc_to_constituent[self.starts[position]] = position + + return ConstituentData( + self.starts, self.ends, self.labels, loc_to_constituent, label_vocab + ) + + +class SentenceWrapper(BaseInputExample): + TEXT_NORMALIZATION_MAPPING = { + "`": "'", + "«": '"', + "»": '"', + "‘": "'", + "’": "'", + "“": '"', + "”": '"', + "„": '"', + "‹": "'", + "›": "'", + "—": "--", # em dash + } + + def __init__(self, spacy_sent): + self.sent = spacy_sent + + @property + def words(self): + return [ + self.TEXT_NORMALIZATION_MAPPING.get(token.text, token.text) + for token in self.sent + ] + + @property + def space_after(self): + return [bool(token.whitespace_) for token in self.sent] + + @property + def tree(self): + return None + + def leaves(self): + return self.words + + def pos(self): + return [(word, "UNK") for word in self.words] + + +class BeneparComponent: + """ + Berkeley Neural Parser (benepar) component for spaCy. + + Sample usage: + >>> nlp = spacy.load('en_core_web_md') + >>> if spacy.__version__.startswith('2'): + nlp.add_pipe(BeneparComponent("benepar_en3")) + else: + nlp.add_pipe("benepar", config={"model": "benepar_en3"}) + >>> doc = nlp("The quick brown fox jumps over the lazy dog.") + >>> sent = list(doc.sents)[0] + >>> print(sent._.parse_string) + + This component is only responsible for constituency parsing and (for some + trained models) part-of-speech tagging. It should be preceded in the + pipeline by other components that can, at minimum, perform tokenization and + sentence segmentation. + """ + + name = "benepar" + + def __init__( + self, + name, + subbatch_max_tokens=500, + disable_tagger=False, + batch_size="ignored", + ): + """Load a trained parser model. + + Args: + name (str): Model name, or path to pytorch saved model + subbatch_max_tokens (int): Maximum number of tokens to process in + each batch + disable_tagger (bool, default False): Unless disabled, the parser + will set predicted part-of-speech tags for the document, + overwriting any existing tags provided by spaCy models or + previous pipeline steps. This option has no effect for parser + models that do not have a part-of-speech tagger built in. + batch_size: deprecated and ignored; use subbatch_max_tokens instead + """ + self._parser = load_trained_model(name) + if torch.cuda.is_available(): + self._parser.cuda() + + self.subbatch_max_tokens = subbatch_max_tokens + self.disable_tagger = disable_tagger + + self._label_vocab = self._parser.config["label_vocab"] + label_vocab_size = max(self._label_vocab.values()) + 1 + self._label_from_index = [()] * label_vocab_size + for label, i in self._label_vocab.items(): + if label: + self._label_from_index[i] = tuple(label.split("::")) + else: + self._label_from_index[i] = () + self._label_from_index = tuple(self._label_from_index) + + if not self.disable_tagger: + tag_vocab = self._parser.config["tag_vocab"] + tag_vocab_size = max(tag_vocab.values()) + 1 + self._tag_from_index = [()] * tag_vocab_size + for tag, i in tag_vocab.items(): + self._tag_from_index[i] = tag + self._tag_from_index = tuple(self._tag_from_index) + else: + self._tag_from_index = None + + def __call__(self, doc): + """Update the input document with predicted constituency parses.""" + # TODO(https://github.com/nikitakit/self-attentive-parser/issues/16): handle + # tokens that consist entirely of whitespace. + constituent_data = PartialConstituentData() + wrapped_sents = [SentenceWrapper(sent) for sent in doc.sents] + for sent, parse in zip( + doc.sents, + self._parser.parse( + wrapped_sents, + return_compressed=True, + subbatch_max_tokens=self.subbatch_max_tokens, + ), + ): + constituent_data.starts.append(parse.starts + sent.start) + constituent_data.ends.append(parse.ends + sent.start) + constituent_data.labels.append(parse.labels) + + if parse.tags is not None and not self.disable_tagger: + for i, tag_id in enumerate(parse.tags): + sent[i].tag_ = self._tag_from_index[tag_id] + + doc._._constituent_data = constituent_data.finalize(doc, self._label_from_index) + return doc + + +def create_benepar_component( + nlp, + name, + model: str, + subbatch_max_tokens: int, + disable_tagger: bool, +): + return BeneparComponent( + model, + subbatch_max_tokens=subbatch_max_tokens, + disable_tagger=disable_tagger, + ) + + +def register_benepar_component_factory(): + # Starting with spaCy 3.0, nlp.add_pipe no longer directly accepts + # BeneparComponent instances. We must instead register a component factory. + import spacy + + if spacy.__version__.startswith("2"): + return + + from spacy.language import Language + + Language.factory( + "benepar", + default_config={ + "subbatch_max_tokens": 500, + "disable_tagger": False, + }, + func=create_benepar_component, + ) + + +try: + register_benepar_component_factory() +except ImportError: + pass diff --git a/benepar/nkutil.py b/benepar/nkutil.py new file mode 100644 index 0000000000000000000000000000000000000000..290ad20474d1406f9091aebbbbc960562c9075c1 --- /dev/null +++ b/benepar/nkutil.py @@ -0,0 +1,51 @@ +class HParams: + _skip_keys = ["populate_arguments", "set_from_args", "print", "to_dict"] + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + if not hasattr(self, item): + raise KeyError(f"Hyperparameter {item} has not been declared yet") + setattr(self, item, value) + + def to_dict(self): + res = {} + for k in dir(self): + if k.startswith("_") or k in self._skip_keys: + continue + res[k] = self[k] + return res + + def populate_arguments(self, parser): + for k in dir(self): + if k.startswith("_") or k in self._skip_keys: + continue + v = self[k] + k = k.replace("_", "-") + if type(v) in (int, float, str): + parser.add_argument(f"--{k}", type=type(v), default=v) + elif isinstance(v, bool): + if not v: + parser.add_argument(f"--{k}", action="store_true") + else: + parser.add_argument(f"--no-{k}", action="store_false") + + def set_from_args(self, args): + for k in dir(self): + if k.startswith("_") or k in self._skip_keys: + continue + if hasattr(args, k): + self[k] = getattr(args, k) + elif hasattr(args, f"no_{k}"): + self[k] = getattr(args, f"no_{k}") + + def print(self): + for k in dir(self): + if k.startswith("_") or k in self._skip_keys: + continue + print(k, repr(self[k])) diff --git a/benepar/parse_base.py b/benepar/parse_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9be49169f6ed97148ba6d109c7512a3c0e5feb05 --- /dev/null +++ b/benepar/parse_base.py @@ -0,0 +1,216 @@ +from abc import ABC, abstractmethod +import dataclasses +from typing import Any, Iterable, List, Optional, Tuple, Union + +import nltk +import numpy as np + + +class BaseInputExample(ABC): + """Parser input for a single sentence (abstract interface).""" + + # Subclasses must define the following attributes or properties. + # `words` is a list of unicode representations for each word in the sentence + # and `space_after` is a list of booleans that indicate whether there is + # whitespace after a word. Together, these should form a reversible + # tokenization of raw text input. `tree` is an optional gold parse tree. + words: List[str] + space_after: List[bool] + tree: Optional[nltk.Tree] + + @abstractmethod + def leaves(self) -> Optional[List[str]]: + """Returns leaves to use in the parse tree. + + While `words` must be raw unicode text, these should be whatever is + standard for the treebank. For example, '(' in words might correspond to + '-LRB-' in leaves, and leaves might include other transformations such + as transliteration. + """ + pass + + @abstractmethod + def pos(self) -> Optional[List[Tuple[str, str]]]: + """Returns a list of (leaf, part-of-speech tag) tuples.""" + pass + + +@dataclasses.dataclass +class CompressedParserOutput: + """Parser output, encoded as a collection of numpy arrays. + + By default, a parser will return nltk.Tree objects. These have much nicer + APIs than the CompressedParserOutput class, and the code involved is simpler + and more readable. As a trade-off, code dealing with nltk.Tree objects is + slower: the nltk.Tree type itself has some overhead, and algorithms dealing + with it are implemented in pure Python as opposed to C or even CUDA. The + CompressedParserOutput type is an alternative that has some optimizations + for the sole purpose of speeding up inference. + + If trying a new parser type for research purposes, it's safe to ignore this + class and the return_compressed argument to parse(). If the parser works + well and is being released, the return_compressed argument can then be added + with a dedicated fast implementation, or simply by using the from_tree + method defined below. + """ + + # A parse tree is represented as a set of constituents. In the case of + # non-binary trees, only the labeled non-terminal nodes are included: there + # are no dummy nodes inserted for binarization purposes. However, single + # words are always included in the set of constituents, and they may have a + # null label if there is no phrasal category above the part-of-speech tag. + # All constituents are sorted according to pre-order traversal, and each has + # an associated start (the index of the first word in the constituent), end + # (1 + the index of the last word in the constituent), and label (index + # associated with an external label_vocab dictionary.) These are then stored + # in three numpy arrays: + starts: Iterable[int] # Must be a numpy array + ends: Iterable[int] # Must be a numpy array + labels: Iterable[int] # Must be a numpy array + + # Part of speech tag ids as output by the parser (may be None if the parser + # does not do POS tagging). These indices are associated with an external + # tag_vocab dictionary. + tags: Optional[Iterable[int]] = None # Must be None or a numpy array + + def without_predicted_tags(self): + return dataclasses.replace(self, tags=None) + + def with_tags(self, tags): + return dataclasses.replace(self, tags=tags) + + @classmethod + def from_tree( + cls, tree: nltk.Tree, label_vocab: dict, tag_vocab: Optional[dict] = None + ) -> "CompressedParserOutput": + num_words = len(tree.leaves()) + starts = np.empty(2 * num_words, dtype=int) + ends = np.empty(2 * num_words, dtype=int) + labels = np.empty(2 * num_words, dtype=int) + + def helper(tree, start, write_idx): + nonlocal starts, ends, labels + label = [] + while len(tree) == 1 and not isinstance(tree[0], str): + if tree.label() != "TOP": + label.append(tree.label()) + tree = tree[0] + + if len(tree) == 1 and isinstance(tree[0], str): + starts[write_idx] = start + ends[write_idx] = start + 1 + labels[write_idx] = label_vocab["::".join(label)] + return start + 1, write_idx + 1 + + label.append(tree.label()) + starts[write_idx] = start + labels[write_idx] = label_vocab["::".join(label)] + + end = start + new_write_idx = write_idx + 1 + for child in tree: + end, new_write_idx = helper(child, end, new_write_idx) + + ends[write_idx] = end + return end, new_write_idx + + _, num_constituents = helper(tree, 0, 0) + starts = starts[:num_constituents] + ends = ends[:num_constituents] + labels = labels[:num_constituents] + + if tag_vocab is None: + tags = None + else: + tags = np.array([tag_vocab[tag] for _, tag in tree.pos()], dtype=int) + + return cls(starts=starts, ends=ends, labels=labels, tags=tags) + + def to_tree(self, leaves, label_from_index: dict, tag_from_index: dict = None): + if self.tags is not None: + if tag_from_index is None: + raise ValueError( + "tags_from_index is required to convert predicted pos tags" + ) + predicted_tags = [tag_from_index[i] for i in self.tags] + assert len(leaves) == len(predicted_tags) + leaves = [ + nltk.Tree(tag, [leaf[0] if isinstance(leaf, tuple) else leaf]) + for tag, leaf in zip(predicted_tags, leaves) + ] + else: + leaves = [ + nltk.Tree(leaf[1], [leaf[0]]) + if isinstance(leaf, tuple) + else (nltk.Tree("UNK", [leaf]) if isinstance(leaf, str) else leaf) + for leaf in leaves + ] + + idx = -1 + + def helper(): + nonlocal idx + idx += 1 + i, j, label = ( + self.starts[idx], + self.ends[idx], + label_from_index[self.labels[idx]], + ) + if (i + 1) >= j: + children = [leaves[i]] + else: + children = [] + while ( + (idx + 1) < len(self.starts) + and i <= self.starts[idx + 1] + and self.ends[idx + 1] <= j + ): + children.extend(helper()) + + if label: + for sublabel in reversed(label.split("::")): + children = [nltk.Tree(sublabel, children)] + + return children + + children = helper() + return nltk.Tree("TOP", children) + + +class BaseParser(ABC): + """Parser (abstract interface)""" + + @classmethod + @abstractmethod + def from_trained( + cls, model_name: str, config: dict = None, state_dict: dict = None + ) -> "BaseParser": + """Load a trained parser.""" + pass + + @abstractmethod + def parallelize(self, *args, **kwargs): + """Spread out pre-trained model layers across GPUs.""" + pass + + @abstractmethod + def parse( + self, + examples: Iterable[BaseInputExample], + return_compressed: bool = False, + return_scores: bool = False, + subbatch_max_tokens: Optional[int] = None, + ) -> Union[Iterable[nltk.Tree], Iterable[Any]]: + """Parse sentences.""" + pass + + @abstractmethod + def encode_and_collate_subbatches( + self, examples: List[BaseInputExample], subbatch_max_tokens: int + ) -> List[dict]: + """Split batch into sub-batches and convert to tensor features""" + pass + + @abstractmethod + def compute_loss(self, batch: dict): + pass diff --git a/benepar/parse_chart.py b/benepar/parse_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf8314885a8b77f01dd71d0636c34eb85d7f5ae --- /dev/null +++ b/benepar/parse_chart.py @@ -0,0 +1,434 @@ +import os + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import AutoConfig, AutoModel + +from . import char_lstm +from . import decode_chart +from . import nkutil +from .partitioned_transformer import ( + ConcatPositionalEncoding, + FeatureDropout, + PartitionedTransformerEncoder, + PartitionedTransformerEncoderLayer, +) +from . import parse_base +from . import retokenization +from . import subbatching + + +class ChartParser(nn.Module, parse_base.BaseParser): + def __init__( + self, + tag_vocab, + label_vocab, + char_vocab, + hparams, + pretrained_model_path=None, + ): + super().__init__() + self.config = locals() + self.config.pop("self") + self.config.pop("__class__") + self.config.pop("pretrained_model_path") + self.config["hparams"] = hparams.to_dict() + + self.tag_vocab = tag_vocab + self.label_vocab = label_vocab + self.char_vocab = char_vocab + + self.d_model = hparams.d_model + + self.char_encoder = None + self.pretrained_model = None + if hparams.use_chars_lstm: + assert ( + not hparams.use_pretrained + ), "use_chars_lstm and use_pretrained are mutually exclusive" + self.retokenizer = char_lstm.RetokenizerForCharLSTM(self.char_vocab) + self.char_encoder = char_lstm.CharacterLSTM( + max(self.char_vocab.values()) + 1, + hparams.d_char_emb, + hparams.d_model // 2, # Half-size to leave room for + # partitioned positional encoding + char_dropout=hparams.char_lstm_input_dropout, + ) + elif hparams.use_pretrained: + if pretrained_model_path is None: + self.retokenizer = retokenization.Retokenizer( + hparams.pretrained_model, retain_start_stop=True + ) + self.pretrained_model = AutoModel.from_pretrained( + hparams.pretrained_model + ) + else: + self.retokenizer = retokenization.Retokenizer( + pretrained_model_path, retain_start_stop=True + ) + self.pretrained_model = AutoModel.from_config( + AutoConfig.from_pretrained(pretrained_model_path) + ) + d_pretrained = self.pretrained_model.config.hidden_size + + if hparams.use_encoder: + self.project_pretrained = nn.Linear( + d_pretrained, hparams.d_model // 2, bias=False + ) + else: + self.project_pretrained = nn.Linear( + d_pretrained, hparams.d_model, bias=False + ) + + if hparams.use_encoder: + self.morpho_emb_dropout = FeatureDropout(hparams.morpho_emb_dropout) + self.add_timing = ConcatPositionalEncoding( + d_model=hparams.d_model, + max_len=hparams.encoder_max_len, + ) + encoder_layer = PartitionedTransformerEncoderLayer( + hparams.d_model, + n_head=hparams.num_heads, + d_qkv=hparams.d_kv, + d_ff=hparams.d_ff, + ff_dropout=hparams.relu_dropout, + residual_dropout=hparams.residual_dropout, + attention_dropout=hparams.attention_dropout, + ) + self.encoder = PartitionedTransformerEncoder( + encoder_layer, hparams.num_layers + ) + else: + self.morpho_emb_dropout = None + self.add_timing = None + self.encoder = None + + self.f_label = nn.Sequential( + nn.Linear(hparams.d_model, hparams.d_label_hidden), + nn.LayerNorm(hparams.d_label_hidden), + nn.ReLU(), + nn.Linear(hparams.d_label_hidden, max(label_vocab.values())), + ) + + if hparams.predict_tags: + self.f_tag = nn.Sequential( + nn.Linear(hparams.d_model, hparams.d_tag_hidden), + nn.LayerNorm(hparams.d_tag_hidden), + nn.ReLU(), + nn.Linear(hparams.d_tag_hidden, max(tag_vocab.values()) + 1), + ) + self.tag_loss_scale = hparams.tag_loss_scale + self.tag_from_index = {i: label for label, i in tag_vocab.items()} + else: + self.f_tag = None + self.tag_from_index = None + + self.decoder = decode_chart.ChartDecoder( + label_vocab=self.label_vocab, + force_root_constituent=hparams.force_root_constituent, + ) + self.criterion = decode_chart.SpanClassificationMarginLoss( + reduction="sum", force_root_constituent=hparams.force_root_constituent + ) + + self.parallelized_devices = None + + @property + def device(self): + if self.parallelized_devices is not None: + return self.parallelized_devices[0] + else: + return next(self.f_label.parameters()).device + + @property + def output_device(self): + if self.parallelized_devices is not None: + return self.parallelized_devices[1] + else: + return next(self.f_label.parameters()).device + + def parallelize(self, *args, **kwargs): + self.parallelized_devices = (torch.device("cuda", 0), torch.device("cuda", 1)) + for child in self.children(): + if child != self.pretrained_model: + child.to(self.output_device) + self.pretrained_model.parallelize(*args, **kwargs) + + @classmethod + def from_trained(cls, model_path): + if os.path.isdir(model_path): + # Multi-file format used when exporting models for release. + # Unlike the checkpoints saved during training, these files include + # all tokenizer parameters and a copy of the pre-trained model + # config (rather than downloading these on-demand). + config = AutoConfig.from_pretrained(model_path).benepar + state_dict = torch.load( + os.path.join(model_path, "benepar_model.bin"), map_location="cpu" + ) + config["pretrained_model_path"] = model_path + else: + # Single-file format used for saving checkpoints during training. + data = torch.load(model_path, map_location="cpu") + config = data["config"] + state_dict = data["state_dict"] + + hparams = config["hparams"] + + if "force_root_constituent" not in hparams: + hparams["force_root_constituent"] = True + + config["hparams"] = nkutil.HParams(**hparams) + parser = cls(**config) + parser.load_state_dict(state_dict) + return parser + + def encode(self, example): + if self.char_encoder is not None: + encoded = self.retokenizer(example.words, return_tensors="np") + else: + encoded = self.retokenizer(example.words, example.space_after) + + if example.tree is not None: + encoded["span_labels"] = torch.tensor( + self.decoder.chart_from_tree(example.tree) + ) + if self.f_tag is not None: + encoded["tag_labels"] = torch.tensor( + [-100] + [self.tag_vocab[tag] for _, tag in example.pos()] + [-100] + ) + return encoded + + def pad_encoded(self, encoded_batch): + batch = self.retokenizer.pad( + [ + { + k: v + for k, v in example.items() + if (k != "span_labels" and k != "tag_labels") + } + for example in encoded_batch + ], + return_tensors="pt", + ) + if encoded_batch and "span_labels" in encoded_batch[0]: + batch["span_labels"] = decode_chart.pad_charts( + [example["span_labels"] for example in encoded_batch] + ) + if encoded_batch and "tag_labels" in encoded_batch[0]: + batch["tag_labels"] = nn.utils.rnn.pad_sequence( + [example["tag_labels"] for example in encoded_batch], + batch_first=True, + padding_value=-100, + ) + return batch + + def _get_lens(self, encoded_batch): + if self.pretrained_model is not None: + return [len(encoded["input_ids"]) for encoded in encoded_batch] + return [len(encoded["valid_token_mask"]) for encoded in encoded_batch] + + def encode_and_collate_subbatches(self, examples, subbatch_max_tokens): + batch_size = len(examples) + batch_num_tokens = sum(len(x.words) for x in examples) + encoded = [self.encode(example) for example in examples] + + res = [] + for ids, subbatch_encoded in subbatching.split( + encoded, costs=self._get_lens(encoded), max_cost=subbatch_max_tokens + ): + subbatch = self.pad_encoded(subbatch_encoded) + subbatch["batch_size"] = batch_size + subbatch["batch_num_tokens"] = batch_num_tokens + res.append((len(ids), subbatch)) + return res + + def forward(self, batch): + valid_token_mask = batch["valid_token_mask"].to(self.output_device) + + if ( + self.encoder is not None + and valid_token_mask.shape[1] > self.add_timing.timing_table.shape[0] + ): + raise ValueError( + "Sentence of length {} exceeds the maximum supported length of " + "{}".format( + valid_token_mask.shape[1] - 2, + self.add_timing.timing_table.shape[0] - 2, + ) + ) + + if self.char_encoder is not None: + assert isinstance(self.char_encoder, char_lstm.CharacterLSTM) + char_ids = batch["char_ids"].to(self.device) + extra_content_annotations = self.char_encoder(char_ids, valid_token_mask) + elif self.pretrained_model is not None: + input_ids = batch["input_ids"].to(self.device) + words_from_tokens = batch["words_from_tokens"].to(self.output_device) + pretrained_attention_mask = batch["attention_mask"].to(self.device) + + extra_kwargs = {} + if "token_type_ids" in batch: + extra_kwargs["token_type_ids"] = batch["token_type_ids"].to(self.device) + if "decoder_input_ids" in batch: + extra_kwargs["decoder_input_ids"] = batch["decoder_input_ids"].to( + self.device + ) + extra_kwargs["decoder_attention_mask"] = batch[ + "decoder_attention_mask" + ].to(self.device) + + pretrained_out = self.pretrained_model( + input_ids, attention_mask=pretrained_attention_mask, **extra_kwargs + ) + features = pretrained_out.last_hidden_state.to(self.output_device) + features = features[ + torch.arange(features.shape[0])[:, None], + # Note that words_from_tokens uses index -100 for invalid positions + F.relu(words_from_tokens), + ] + features.masked_fill_(~valid_token_mask[:, :, None], 0) + if self.encoder is not None: + extra_content_annotations = self.project_pretrained(features) + + if self.encoder is not None: + encoder_in = self.add_timing( + self.morpho_emb_dropout(extra_content_annotations) + ) + + annotations = self.encoder(encoder_in, valid_token_mask) + # Rearrange the annotations to ensure that the transition to + # fenceposts captures an even split between position and content. + + annotations = torch.cat( + [ + annotations[..., 0::2], + annotations[..., 1::2], + ], + -1, + ) + else: + assert self.pretrained_model is not None + annotations = self.project_pretrained(features) + + if self.f_tag is not None: + tag_scores = self.f_tag(annotations) + else: + tag_scores = None + + fencepost_annotations = torch.cat( + [ + annotations[:, :-1, : self.d_model // 2], + annotations[:, 1:, self.d_model // 2 :], + ], + -1, + ) + + # Note that the bias added to the final layer norm is useless because + # this subtraction gets rid of it + span_features = ( + torch.unsqueeze(fencepost_annotations, 1) + - torch.unsqueeze(fencepost_annotations, 2) + )[:, :-1, 1:] + span_scores = self.f_label(span_features) + span_scores = torch.cat( + [span_scores.new_zeros(span_scores.shape[:-1] + (1,)), span_scores], -1 + ) + return span_scores, tag_scores + + def compute_loss(self, batch): + span_scores, tag_scores = self.forward(batch) + span_labels = batch["span_labels"].to(span_scores.device) + span_loss = self.criterion(span_scores, span_labels) + # Divide by the total batch size, not by the subbatch size + span_loss = span_loss / batch["batch_size"] + if tag_scores is None: + return span_loss + else: + tag_labels = batch["tag_labels"].to(tag_scores.device) + tag_loss = self.tag_loss_scale * F.cross_entropy( + tag_scores.reshape((-1, tag_scores.shape[-1])), + tag_labels.reshape((-1,)), + reduction="sum", + ignore_index=-100, + ) + tag_loss = tag_loss / batch["batch_num_tokens"] + return span_loss + tag_loss + + def _parse_encoded( + self, examples, encoded, return_compressed=False, return_scores=False + ): + with torch.no_grad(): + batch = self.pad_encoded(encoded) + span_scores, tag_scores = self.forward(batch) + if return_scores: + span_scores_np = span_scores.cpu().numpy() + else: + # Start/stop tokens don't count, so subtract 2 + lengths = batch["valid_token_mask"].sum(-1) - 2 + charts_np = self.decoder.charts_from_pytorch_scores_batched( + span_scores, lengths.to(span_scores.device) + ) + if tag_scores is not None: + tag_ids_np = tag_scores.argmax(-1).cpu().numpy() + else: + tag_ids_np = None + + for i in range(len(encoded)): + example_len = len(examples[i].words) + if return_scores: + yield span_scores_np[i, :example_len, :example_len] + elif return_compressed: + output = self.decoder.compressed_output_from_chart(charts_np[i]) + if tag_ids_np is not None: + output = output.with_tags(tag_ids_np[i, 1 : example_len + 1]) + yield output + else: + if tag_scores is None: + leaves = examples[i].pos() + else: + predicted_tags = [ + self.tag_from_index[i] + for i in tag_ids_np[i, 1 : example_len + 1] + ] + leaves = [ + (word, predicted_tag) + for predicted_tag, (word, gold_tag) in zip( + predicted_tags, examples[i].pos() + ) + ] + yield self.decoder.tree_from_chart(charts_np[i], leaves=leaves) + + def parse( + self, + examples, + return_compressed=False, + return_scores=False, + subbatch_max_tokens=None, + ): + training = self.training + self.eval() + encoded = [self.encode(example) for example in examples] + if subbatch_max_tokens is not None: + res = subbatching.map( + self._parse_encoded, + examples, + encoded, + costs=self._get_lens(encoded), + max_cost=subbatch_max_tokens, + return_compressed=return_compressed, + return_scores=return_scores, + ) + else: + res = self._parse_encoded( + examples, + encoded, + return_compressed=return_compressed, + return_scores=return_scores, + ) + res = list(res) + self.train(training) + return res diff --git a/benepar/partitioned_transformer.py b/benepar/partitioned_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a078b41b5c26e1ec283794d735e0e0e9bbe29201 --- /dev/null +++ b/benepar/partitioned_transformer.py @@ -0,0 +1,206 @@ +""" +Transformer with partitioned content and position features. + +See section 3 of https://arxiv.org/pdf/1805.01052.pdf +""" + +import copy +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FeatureDropoutFunction(torch.autograd.function.InplaceFunction): + @staticmethod + def forward(ctx, input, p=0.5, train=False, inplace=False): + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, but got {}".format(p) + ) + + ctx.p = p + ctx.train = train + ctx.inplace = inplace + + if ctx.inplace: + ctx.mark_dirty(input) + output = input + else: + output = input.clone() + + if ctx.p > 0 and ctx.train: + ctx.noise = torch.empty( + (input.size(0), input.size(-1)), + dtype=input.dtype, + layout=input.layout, + device=input.device, + ) + if ctx.p == 1: + ctx.noise.fill_(0) + else: + ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p) + ctx.noise = ctx.noise[:, None, :] + output.mul_(ctx.noise) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.p > 0 and ctx.train: + return grad_output.mul(ctx.noise), None, None, None + else: + return grad_output, None, None, None + + +class FeatureDropout(nn.Dropout): + """ + Feature-level dropout: takes an input of size len x num_features and drops + each feature with probabibility p. A feature is dropped across the full + portion of the input that corresponds to a single batch element. + """ + + def forward(self, x): + if isinstance(x, tuple): + x_c, x_p = x + x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace) + x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace) + return x_c, x_p + else: + return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace) + + +class PartitionedReLU(nn.ReLU): + def forward(self, x): + if isinstance(x, tuple): + x_c, x_p = x + else: + x_c, x_p = torch.chunk(x, 2, dim=-1) + return super().forward(x_c), super().forward(x_p) + + +class PartitionedLinear(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias) + self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias) + + def forward(self, x): + if isinstance(x, tuple): + x_c, x_p = x + else: + x_c, x_p = torch.chunk(x, 2, dim=-1) + + out_c = self.linear_c(x_c) + out_p = self.linear_p(x_p) + return out_c, out_p + + +class PartitionedMultiHeadAttention(nn.Module): + def __init__( + self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02 + ): + super().__init__() + + self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) + self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) + self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) + self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) + + bound = math.sqrt(3.0) * initializer_range + for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]: + nn.init.uniform_(param, -bound, bound) + self.scaling_factor = 1 / d_qkv ** 0.5 + + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, x, mask=None): + if isinstance(x, tuple): + x_c, x_p = x + else: + x_c, x_p = torch.chunk(x, 2, dim=-1) + qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c) + qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p) + q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)] + q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)] + q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor + k = torch.cat([k_c, k_p], dim=-1) + v = torch.cat([v_c, v_p], dim=-1) + dots = torch.einsum("bhqa,bhka->bhqk", q, k) + if mask is not None: + dots.data.masked_fill_(~mask[:, None, None, :], -float("inf")) + probs = F.softmax(dots, dim=-1) + probs = self.dropout(probs) + o = torch.einsum("bhqk,bhka->bhqa", probs, v) + o_c, o_p = torch.chunk(o, 2, dim=-1) + out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c) + out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p) + return out_c, out_p + + +class PartitionedTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + n_head, + d_qkv, + d_ff, + ff_dropout=0.1, + residual_dropout=0.1, + attention_dropout=0.1, + activation=PartitionedReLU(), + ): + super().__init__() + self.self_attn = PartitionedMultiHeadAttention( + d_model, n_head, d_qkv, attention_dropout=attention_dropout + ) + self.linear1 = PartitionedLinear(d_model, d_ff) + self.ff_dropout = FeatureDropout(ff_dropout) + self.linear2 = PartitionedLinear(d_ff, d_model) + + self.norm_attn = nn.LayerNorm(d_model) + self.norm_ff = nn.LayerNorm(d_model) + self.residual_dropout_attn = FeatureDropout(residual_dropout) + self.residual_dropout_ff = FeatureDropout(residual_dropout) + + self.activation = activation + + def forward(self, x, mask=None): + residual = self.self_attn(x, mask=mask) + residual = torch.cat(residual, dim=-1) + residual = self.residual_dropout_attn(residual) + x = self.norm_attn(x + residual) + residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x)))) + residual = torch.cat(residual, dim=-1) + residual = self.residual_dropout_ff(residual) + x = self.norm_ff(x + residual) + return x + + +class PartitionedTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, n_layers): + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(n_layers)] + ) + + def forward(self, x, mask=None): + for layer in self.layers: + x = layer(x, mask=mask) + return x + + +class ConcatPositionalEncoding(nn.Module): + def __init__(self, d_model=256, max_len=512): + super().__init__() + self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model // 2)) + nn.init.normal_(self.timing_table) + self.norm = nn.LayerNorm(d_model) + + def forward(self, x): + timing = self.timing_table[None, : x.shape[1], :] + x, timing = torch.broadcast_tensors(x, timing) + out = torch.cat([x, timing], dim=-1) + out = self.norm(out) + return out diff --git a/benepar/ptb_unescape.py b/benepar/ptb_unescape.py new file mode 100644 index 0000000000000000000000000000000000000000..b9403d492257003c9145c494314b6e670da61fcc --- /dev/null +++ b/benepar/ptb_unescape.py @@ -0,0 +1,83 @@ +PTB_UNESCAPE_MAPPING = { + "«": '"', + "»": '"', + "‘": "'", + "’": "'", + "“": '"', + "”": '"', + "„": '"', + "‹": "'", + "›": "'", + "\u2013": "--", # en dash + "\u2014": "--", # em dash +} + +NO_SPACE_BEFORE = {"-RRB-", "-RCB-", "-RSB-", "''"} | set("%.,!?:;") +NO_SPACE_AFTER = {"-LRB-", "-LCB-", "-LSB-", "``", "`"} | set("$#") +NO_SPACE_BEFORE_TOKENS_ENGLISH = {"'", "'s", "'ll", "'re", "'d", "'m", "'ve"} +PTB_DASH_ESCAPED = {"-RRB-", "-RCB-", "-RSB-", "-LRB-", "-LCB-", "-LSB-", "--"} + + +def ptb_unescape(words): + cleaned_words = [] + for word in words: + word = PTB_UNESCAPE_MAPPING.get(word, word) + # This un-escaping for / and * was not yet added for the + # parser version in https://arxiv.org/abs/1812.11760v1 + # and related model releases (e.g. benepar_en2) + word = word.replace("\\/", "/").replace("\\*", "*") + # Mid-token punctuation occurs in biomedical text + word = word.replace("-LSB-", "[").replace("-RSB-", "]") + word = word.replace("-LRB-", "(").replace("-RRB-", ")") + word = word.replace("-LCB-", "{").replace("-RCB-", "}") + word = word.replace("``", '"').replace("`", "'").replace("''", '"') + cleaned_words.append(word) + return cleaned_words + + +def guess_space_after_non_english(escaped_words): + sp_after = [True for _ in escaped_words] + for i, word in enumerate(escaped_words): + if i > 0 and ( + ( + word.startswith("-") + and not any(word.startswith(x) for x in PTB_DASH_ESCAPED) + ) + or any(word.startswith(x) for x in NO_SPACE_BEFORE) + or word == "'" + ): + sp_after[i - 1] = False + if ( + word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED) + ) or any(word.endswith(x) for x in NO_SPACE_AFTER): + sp_after[i] = False + + return sp_after + + +def guess_space_after(escaped_words, for_english=True): + if not for_english: + return guess_space_after_non_english(escaped_words) + + sp_after = [True for _ in escaped_words] + for i, word in enumerate(escaped_words): + if word.lower() == "n't" and i > 0: + sp_after[i - 1] = False + elif word.lower() == "not" and i > 0 and escaped_words[i - 1].lower() == "can": + sp_after[i - 1] = False + + if i > 0 and ( + ( + word.startswith("-") + and not any(word.startswith(x) for x in PTB_DASH_ESCAPED) + ) + or any(word.startswith(x) for x in NO_SPACE_BEFORE) + or word.lower() in NO_SPACE_BEFORE_TOKENS_ENGLISH + ): + sp_after[i - 1] = False + if ( + word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED) + ) or any(word.endswith(x) for x in NO_SPACE_AFTER): + sp_after[i] = False + + return sp_after diff --git a/benepar/retokenization.py b/benepar/retokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..42f77188c5faf721f0587aeeac6da302ac3d8be3 --- /dev/null +++ b/benepar/retokenization.py @@ -0,0 +1,258 @@ +""" +Converts from linguistically motivated word-based tokenization to subword +tokenization used by pre-trained models. +""" + +import numpy as np +import torch +import transformers + + +def retokenize( + tokenizer, + words, + space_after, + return_attention_mask=True, + return_offsets_mapping=False, + return_tensors=None, + **kwargs +): + """Re-tokenize into subwords. + + Args: + tokenizer: An instance of transformers.PreTrainedTokenizerFast + words: List of words + space_after: A list of the same length as `words`, indicating whether + whitespace follows each word. + **kwargs: all remaining arguments are passed on to tokenizer.__call__ + + Returns: + The output of tokenizer.__call__, with one additional dictionary field: + - **words_from_tokens** -- List of the same length as `words`, where + each entry is the index of the *last* subword that overlaps the + corresponding word. + """ + s = "".join([w + (" " if sp else "") for w, sp in zip(words, space_after)]) + word_offset_starts = np.cumsum( + [0] + [len(w) + (1 if sp else 0) for w, sp in zip(words, space_after)] + )[:-1] + word_offset_ends = word_offset_starts + np.asarray([len(w) for w in words]) + + tokenized = tokenizer( + s, + return_attention_mask=return_attention_mask, + return_offsets_mapping=True, + return_tensors=return_tensors, + **kwargs + ) + if return_offsets_mapping: + token_offset_mapping = tokenized["offset_mapping"] + else: + token_offset_mapping = tokenized.pop("offset_mapping") + if return_tensors is not None: + token_offset_mapping = np.asarray(token_offset_mapping)[0].tolist() + + offset_mapping_iter = iter( + [ + (i, (start, end)) + for (i, (start, end)) in enumerate(token_offset_mapping) + if start != end + ] + ) + token_idx, (token_start, token_end) = next(offset_mapping_iter) + words_from_tokens = [-100] * len(words) + for word_idx, (word_start, word_end) in enumerate( + zip(word_offset_starts, word_offset_ends) + ): + while token_end <= word_start: + token_idx, (token_start, token_end) = next(offset_mapping_iter) + if token_end > word_end: + words_from_tokens[word_idx] = token_idx + while token_end <= word_end: + words_from_tokens[word_idx] = token_idx + try: + token_idx, (token_start, token_end) = next(offset_mapping_iter) + except StopIteration: + assert word_idx == len(words) - 1 + break + if return_tensors == "np": + words_from_tokens = np.asarray(words_from_tokens, dtype=int) + elif return_tensors == "pt": + words_from_tokens = torch.tensor(words_from_tokens, dtype=torch.long) + elif return_tensors == "tf": + raise NotImplementedError("Returning tf tensors is not implemented") + tokenized["words_from_tokens"] = words_from_tokens + return tokenized + + +class Retokenizer: + def __init__(self, pretrained_model_name_or_path, retain_start_stop=False): + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, fast=True + ) + if not self.tokenizer.is_fast: + raise NotImplementedError( + "Converting from treebank tokenization to tokenization used by a " + "pre-trained model requires a 'fast' tokenizer, which appears to not " + "be available for this pre-trained model type." + ) + self.retain_start_stop = retain_start_stop + self.is_t5 = "T5Tokenizer" in str(type(self.tokenizer)) + self.is_gpt2 = "GPT2Tokenizer" in str(type(self.tokenizer)) + + if self.is_gpt2: + # The provided GPT-2 tokenizer does not specify a padding token by default + self.tokenizer.pad_token = self.tokenizer.eos_token + + if self.retain_start_stop: + # When retain_start_stop is set, the next layer after the pre-trained model + # expects start and stop token embeddings. For BERT these can naturally be + # the feature vectors for CLS and SEP, but pre-trained models differ in the + # special tokens that they use. This code attempts to find special token + # positions for each pre-trained model. + dummy_ids = self.tokenizer.build_inputs_with_special_tokens([-100]) + if self.is_t5: + # For T5 we use the output from the decoder, which accepts inputs that + # are shifted relative to the encoder. + dummy_ids = [self.tokenizer.pad_token_id] + dummy_ids + if self.is_gpt2: + # For GPT-2, we append an eos token if special tokens are needed + dummy_ids = dummy_ids + [self.tokenizer.eos_token_id] + try: + input_idx = dummy_ids.index(-100) + except ValueError: + raise NotImplementedError( + "Could not automatically infer how to extract start/stop tokens " + "from this pre-trained model" + ) + num_prefix_tokens = input_idx + num_suffix_tokens = len(dummy_ids) - input_idx - 1 + self.start_token_idx = None + self.stop_token_idx = None + if num_prefix_tokens > 0: + self.start_token_idx = num_prefix_tokens - 1 + if num_suffix_tokens > 0: + self.stop_token_idx = -num_suffix_tokens + if self.start_token_idx is None and num_suffix_tokens > 0: + self.start_token_idx = -1 + if self.stop_token_idx is None and num_prefix_tokens > 0: + self.stop_token_idx = 0 + if self.start_token_idx is None or self.stop_token_idx is None: + assert num_prefix_tokens == 0 and num_suffix_tokens == 0 + raise NotImplementedError( + "Could not automatically infer how to extract start/stop tokens " + "from this pre-trained model because the associated tokenizer " + "appears not to add any special start/stop/cls/sep/etc. tokens " + "to the sequence." + ) + + def __call__(self, words, space_after, **kwargs): + example = retokenize(self.tokenizer, words, space_after, **kwargs) + if self.is_t5: + # decoder_input_ids (which are shifted wrt input_ids) will be created after + # padding, but we adjust words_from_tokens now, in anticipation. + if isinstance(example["words_from_tokens"], list): + example["words_from_tokens"] = [ + x + 1 for x in example["words_from_tokens"] + ] + else: + example["words_from_tokens"] += 1 + if self.retain_start_stop: + num_tokens = len(example["input_ids"]) + if self.is_t5: + num_tokens += 1 + if self.is_gpt2: + num_tokens += 1 + if kwargs.get("return_tensors") == "pt": + example["input_ids"] = torch.cat( + example["input_ids"], + torch.tensor([self.tokenizer.eos_token_id]), + ) + example["attention_mask"] = torch.cat( + example["attention_mask"], torch.tensor([1]) + ) + else: + example["input_ids"].append(self.tokenizer.eos_token_id) + example["attention_mask"].append(1) + if num_tokens > self.tokenizer.model_max_length: + raise ValueError( + f"Sentence of length {num_tokens} (in sub-word tokens) exceeds the " + f"maximum supported length of {self.tokenizer.model_max_length}" + ) + start_token_idx = ( + self.start_token_idx + if self.start_token_idx >= 0 + else num_tokens + self.start_token_idx + ) + stop_token_idx = ( + self.stop_token_idx + if self.stop_token_idx >= 0 + else num_tokens + self.stop_token_idx + ) + if kwargs.get("return_tensors") == "pt": + example["words_from_tokens"] = torch.cat( + [ + torch.tensor([start_token_idx]), + example["words_from_tokens"], + torch.tensor([stop_token_idx]), + ] + ) + else: + example["words_from_tokens"] = ( + [start_token_idx] + example["words_from_tokens"] + [stop_token_idx] + ) + return example + + def pad(self, encoded_inputs, return_tensors=None, **kwargs): + if return_tensors != "pt": + raise NotImplementedError("Only return_tensors='pt' is supported.") + res = self.tokenizer.pad( + [ + {k: v for k, v in example.items() if k != "words_from_tokens"} + for example in encoded_inputs + ], + return_tensors=return_tensors, + **kwargs + ) + if self.tokenizer.padding_side == "right": + res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence( + [ + torch.tensor(example["words_from_tokens"]) + for example in encoded_inputs + ], + batch_first=True, + padding_value=-100, + ) + else: + # XLNet adds padding tokens on the left of the sequence, so + # words_from_tokens must be adjusted to skip the added padding tokens. + assert self.tokenizer.padding_side == "left" + res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence( + [ + torch.tensor(example["words_from_tokens"]) + + (res["input_ids"].shape[-1] - len(example["input_ids"])) + for example in encoded_inputs + ], + batch_first=True, + padding_value=-100, + ) + + if self.is_t5: + res["decoder_input_ids"] = torch.cat( + [ + torch.full_like( + res["input_ids"][:, :1], self.tokenizer.pad_token_id + ), + res["input_ids"], + ], + 1, + ) + res["decoder_attention_mask"] = torch.cat( + [ + torch.ones_like(res["attention_mask"][:, :1]), + res["attention_mask"], + ], + 1, + ) + res["valid_token_mask"] = res["words_from_tokens"] != -100 + return res diff --git a/benepar/spacy_plugin.py b/benepar/spacy_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..3923cc46064c5f098a2276a4312c09bc65b8891a --- /dev/null +++ b/benepar/spacy_plugin.py @@ -0,0 +1,13 @@ +__all__ = ["BeneparComponent", "NonConstituentException"] + +import warnings + +from .integrations.spacy_plugin import BeneparComponent, NonConstituentException + +warnings.warn( + "BeneparComponent and NonConstituentException have been moved to the benepar " + "module. Use `from benepar import BeneparComponent, NonConstituentException` " + "instead of benepar.spacy_plugin. The benepar.spacy_plugin namespace is deprecated " + "and will be removed in a future version.", + FutureWarning, +) diff --git a/benepar/subbatching.py b/benepar/subbatching.py new file mode 100644 index 0000000000000000000000000000000000000000..53bed87ce8743034a670e358acc947709c57ee3d --- /dev/null +++ b/benepar/subbatching.py @@ -0,0 +1,62 @@ +""" +Utilities for splitting batches of examples into smaller sub-batches. + +This is useful during training when the batch size is too large to fit on GPU, +meaning that gradient accumulation across multiple sub-batches must be used. +It is also useful for batching examples during evaluation. Unlike a naive +approach, this code groups examples with similar lengths to reduce the amount +of wasted computation due to padding. +""" + +import numpy as np + + +def split(*data, costs, max_cost): + """Splits a batch of input items into sub-batches. + + Args: + *data: One or more lists of input items, all of the same length + costs: A list of costs for each item + max_cost: Maximum total cost for each sub-batch + + Yields: + (example_ids, *subbatch_data) tuples. + """ + costs = np.asarray(costs, dtype=int) + costs_argsort = np.argsort(costs).tolist() + + subbatch_size = 1 + while costs_argsort: + if subbatch_size == len(costs_argsort) or ( + subbatch_size * costs[costs_argsort[subbatch_size]] > max_cost + ): + subbatch_item_ids = costs_argsort[:subbatch_size] + subbatch_data = [[items[i] for i in subbatch_item_ids] for items in data] + yield (subbatch_item_ids,) + tuple(subbatch_data) + costs_argsort = costs_argsort[subbatch_size:] + subbatch_size = 1 + else: + subbatch_size += 1 + + +def map(func, *data, costs, max_cost, **common_kwargs): + """Maps a function over subbatches of input items. + + Args: + func: Function to map over the data + *data: One or more lists of input items, all of the same length. + costs: A list of costs for each item + max_cost: Maximum total cost for each sub-batch + **common_kwargs: Keyword arguments to pass to all calls of func + + Returns: + A list of outputs from calling func(*subbatch_data, **kwargs) for each + subbatch, and then rearranging the outputs from func into the original + item order. + """ + res = [None] * len(data[0]) + for item_ids, *subbatch_items in split(*data, costs=costs, max_cost=max_cost): + subbatch_out = func(*subbatch_items, **common_kwargs) + for item_id, item_out in zip(item_ids, subbatch_out): + res[item_id] = item_out + return res diff --git a/parse.py b/parse.py index 9261a3bd3bb07cfd7490694be920d1da81eaebee..855f3f75234744bb8cfbfd8855b6e91dc9c2d64d 100644 --- a/parse.py +++ b/parse.py @@ -2,12 +2,10 @@ import re import sys import benepar from huggingface_hub import hf_hub_download - -model_path = "ParserModels/ENHG/new-convbert-german-europeana0_dev=83.03.pt" - hf_hub_download(repo_id=model_path, filename='german-delex-parser_dev=83.10.pt') -parser = benepar.Parser(model_path) def parse(words): + model_path = hf_hub_download(repo_id="nielklug/enhg_parser", filename='new-convbert-german-europeana0_dev=83.03.pt') + parser = benepar.Parser(model_path) words = [word.replace('(','-LRB-').replace(')','-RRB-') for word in words] input_sentence = benepar.InputSentence(words=words) tree = parser.parse(input_sentence) @@ -17,11 +15,3 @@ def parse(words): tree = re.sub(r' \(', '(', tree) return tree - -with open(sys.argv[1]) as file: - for line in file: - line = re.sub(r'(\S)([.,;:?!)"])', r'\1 \2', line.strip()) - line = re.sub(r'(["(])(\S)', r'\1 \2', line) - words = line.split() - tree = parse(words) - print(tree)