nielklug commited on
Commit
8778cfe
1 Parent(s): 7884ed6

add parsing

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/parse.cpython-38.pyc +0 -0
  2. app.py +11 -9
  3. benepar/__init__.py +20 -0
  4. benepar/__pycache__/__init__.cpython-310.pyc +0 -0
  5. benepar/__pycache__/__init__.cpython-37.pyc +0 -0
  6. benepar/__pycache__/__init__.cpython-38.pyc +0 -0
  7. benepar/__pycache__/char_lstm.cpython-310.pyc +0 -0
  8. benepar/__pycache__/char_lstm.cpython-37.pyc +0 -0
  9. benepar/__pycache__/char_lstm.cpython-38.pyc +0 -0
  10. benepar/__pycache__/decode_chart.cpython-310.pyc +0 -0
  11. benepar/__pycache__/decode_chart.cpython-37.pyc +0 -0
  12. benepar/__pycache__/decode_chart.cpython-38.pyc +0 -0
  13. benepar/__pycache__/nkutil.cpython-310.pyc +0 -0
  14. benepar/__pycache__/nkutil.cpython-37.pyc +0 -0
  15. benepar/__pycache__/nkutil.cpython-38.pyc +0 -0
  16. benepar/__pycache__/parse_base.cpython-310.pyc +0 -0
  17. benepar/__pycache__/parse_base.cpython-37.pyc +0 -0
  18. benepar/__pycache__/parse_base.cpython-38.pyc +0 -0
  19. benepar/__pycache__/parse_chart.cpython-310.pyc +0 -0
  20. benepar/__pycache__/parse_chart.cpython-37.pyc +0 -0
  21. benepar/__pycache__/parse_chart.cpython-38.pyc +0 -0
  22. benepar/__pycache__/partitioned_transformer.cpython-310.pyc +0 -0
  23. benepar/__pycache__/partitioned_transformer.cpython-37.pyc +0 -0
  24. benepar/__pycache__/partitioned_transformer.cpython-38.pyc +0 -0
  25. benepar/__pycache__/ptb_unescape.cpython-310.pyc +0 -0
  26. benepar/__pycache__/ptb_unescape.cpython-37.pyc +0 -0
  27. benepar/__pycache__/ptb_unescape.cpython-38.pyc +0 -0
  28. benepar/__pycache__/retokenization.cpython-310.pyc +0 -0
  29. benepar/__pycache__/retokenization.cpython-37.pyc +0 -0
  30. benepar/__pycache__/retokenization.cpython-38.pyc +0 -0
  31. benepar/__pycache__/subbatching.cpython-310.pyc +0 -0
  32. benepar/__pycache__/subbatching.cpython-37.pyc +0 -0
  33. benepar/__pycache__/subbatching.cpython-38.pyc +0 -0
  34. benepar/char_lstm.py +160 -0
  35. benepar/decode_chart.py +291 -0
  36. benepar/decode_chart.py~ +291 -0
  37. benepar/integrations/__init__.py +0 -0
  38. benepar/integrations/__pycache__/__init__.cpython-310.pyc +0 -0
  39. benepar/integrations/__pycache__/__init__.cpython-37.pyc +0 -0
  40. benepar/integrations/__pycache__/__init__.cpython-38.pyc +0 -0
  41. benepar/integrations/__pycache__/downloader.cpython-310.pyc +0 -0
  42. benepar/integrations/__pycache__/downloader.cpython-37.pyc +0 -0
  43. benepar/integrations/__pycache__/downloader.cpython-38.pyc +0 -0
  44. benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc +0 -0
  45. benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc +0 -0
  46. benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc +0 -0
  47. benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc +0 -0
  48. benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc +0 -0
  49. benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc +0 -0
  50. benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc +0 -0
__pycache__/parse.cpython-38.pyc CHANGED
Binary files a/__pycache__/parse.cpython-38.pyc and b/__pycache__/parse.cpython-38.pyc differ
 
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- # from parse import parse_text
3
  from nltk import Tree
4
  import pandas as pd
5
  import re
@@ -31,19 +31,21 @@ if text:
31
 
32
  df = pd.DataFrame(zipped, columns=['Token', 'Tag', 'Prob.'])
33
 
34
- # # Convert the bracket parse tree into an NLTK Tree
35
- # t = Tree.fromstring(re.sub(r'(\.[^ )]+)+', '', parse_tree))
36
 
37
- # tree_svg = TreePrettyPrinter(t).svg(nodecolor='black', leafcolor='black', funccolor='black')
 
 
 
38
 
39
  col1 = st.columns(1)[0]
40
  col1.header("POS tagging result:")
41
  col1.table(df)
42
 
43
- # col2 = st.columns(1)[0]
44
- # col2.header("Parsing result:")
45
- # col2.write(parse_tree.replace('_', '\_').replace('$', '\$').replace('*', '\*'))
46
 
47
- # # Display the graph in the Streamlit app
48
- # col2.image(tree_svg, use_column_width=True)
49
 
 
1
  import streamlit as st
2
+ from parse import parse
3
  from nltk import Tree
4
  import pandas as pd
5
  import re
 
31
 
32
  df = pd.DataFrame(zipped, columns=['Token', 'Tag', 'Prob.'])
33
 
34
+ parse_tree = parse(tokens)
 
35
 
36
+ # Convert the bracket parse tree into an NLTK Tree
37
+ t = Tree.fromstring(re.sub(r'-[^ )]*', '', parse_tree))
38
+
39
+ tree_svg = TreePrettyPrinter(t).svg(nodecolor='black', leafcolor='black', funccolor='black')
40
 
41
  col1 = st.columns(1)[0]
42
  col1.header("POS tagging result:")
43
  col1.table(df)
44
 
45
+ col2 = st.columns(1)[0]
46
+ col2.header("Parsing result:")
47
+ col2.write(parse_tree.replace('_', '\_').replace('$', '\$').replace('*', '\*'))
48
 
49
+ # Display the graph in the Streamlit app
50
+ col2.image(tree_svg, use_column_width=True)
51
 
benepar/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benepar: Berkeley Neural Parser
3
+ """
4
+
5
+ # This file and all code in integrations/ relate to the version of the parser
6
+ # released via PyPI. If you only need to run research experiments, it is safe
7
+ # to delete the integrations/ folder and replace this __init__.py with an
8
+ # empty file.
9
+
10
+ __all__ = [
11
+ "Parser",
12
+ "InputSentence",
13
+ "download",
14
+ "BeneparComponent",
15
+ "NonConstituentException",
16
+ ]
17
+
18
+ from .integrations.downloader import download
19
+ from .integrations.nltk_plugin import Parser, InputSentence
20
+ from .integrations.spacy_plugin import BeneparComponent, NonConstituentException
benepar/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (526 Bytes). View file
 
benepar/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (521 Bytes). View file
 
benepar/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (505 Bytes). View file
 
benepar/__pycache__/char_lstm.cpython-310.pyc ADDED
Binary file (4.94 kB). View file
 
benepar/__pycache__/char_lstm.cpython-37.pyc ADDED
Binary file (4.92 kB). View file
 
benepar/__pycache__/char_lstm.cpython-38.pyc ADDED
Binary file (4.96 kB). View file
 
benepar/__pycache__/decode_chart.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
benepar/__pycache__/decode_chart.cpython-37.pyc ADDED
Binary file (10.2 kB). View file
 
benepar/__pycache__/decode_chart.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
benepar/__pycache__/nkutil.cpython-310.pyc ADDED
Binary file (2.14 kB). View file
 
benepar/__pycache__/nkutil.cpython-37.pyc ADDED
Binary file (2.1 kB). View file
 
benepar/__pycache__/nkutil.cpython-38.pyc ADDED
Binary file (2.09 kB). View file
 
benepar/__pycache__/parse_base.cpython-310.pyc ADDED
Binary file (7.38 kB). View file
 
benepar/__pycache__/parse_base.cpython-37.pyc ADDED
Binary file (7.16 kB). View file
 
benepar/__pycache__/parse_base.cpython-38.pyc ADDED
Binary file (7.26 kB). View file
 
benepar/__pycache__/parse_chart.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
benepar/__pycache__/parse_chart.cpython-37.pyc ADDED
Binary file (11 kB). View file
 
benepar/__pycache__/parse_chart.cpython-38.pyc ADDED
Binary file (11.1 kB). View file
 
benepar/__pycache__/partitioned_transformer.cpython-310.pyc ADDED
Binary file (7.83 kB). View file
 
benepar/__pycache__/partitioned_transformer.cpython-37.pyc ADDED
Binary file (7.9 kB). View file
 
benepar/__pycache__/partitioned_transformer.cpython-38.pyc ADDED
Binary file (7.82 kB). View file
 
benepar/__pycache__/ptb_unescape.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
benepar/__pycache__/ptb_unescape.cpython-37.pyc ADDED
Binary file (3.2 kB). View file
 
benepar/__pycache__/ptb_unescape.cpython-38.pyc ADDED
Binary file (3.19 kB). View file
 
benepar/__pycache__/retokenization.cpython-310.pyc ADDED
Binary file (6.83 kB). View file
 
benepar/__pycache__/retokenization.cpython-37.pyc ADDED
Binary file (6.73 kB). View file
 
benepar/__pycache__/retokenization.cpython-38.pyc ADDED
Binary file (6.83 kB). View file
 
benepar/__pycache__/subbatching.cpython-310.pyc ADDED
Binary file (2.53 kB). View file
 
benepar/__pycache__/subbatching.cpython-37.pyc ADDED
Binary file (2.49 kB). View file
 
benepar/__pycache__/subbatching.cpython-38.pyc ADDED
Binary file (2.48 kB). View file
 
benepar/char_lstm.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Character LSTM implementation (matches https://arxiv.org/pdf/1805.01052.pdf)
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class CharacterLSTM(nn.Module):
12
+ def __init__(self, num_embeddings, d_embedding, d_out, char_dropout=0.0, **kwargs):
13
+ super().__init__()
14
+
15
+ self.d_embedding = d_embedding
16
+ self.d_out = d_out
17
+
18
+ self.lstm = nn.LSTM(
19
+ self.d_embedding, self.d_out // 2, num_layers=1, bidirectional=True
20
+ )
21
+
22
+ self.emb = nn.Embedding(num_embeddings, self.d_embedding, **kwargs)
23
+ self.char_dropout = nn.Dropout(char_dropout)
24
+
25
+ def forward(self, chars_packed, valid_token_mask):
26
+ inp_embs = nn.utils.rnn.PackedSequence(
27
+ self.char_dropout(self.emb(chars_packed.data)),
28
+ batch_sizes=chars_packed.batch_sizes,
29
+ sorted_indices=chars_packed.sorted_indices,
30
+ unsorted_indices=chars_packed.unsorted_indices,
31
+ )
32
+
33
+ _, (lstm_out, _) = self.lstm(inp_embs)
34
+ lstm_out = torch.cat([lstm_out[0], lstm_out[1]], -1)
35
+
36
+ # Switch to a representation where there are dummy vectors for invalid
37
+ # tokens generated by padding.
38
+ res = lstm_out.new_zeros(
39
+ (valid_token_mask.shape[0], valid_token_mask.shape[1], lstm_out.shape[-1])
40
+ )
41
+ res[valid_token_mask] = lstm_out
42
+ return res
43
+
44
+
45
+ class RetokenizerForCharLSTM:
46
+ # Assumes that these control characters are not present in treebank text
47
+ CHAR_UNK = "\0"
48
+ CHAR_ID_UNK = 0
49
+ CHAR_START_SENTENCE = "\1"
50
+ CHAR_START_WORD = "\2"
51
+ CHAR_STOP_WORD = "\3"
52
+ CHAR_STOP_SENTENCE = "\4"
53
+
54
+ def __init__(self, char_vocab):
55
+ self.char_vocab = char_vocab
56
+
57
+ @classmethod
58
+ def build_vocab(cls, sentences):
59
+ char_set = set()
60
+ for sentence in sentences:
61
+ if isinstance(sentence, tuple):
62
+ sentence = sentence[0]
63
+ for word in sentence:
64
+ char_set |= set(word)
65
+
66
+ # If codepoints are small (e.g. Latin alphabet), index by codepoint
67
+ # directly
68
+ highest_codepoint = max(ord(char) for char in char_set)
69
+ if highest_codepoint < 512:
70
+ if highest_codepoint < 256:
71
+ highest_codepoint = 256
72
+ else:
73
+ highest_codepoint = 512
74
+
75
+ char_vocab = {}
76
+ # This also takes care of constants like CHAR_UNK, etc.
77
+ for codepoint in range(highest_codepoint):
78
+ char_vocab[chr(codepoint)] = codepoint
79
+ return char_vocab
80
+ else:
81
+ char_vocab = {}
82
+ char_vocab[cls.CHAR_UNK] = 0
83
+ char_vocab[cls.CHAR_START_SENTENCE] = 1
84
+ char_vocab[cls.CHAR_START_WORD] = 2
85
+ char_vocab[cls.CHAR_STOP_WORD] = 3
86
+ char_vocab[cls.CHAR_STOP_SENTENCE] = 4
87
+ for id_, char in enumerate(sorted(char_set), start=5):
88
+ char_vocab[char] = id_
89
+ return char_vocab
90
+
91
+ def __call__(self, words, space_after="ignored", return_tensors=None):
92
+ if return_tensors != "np":
93
+ raise NotImplementedError("Only return_tensors='np' is supported.")
94
+
95
+ res = {}
96
+
97
+ # Sentence-level start/stop tokens are encoded as 3 pseudo-chars
98
+ # Within each word, account for 2 start/stop characters
99
+ max_word_len = max(3, max(len(word) for word in words)) + 2
100
+ char_ids = np.zeros((len(words) + 2, max_word_len), dtype=int)
101
+ word_lens = np.zeros(len(words) + 2, dtype=int)
102
+
103
+ char_ids[0, :5] = [
104
+ self.char_vocab[self.CHAR_START_WORD],
105
+ self.char_vocab[self.CHAR_START_SENTENCE],
106
+ self.char_vocab[self.CHAR_START_SENTENCE],
107
+ self.char_vocab[self.CHAR_START_SENTENCE],
108
+ self.char_vocab[self.CHAR_STOP_WORD],
109
+ ]
110
+ word_lens[0] = 5
111
+ for i, word in enumerate(words, start=1):
112
+ char_ids[i, 0] = self.char_vocab[self.CHAR_START_WORD]
113
+ for j, char in enumerate(word, start=1):
114
+ char_ids[i, j] = self.char_vocab.get(char, self.CHAR_ID_UNK)
115
+ char_ids[i, j + 1] = self.char_vocab[self.CHAR_STOP_WORD]
116
+ word_lens[i] = j + 2
117
+ char_ids[i + 1, :5] = [
118
+ self.char_vocab[self.CHAR_START_WORD],
119
+ self.char_vocab[self.CHAR_STOP_SENTENCE],
120
+ self.char_vocab[self.CHAR_STOP_SENTENCE],
121
+ self.char_vocab[self.CHAR_STOP_SENTENCE],
122
+ self.char_vocab[self.CHAR_STOP_WORD],
123
+ ]
124
+ word_lens[i + 1] = 5
125
+
126
+ res["char_ids"] = char_ids
127
+ res["word_lens"] = word_lens
128
+ res["valid_token_mask"] = np.ones_like(word_lens, dtype=bool)
129
+
130
+ return res
131
+
132
+ def pad(self, examples, return_tensors=None):
133
+ if return_tensors != "pt":
134
+ raise NotImplementedError("Only return_tensors='pt' is supported.")
135
+ max_word_len = max(example["char_ids"].shape[-1] for example in examples)
136
+ char_ids = torch.cat(
137
+ [
138
+ F.pad(
139
+ torch.tensor(example["char_ids"]),
140
+ (0, max_word_len - example["char_ids"].shape[-1]),
141
+ )
142
+ for example in examples
143
+ ]
144
+ )
145
+ word_lens = torch.cat(
146
+ [torch.tensor(example["word_lens"]) for example in examples]
147
+ )
148
+ valid_token_mask = nn.utils.rnn.pad_sequence(
149
+ [torch.tensor(example["valid_token_mask"]) for example in examples],
150
+ batch_first=True,
151
+ padding_value=False,
152
+ )
153
+
154
+ char_ids = nn.utils.rnn.pack_padded_sequence(
155
+ char_ids, word_lens, batch_first=True, enforce_sorted=False
156
+ )
157
+ return {
158
+ "char_ids": char_ids,
159
+ "valid_token_mask": valid_token_mask,
160
+ }
benepar/decode_chart.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parsing formulated as span classification (https://arxiv.org/abs/1705.03919)
3
+ """
4
+
5
+ import nltk
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch_struct
11
+
12
+ from .parse_base import CompressedParserOutput
13
+
14
+
15
+ def pad_charts(charts, padding_value=-100):
16
+ """Pad a list of variable-length charts with `padding_value`."""
17
+ batch_size = len(charts)
18
+ max_len = max(chart.shape[0] for chart in charts)
19
+ padded_charts = torch.full(
20
+ (batch_size, max_len, max_len),
21
+ padding_value,
22
+ dtype=charts[0].dtype,
23
+ device=charts[0].device,
24
+ )
25
+ for i, chart in enumerate(charts):
26
+ chart_size = chart.shape[0]
27
+ padded_charts[i, :chart_size, :chart_size] = chart
28
+ return padded_charts
29
+
30
+
31
+ def collapse_unary_strip_pos(tree, strip_top=True):
32
+ """Collapse unary chains and strip part of speech tags."""
33
+
34
+ def strip_pos(tree):
35
+ if len(tree) == 1 and isinstance(tree[0], str):
36
+ return tree[0]
37
+ else:
38
+ return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])
39
+
40
+ collapsed_tree = strip_pos(tree)
41
+ collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
42
+ if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
43
+ if strip_top:
44
+ if len(collapsed_tree) == 1:
45
+ collapsed_tree = collapsed_tree[0]
46
+ else:
47
+ collapsed_tree.set_label("")
48
+ elif len(collapsed_tree) == 1:
49
+ collapsed_tree[0].set_label(
50
+ collapsed_tree.label() + "::" + collapsed_tree[0].label())
51
+ collapsed_tree = collapsed_tree[0]
52
+ return collapsed_tree
53
+
54
+
55
+ def _get_labeled_spans(tree, spans_out, start):
56
+ if isinstance(tree, str):
57
+ return start + 1
58
+
59
+ assert len(tree) > 1 or isinstance(
60
+ tree[0], str
61
+ ), "Must call collapse_unary_strip_pos first"
62
+ end = start
63
+ for child in tree:
64
+ end = _get_labeled_spans(child, spans_out, end)
65
+ # Spans are returned as closed intervals on both ends
66
+ spans_out.append((start, end - 1, tree.label()))
67
+ return end
68
+
69
+
70
+ def get_labeled_spans(tree):
71
+ """Converts a tree into a list of labeled spans.
72
+
73
+ Args:
74
+ tree: an nltk.tree.Tree object
75
+
76
+ Returns:
77
+ A list of (span_start, span_end, span_label) tuples. The start and end
78
+ indices indicate the first and last words of the span (a closed
79
+ interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
80
+ result in a single span labeled "S+VP".
81
+ """
82
+ tree = collapse_unary_strip_pos(tree)
83
+ spans_out = []
84
+ _get_labeled_spans(tree, spans_out, start=0)
85
+ return spans_out
86
+
87
+
88
+ def uncollapse_unary(tree, ensure_top=False):
89
+ """Un-collapse unary chains."""
90
+ if isinstance(tree, str):
91
+ return tree
92
+ else:
93
+ labels = tree.label().split("::")
94
+ if ensure_top and labels[0] != "TOP":
95
+ labels = ["TOP"] + labels
96
+ children = []
97
+ for child in tree:
98
+ child = uncollapse_unary(child)
99
+ children.append(child)
100
+ for label in labels[::-1]:
101
+ children = [nltk.tree.Tree(label, children)]
102
+ return children[0]
103
+
104
+
105
+ class ChartDecoder:
106
+ """A chart decoder for parsing formulated as span classification."""
107
+
108
+ def __init__(self, label_vocab, force_root_constituent=True):
109
+ """Constructs a new ChartDecoder object.
110
+ Args:
111
+ label_vocab: A mapping from span labels to integer indices.
112
+ """
113
+ self.label_vocab = label_vocab
114
+ self.label_from_index = {i: label for label, i in label_vocab.items()}
115
+ self.force_root_constituent = force_root_constituent
116
+
117
+ @staticmethod
118
+ def build_vocab(trees):
119
+ label_set = set()
120
+ for tree in trees:
121
+ for _, _, label in get_labeled_spans(tree):
122
+ if label:
123
+ label_set.add(label)
124
+ label_set = [""] + sorted(label_set)
125
+ return {label: i for i, label in enumerate(label_set)}
126
+
127
+ @staticmethod
128
+ def infer_force_root_constituent(trees):
129
+ for tree in trees:
130
+ for _, _, label in get_labeled_spans(tree):
131
+ if not label:
132
+ return False
133
+ return True
134
+
135
+ def chart_from_tree(self, tree):
136
+ spans = get_labeled_spans(tree)
137
+ num_words = len(tree.leaves())
138
+ chart = np.full((num_words, num_words), -100, dtype=int)
139
+ chart = np.tril(chart, -1)
140
+ # Now all invalid entries are filled with -100, and valid entries with 0
141
+ for start, end, label in spans:
142
+ # Previously unseen unary chains can occur in the dev/test sets.
143
+ # For now, we ignore them and don't mark the corresponding chart
144
+ # entry as a constituent.
145
+ if label in self.label_vocab:
146
+ chart[start, end] = self.label_vocab[label]
147
+ return chart
148
+
149
+ def charts_from_pytorch_scores_batched(self, scores, lengths):
150
+ """Runs CKY to recover span labels from scores (e.g. logits).
151
+
152
+ This method uses pytorch-struct to speed up decoding compared to the
153
+ pure-Python implementation of CKY used by tree_from_scores().
154
+
155
+ Args:
156
+ scores: a pytorch tensor of shape (batch size, max length,
157
+ max length, label vocab size).
158
+ lengths: a pytorch tensor of shape (batch size,)
159
+
160
+ Returns:
161
+ A list of numpy arrays, each of shape (sentence length, sentence
162
+ length).
163
+ """
164
+ scores = scores.detach()
165
+ scores = scores - scores[..., :1]
166
+ if self.force_root_constituent:
167
+ scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9
168
+ dist = torch_struct.TreeCRF(scores, lengths=lengths)
169
+ amax = dist.argmax
170
+ amax[..., 0] += 1e-9
171
+ padded_charts = amax.argmax(-1)
172
+ padded_charts = padded_charts.detach().cpu().numpy()
173
+ return [
174
+ chart[:length, :length] for chart, length in zip(padded_charts, lengths)
175
+ ]
176
+
177
+ def compressed_output_from_chart(self, chart):
178
+ chart_with_filled_diagonal = chart.copy()
179
+ np.fill_diagonal(chart_with_filled_diagonal, 1)
180
+ chart_with_filled_diagonal[0, -1] = 1
181
+ starts, inclusive_ends = np.where(chart_with_filled_diagonal)
182
+ preorder_sort = np.lexsort((-inclusive_ends, starts))
183
+ starts = starts[preorder_sort]
184
+ inclusive_ends = inclusive_ends[preorder_sort]
185
+ labels = chart[starts, inclusive_ends]
186
+ ends = inclusive_ends + 1
187
+ return CompressedParserOutput(starts=starts, ends=ends, labels=labels)
188
+
189
+ def tree_from_chart(self, chart, leaves):
190
+ compressed_output = self.compressed_output_from_chart(chart)
191
+ return compressed_output.to_tree(leaves, self.label_from_index)
192
+
193
+ def tree_from_scores(self, scores, leaves):
194
+ """Runs CKY to decode a tree from scores (e.g. logits).
195
+
196
+ If speed is important, consider using charts_from_pytorch_scores_batched
197
+ followed by compressed_output_from_chart or tree_from_chart instead.
198
+
199
+ Args:
200
+ scores: a chart of scores (or logits) of shape
201
+ (sentence length, sentence length, label vocab size). The first
202
+ two dimensions may be padded to a longer length, but all padded
203
+ values will be ignored.
204
+ leaves: the leaf nodes to use in the constructed tree. These
205
+ may be of type str or nltk.Tree, or (word, tag) tuples that
206
+ will be used to construct the leaf node objects.
207
+
208
+ Returns:
209
+ An nltk.Tree object.
210
+ """
211
+ leaves = [
212
+ nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node
213
+ for node in leaves
214
+ ]
215
+
216
+ chart = {}
217
+ scores = scores - scores[:, :, 0, None]
218
+ for length in range(1, len(leaves) + 1):
219
+ for left in range(0, len(leaves) + 1 - length):
220
+ right = left + length
221
+
222
+ label_scores = scores[left, right - 1]
223
+ label_scores = label_scores - label_scores[0]
224
+
225
+ argmax_label_index = int(
226
+ label_scores.argmax()
227
+ if length < len(leaves) or not self.force_root_constituent
228
+ else label_scores[1:].argmax() + 1
229
+ )
230
+ argmax_label = self.label_from_index[argmax_label_index]
231
+ label = argmax_label
232
+ label_score = label_scores[argmax_label_index]
233
+
234
+ if length == 1:
235
+ tree = leaves[left]
236
+ if label:
237
+ tree = nltk.tree.Tree(label, [tree])
238
+ chart[left, right] = [tree], label_score
239
+ continue
240
+
241
+ best_split = max(
242
+ range(left + 1, right),
243
+ key=lambda split: (chart[left, split][1] + chart[split, right][1]),
244
+ )
245
+
246
+ left_trees, left_score = chart[left, best_split]
247
+ right_trees, right_score = chart[best_split, right]
248
+
249
+ children = left_trees + right_trees
250
+ if label:
251
+ children = [nltk.tree.Tree(label, children)]
252
+
253
+ chart[left, right] = (children, label_score + left_score + right_score)
254
+
255
+ children, score = chart[0, len(leaves)]
256
+ tree = nltk.tree.Tree("TOP", children)
257
+ tree = uncollapse_unary(tree)
258
+ return tree
259
+
260
+
261
+ class SpanClassificationMarginLoss(nn.Module):
262
+ def __init__(self, force_root_constituent=True, reduction="mean"):
263
+ super().__init__()
264
+ self.force_root_constituent = force_root_constituent
265
+ if reduction not in ("none", "mean", "sum"):
266
+ raise ValueError(f"Invalid value for reduction: {reduction}")
267
+ self.reduction = reduction
268
+
269
+ def forward(self, logits, labels):
270
+ gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1])
271
+
272
+ logits = logits - logits[..., :1]
273
+ lengths = (labels[:, 0, :] != -100).sum(-1)
274
+ augment = (1 - gold_event).to(torch.float)
275
+ if self.force_root_constituent:
276
+ augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9
277
+ dist = torch_struct.TreeCRF(logits + augment, lengths=lengths)
278
+
279
+ pred_score = dist.max
280
+ gold_score = (logits * gold_event).sum((1, 2, 3))
281
+
282
+ margin_losses = F.relu(pred_score - gold_score)
283
+
284
+ if self.reduction == "none":
285
+ return margin_losses
286
+ elif self.reduction == "mean":
287
+ return margin_losses.mean()
288
+ elif self.reduction == "sum":
289
+ return margin_losses.sum()
290
+ else:
291
+ assert False, f"Unexpected reduction: {self.reduction}"
benepar/decode_chart.py~ ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parsing formulated as span classification (https://arxiv.org/abs/1705.03919)
3
+ """
4
+
5
+ import nltk
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch_struct
11
+
12
+ from .parse_base import CompressedParserOutput
13
+
14
+
15
+ def pad_charts(charts, padding_value=-100):
16
+ """Pad a list of variable-length charts with `padding_value`."""
17
+ batch_size = len(charts)
18
+ max_len = max(chart.shape[0] for chart in charts)
19
+ padded_charts = torch.full(
20
+ (batch_size, max_len, max_len),
21
+ padding_value,
22
+ dtype=charts[0].dtype,
23
+ device=charts[0].device,
24
+ )
25
+ for i, chart in enumerate(charts):
26
+ chart_size = chart.shape[0]
27
+ padded_charts[i, :chart_size, :chart_size] = chart
28
+ return padded_charts
29
+
30
+
31
+ def collapse_unary_strip_pos(tree, strip_top=True):
32
+ """Collapse unary chains and strip part of speech tags."""
33
+
34
+ def strip_pos(tree):
35
+ if len(tree) == 1 and isinstance(tree[0], str):
36
+ return tree[0]
37
+ else:
38
+ return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])
39
+
40
+ collapsed_tree = strip_pos(tree)
41
+ collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
42
+ if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
43
+ if strip_top:
44
+ if len(collapsed_tree) == 1:
45
+ collapsed_tree = collapsed_tree[0]
46
+ else:
47
+ collapsed_tree.set_label("")
48
+ elif len(collapsed_tree) == 1:
49
+ collapsed_tree[0].set_label(
50
+ collapsed_tree.label() + "::" + collapsed_tree[0].label())
51
+ collapsed_tree = collapsed_tree[0]
52
+ return collapsed_tree
53
+
54
+
55
+ def _get_labeled_spans(tree, spans_out, start):
56
+ if isinstance(tree, str):
57
+ return start + 1
58
+
59
+ assert len(tree) > 1 or isinstance(
60
+ tree[0], str
61
+ ), "Must call collapse_unary_strip_pos first"
62
+ end = start
63
+ for child in tree:
64
+ end = _get_labeled_spans(child, spans_out, end)
65
+ # Spans are returned as closed intervals on both ends
66
+ spans_out.append((start, end - 1, tree.label()))
67
+ return end
68
+
69
+
70
+ def get_labeled_spans(tree):
71
+ """Converts a tree into a list of labeled spans.
72
+
73
+ Args:
74
+ tree: an nltk.tree.Tree object
75
+
76
+ Returns:
77
+ A list of (span_start, span_end, span_label) tuples. The start and end
78
+ indices indicate the first and last words of the span (a closed
79
+ interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
80
+ result in a single span labeled "S+VP".
81
+ """
82
+ tree = collapse_unary_strip_pos(tree)
83
+ spans_out = []
84
+ _get_labeled_spans(tree, spans_out, start=0)
85
+ return spans_out
86
+
87
+
88
+ def uncollapse_unary(tree, ensure_top=False):
89
+ """Un-collapse unary chains."""
90
+ if isinstance(tree, str):
91
+ return tree
92
+ else:
93
+ labels = tree.label().split("::")
94
+ if ensure_top and labels[0] != "TOP":
95
+ labels = ["TOP"] + labels
96
+ children = []
97
+ for child in tree:
98
+ child = uncollapse_unary(child)
99
+ children.append(child)
100
+ for label in labels[::-1]:
101
+ children = [nltk.tree.Tree(label, children)]
102
+ return children[0]
103
+
104
+
105
+ class ChartDecoder:
106
+ """A chart decoder for parsing formulated as span classification."""
107
+
108
+ def __init__(self, label_vocab, force_root_constituent=True):
109
+ """Constructs a new ChartDecoder object.
110
+ Args:
111
+ label_vocab: A mapping from span labels to integer indices.
112
+ """
113
+ self.label_vocab = label_vocab
114
+ self.label_from_index = {i: label for label, i in label_vocab.items()}
115
+ self.force_root_constituent = force_root_constituent
116
+
117
+ @staticmethod
118
+ def build_vocab(trees):
119
+ label_set = set()
120
+ for tree in trees:
121
+ for _, _, label in get_labeled_spans(tree):
122
+ if label:
123
+ label_set.add(label)
124
+ label_set = [""] + sorted(label_set)
125
+ return {label: i for i, label in enumerate(label_set)}
126
+
127
+ @staticmethod
128
+ def infer_force_root_constituent(trees):
129
+ for tree in trees:
130
+ for _, _, label in get_labeled_spans(tree):
131
+ if not label:
132
+ return False
133
+ return True
134
+
135
+ def chart_from_tree(self, tree):
136
+ spans = get_labeled_spans(tree)
137
+ num_words = len(tree.leaves())
138
+ chart = np.full((num_words, num_words), -100, dtype=int)
139
+ chart = np.tril(chart, -1)
140
+ # Now all invalid entries are filled with -100, and valid entries with 0
141
+ for start, end, label in spans:
142
+ # Previously unseen unary chains can occur in the dev/test sets.
143
+ # For now, we ignore them and don't mark the corresponding chart
144
+ # entry as a constituent.
145
+ if label in self.label_vocab:
146
+ chart[start, end] = self.label_vocab[label]
147
+ return chart
148
+
149
+ def charts_from_pytorch_scores_batched(self, scores, lengths):
150
+ """Runs CKY to recover span labels from scores (e.g. logits).
151
+
152
+ This method uses pytorch-struct to speed up decoding compared to the
153
+ pure-Python implementation of CKY used by tree_from_scores().
154
+
155
+ Args:
156
+ scores: a pytorch tensor of shape (batch size, max length,
157
+ max length, label vocab size).
158
+ lengths: a pytorch tensor of shape (batch size,)
159
+
160
+ Returns:
161
+ A list of numpy arrays, each of shape (sentence length, sentence
162
+ length).
163
+ """
164
+ scores = scores.detach()
165
+ scores = scores - scores[..., :1]
166
+ if self.force_root_constituent:
167
+ scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9
168
+ dist = torch_struct.TreeCRF(scores, lengths=lengths)
169
+ amax = dist.argmax
170
+ amax[..., 0] += 1e-9
171
+ padded_charts = amax.argmax(-1)
172
+ padded_charts = padded_charts.detach().cpu().numpy()
173
+ return [
174
+ chart[:length, :length] for chart, length in zip(padded_charts, lengths)
175
+ ]
176
+
177
+ def compressed_output_from_chart(self, chart):
178
+ chart_with_filled_diagonal = chart.copy()
179
+ np.fill_diagonal(chart_with_filled_diagonal, 1)
180
+ chart_with_filled_diagonal[0, -1] = 1
181
+ starts, inclusive_ends = np.where(chart_with_filled_diagonal)
182
+ preorder_sort = np.lexsort((-inclusive_ends, starts))
183
+ starts = starts[preorder_sort]
184
+ inclusive_ends = inclusive_ends[preorder_sort]
185
+ labels = chart[starts, inclusive_ends]
186
+ ends = inclusive_ends + 1
187
+ return CompressedParserOutput(starts=starts, ends=ends, labels=labels)
188
+
189
+ def tree_from_chart(self, chart, leaves):
190
+ compressed_output = self.compressed_output_from_chart(chart)
191
+ return compressed_output.to_tree(leaves, self.label_from_index)
192
+
193
+ def tree_from_scores(self, scores, leaves):
194
+ """Runs CKY to decode a tree from scores (e.g. logits).
195
+
196
+ If speed is important, consider using charts_from_pytorch_scores_batched
197
+ followed by compressed_output_from_chart or tree_from_chart instead.
198
+
199
+ Args:
200
+ scores: a chart of scores (or logits) of shape
201
+ (sentence length, sentence length, label vocab size). The first
202
+ two dimensions may be padded to a longer length, but all padded
203
+ values will be ignored.
204
+ leaves: the leaf nodes to use in the constructed tree. These
205
+ may be of type str or nltk.Tree, or (word, tag) tuples that
206
+ will be used to construct the leaf node objects.
207
+
208
+ Returns:
209
+ An nltk.Tree object.
210
+ """
211
+ leaves = [
212
+ nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node
213
+ for node in leaves
214
+ ]
215
+
216
+ chart = {}
217
+ scores = scores - scores[:, :, 0, None]
218
+ for length in range(1, len(leaves) + 1):
219
+ for left in range(0, len(leaves) + 1 - length):
220
+ right = left + length
221
+
222
+ label_scores = scores[left, right - 1]
223
+ label_scores = label_scores - label_scores[0]
224
+
225
+ argmax_label_index = int(
226
+ label_scores.argmax()
227
+ if length < len(leaves) or not self.force_root_constituent
228
+ else label_scores[1:].argmax() + 1
229
+ )
230
+ argmax_label = self.label_from_index[argmax_label_index]
231
+ label = argmax_label
232
+ label_score = label_scores[argmax_label_index]
233
+
234
+ if length == 1:
235
+ tree = leaves[left]
236
+ if label:
237
+ tree = nltk.tree.Tree(label, [tree])
238
+ chart[left, right] = [tree], label_score
239
+ continue
240
+
241
+ best_split = max(
242
+ range(left + 1, right),
243
+ key=lambda split: (chart[left, split][1] + chart[split, right][1]),
244
+ )
245
+
246
+ left_trees, left_score = chart[left, best_split]
247
+ right_trees, right_score = chart[best_split, right]
248
+
249
+ children = left_trees + right_trees
250
+ if label:
251
+ children = [nltk.tree.Tree(label, children)]
252
+
253
+ chart[left, right] = (children, label_score + left_score + right_score)
254
+
255
+ children, score = chart[0, len(leaves)]
256
+ tree = nltk.tree.Tree("TOP", children)
257
+ tree = uncollapse_unary(tree)
258
+ return tree
259
+
260
+
261
+ class SpanClassificationMarginLoss(nn.Module):
262
+ def __init__(self, force_root_constituent=True, reduction="mean"):
263
+ super().__init__()
264
+ self.force_root_constituent = force_root_constituent
265
+ if reduction not in ("none", "mean", "sum"):
266
+ raise ValueError(f"Invalid value for reduction: {reduction}")
267
+ self.reduction = reduction
268
+
269
+ def forward(self, logits, labels):
270
+ gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1])
271
+
272
+ logits = logits - logits[..., :1]
273
+ lengths = (labels[:, 0, :] != -100).sum(-1)
274
+ augment = (1 - gold_event).to(torch.float)
275
+ if self.force_root_constituent:
276
+ augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9
277
+ dist = torch_struct.TreeCRF(logits + augment, lengths=lengths)
278
+
279
+ pred_score = dist.max
280
+ gold_score = (logits * gold_event).sum((1, 2, 3))
281
+
282
+ margin_losses = F.relu(pred_score - gold_score)
283
+
284
+ if self.reduction == "none":
285
+ return margin_losses
286
+ elif self.reduction == "mean":
287
+ return margin_losses.mean()
288
+ elif self.reduction == "sum":
289
+ return margin_losses.sum()
290
+ else:
291
+ assert False, f"Unexpected reduction: {self.reduction}"
benepar/integrations/__init__.py ADDED
File without changes
benepar/integrations/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (208 Bytes). View file
 
benepar/integrations/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (191 Bytes). View file
 
benepar/integrations/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
benepar/integrations/__pycache__/downloader.cpython-310.pyc ADDED
Binary file (1.35 kB). View file
 
benepar/integrations/__pycache__/downloader.cpython-37.pyc ADDED
Binary file (1.31 kB). View file
 
benepar/integrations/__pycache__/downloader.cpython-38.pyc ADDED
Binary file (1.31 kB). View file
 
benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc ADDED
Binary file (11.1 kB). View file
 
benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc ADDED
Binary file (4.44 kB). View file
 
benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc ADDED
Binary file (4.32 kB). View file
 
benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc ADDED
Binary file (4.36 kB). View file
 
benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc ADDED
Binary file (6.63 kB). View file