Spaces:
Sleeping
Sleeping
add parsing
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __pycache__/parse.cpython-38.pyc +0 -0
- app.py +11 -9
- benepar/__init__.py +20 -0
- benepar/__pycache__/__init__.cpython-310.pyc +0 -0
- benepar/__pycache__/__init__.cpython-37.pyc +0 -0
- benepar/__pycache__/__init__.cpython-38.pyc +0 -0
- benepar/__pycache__/char_lstm.cpython-310.pyc +0 -0
- benepar/__pycache__/char_lstm.cpython-37.pyc +0 -0
- benepar/__pycache__/char_lstm.cpython-38.pyc +0 -0
- benepar/__pycache__/decode_chart.cpython-310.pyc +0 -0
- benepar/__pycache__/decode_chart.cpython-37.pyc +0 -0
- benepar/__pycache__/decode_chart.cpython-38.pyc +0 -0
- benepar/__pycache__/nkutil.cpython-310.pyc +0 -0
- benepar/__pycache__/nkutil.cpython-37.pyc +0 -0
- benepar/__pycache__/nkutil.cpython-38.pyc +0 -0
- benepar/__pycache__/parse_base.cpython-310.pyc +0 -0
- benepar/__pycache__/parse_base.cpython-37.pyc +0 -0
- benepar/__pycache__/parse_base.cpython-38.pyc +0 -0
- benepar/__pycache__/parse_chart.cpython-310.pyc +0 -0
- benepar/__pycache__/parse_chart.cpython-37.pyc +0 -0
- benepar/__pycache__/parse_chart.cpython-38.pyc +0 -0
- benepar/__pycache__/partitioned_transformer.cpython-310.pyc +0 -0
- benepar/__pycache__/partitioned_transformer.cpython-37.pyc +0 -0
- benepar/__pycache__/partitioned_transformer.cpython-38.pyc +0 -0
- benepar/__pycache__/ptb_unescape.cpython-310.pyc +0 -0
- benepar/__pycache__/ptb_unescape.cpython-37.pyc +0 -0
- benepar/__pycache__/ptb_unescape.cpython-38.pyc +0 -0
- benepar/__pycache__/retokenization.cpython-310.pyc +0 -0
- benepar/__pycache__/retokenization.cpython-37.pyc +0 -0
- benepar/__pycache__/retokenization.cpython-38.pyc +0 -0
- benepar/__pycache__/subbatching.cpython-310.pyc +0 -0
- benepar/__pycache__/subbatching.cpython-37.pyc +0 -0
- benepar/__pycache__/subbatching.cpython-38.pyc +0 -0
- benepar/char_lstm.py +160 -0
- benepar/decode_chart.py +291 -0
- benepar/decode_chart.py~ +291 -0
- benepar/integrations/__init__.py +0 -0
- benepar/integrations/__pycache__/__init__.cpython-310.pyc +0 -0
- benepar/integrations/__pycache__/__init__.cpython-37.pyc +0 -0
- benepar/integrations/__pycache__/__init__.cpython-38.pyc +0 -0
- benepar/integrations/__pycache__/downloader.cpython-310.pyc +0 -0
- benepar/integrations/__pycache__/downloader.cpython-37.pyc +0 -0
- benepar/integrations/__pycache__/downloader.cpython-38.pyc +0 -0
- benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc +0 -0
- benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc +0 -0
- benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc +0 -0
- benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc +0 -0
- benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc +0 -0
- benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc +0 -0
- benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc +0 -0
__pycache__/parse.cpython-38.pyc
CHANGED
Binary files a/__pycache__/parse.cpython-38.pyc and b/__pycache__/parse.cpython-38.pyc differ
|
|
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
from nltk import Tree
|
4 |
import pandas as pd
|
5 |
import re
|
@@ -31,19 +31,21 @@ if text:
|
|
31 |
|
32 |
df = pd.DataFrame(zipped, columns=['Token', 'Tag', 'Prob.'])
|
33 |
|
34 |
-
|
35 |
-
# t = Tree.fromstring(re.sub(r'(\.[^ )]+)+', '', parse_tree))
|
36 |
|
37 |
-
#
|
|
|
|
|
|
|
38 |
|
39 |
col1 = st.columns(1)[0]
|
40 |
col1.header("POS tagging result:")
|
41 |
col1.table(df)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
|
|
|
1 |
import streamlit as st
|
2 |
+
from parse import parse
|
3 |
from nltk import Tree
|
4 |
import pandas as pd
|
5 |
import re
|
|
|
31 |
|
32 |
df = pd.DataFrame(zipped, columns=['Token', 'Tag', 'Prob.'])
|
33 |
|
34 |
+
parse_tree = parse(tokens)
|
|
|
35 |
|
36 |
+
# Convert the bracket parse tree into an NLTK Tree
|
37 |
+
t = Tree.fromstring(re.sub(r'-[^ )]*', '', parse_tree))
|
38 |
+
|
39 |
+
tree_svg = TreePrettyPrinter(t).svg(nodecolor='black', leafcolor='black', funccolor='black')
|
40 |
|
41 |
col1 = st.columns(1)[0]
|
42 |
col1.header("POS tagging result:")
|
43 |
col1.table(df)
|
44 |
|
45 |
+
col2 = st.columns(1)[0]
|
46 |
+
col2.header("Parsing result:")
|
47 |
+
col2.write(parse_tree.replace('_', '\_').replace('$', '\$').replace('*', '\*'))
|
48 |
|
49 |
+
# Display the graph in the Streamlit app
|
50 |
+
col2.image(tree_svg, use_column_width=True)
|
51 |
|
benepar/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
benepar: Berkeley Neural Parser
|
3 |
+
"""
|
4 |
+
|
5 |
+
# This file and all code in integrations/ relate to the version of the parser
|
6 |
+
# released via PyPI. If you only need to run research experiments, it is safe
|
7 |
+
# to delete the integrations/ folder and replace this __init__.py with an
|
8 |
+
# empty file.
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"Parser",
|
12 |
+
"InputSentence",
|
13 |
+
"download",
|
14 |
+
"BeneparComponent",
|
15 |
+
"NonConstituentException",
|
16 |
+
]
|
17 |
+
|
18 |
+
from .integrations.downloader import download
|
19 |
+
from .integrations.nltk_plugin import Parser, InputSentence
|
20 |
+
from .integrations.spacy_plugin import BeneparComponent, NonConstituentException
|
benepar/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (526 Bytes). View file
|
|
benepar/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (521 Bytes). View file
|
|
benepar/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (505 Bytes). View file
|
|
benepar/__pycache__/char_lstm.cpython-310.pyc
ADDED
Binary file (4.94 kB). View file
|
|
benepar/__pycache__/char_lstm.cpython-37.pyc
ADDED
Binary file (4.92 kB). View file
|
|
benepar/__pycache__/char_lstm.cpython-38.pyc
ADDED
Binary file (4.96 kB). View file
|
|
benepar/__pycache__/decode_chart.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
benepar/__pycache__/decode_chart.cpython-37.pyc
ADDED
Binary file (10.2 kB). View file
|
|
benepar/__pycache__/decode_chart.cpython-38.pyc
ADDED
Binary file (10.2 kB). View file
|
|
benepar/__pycache__/nkutil.cpython-310.pyc
ADDED
Binary file (2.14 kB). View file
|
|
benepar/__pycache__/nkutil.cpython-37.pyc
ADDED
Binary file (2.1 kB). View file
|
|
benepar/__pycache__/nkutil.cpython-38.pyc
ADDED
Binary file (2.09 kB). View file
|
|
benepar/__pycache__/parse_base.cpython-310.pyc
ADDED
Binary file (7.38 kB). View file
|
|
benepar/__pycache__/parse_base.cpython-37.pyc
ADDED
Binary file (7.16 kB). View file
|
|
benepar/__pycache__/parse_base.cpython-38.pyc
ADDED
Binary file (7.26 kB). View file
|
|
benepar/__pycache__/parse_chart.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
benepar/__pycache__/parse_chart.cpython-37.pyc
ADDED
Binary file (11 kB). View file
|
|
benepar/__pycache__/parse_chart.cpython-38.pyc
ADDED
Binary file (11.1 kB). View file
|
|
benepar/__pycache__/partitioned_transformer.cpython-310.pyc
ADDED
Binary file (7.83 kB). View file
|
|
benepar/__pycache__/partitioned_transformer.cpython-37.pyc
ADDED
Binary file (7.9 kB). View file
|
|
benepar/__pycache__/partitioned_transformer.cpython-38.pyc
ADDED
Binary file (7.82 kB). View file
|
|
benepar/__pycache__/ptb_unescape.cpython-310.pyc
ADDED
Binary file (3.05 kB). View file
|
|
benepar/__pycache__/ptb_unescape.cpython-37.pyc
ADDED
Binary file (3.2 kB). View file
|
|
benepar/__pycache__/ptb_unescape.cpython-38.pyc
ADDED
Binary file (3.19 kB). View file
|
|
benepar/__pycache__/retokenization.cpython-310.pyc
ADDED
Binary file (6.83 kB). View file
|
|
benepar/__pycache__/retokenization.cpython-37.pyc
ADDED
Binary file (6.73 kB). View file
|
|
benepar/__pycache__/retokenization.cpython-38.pyc
ADDED
Binary file (6.83 kB). View file
|
|
benepar/__pycache__/subbatching.cpython-310.pyc
ADDED
Binary file (2.53 kB). View file
|
|
benepar/__pycache__/subbatching.cpython-37.pyc
ADDED
Binary file (2.49 kB). View file
|
|
benepar/__pycache__/subbatching.cpython-38.pyc
ADDED
Binary file (2.48 kB). View file
|
|
benepar/char_lstm.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Character LSTM implementation (matches https://arxiv.org/pdf/1805.01052.pdf)
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class CharacterLSTM(nn.Module):
|
12 |
+
def __init__(self, num_embeddings, d_embedding, d_out, char_dropout=0.0, **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.d_embedding = d_embedding
|
16 |
+
self.d_out = d_out
|
17 |
+
|
18 |
+
self.lstm = nn.LSTM(
|
19 |
+
self.d_embedding, self.d_out // 2, num_layers=1, bidirectional=True
|
20 |
+
)
|
21 |
+
|
22 |
+
self.emb = nn.Embedding(num_embeddings, self.d_embedding, **kwargs)
|
23 |
+
self.char_dropout = nn.Dropout(char_dropout)
|
24 |
+
|
25 |
+
def forward(self, chars_packed, valid_token_mask):
|
26 |
+
inp_embs = nn.utils.rnn.PackedSequence(
|
27 |
+
self.char_dropout(self.emb(chars_packed.data)),
|
28 |
+
batch_sizes=chars_packed.batch_sizes,
|
29 |
+
sorted_indices=chars_packed.sorted_indices,
|
30 |
+
unsorted_indices=chars_packed.unsorted_indices,
|
31 |
+
)
|
32 |
+
|
33 |
+
_, (lstm_out, _) = self.lstm(inp_embs)
|
34 |
+
lstm_out = torch.cat([lstm_out[0], lstm_out[1]], -1)
|
35 |
+
|
36 |
+
# Switch to a representation where there are dummy vectors for invalid
|
37 |
+
# tokens generated by padding.
|
38 |
+
res = lstm_out.new_zeros(
|
39 |
+
(valid_token_mask.shape[0], valid_token_mask.shape[1], lstm_out.shape[-1])
|
40 |
+
)
|
41 |
+
res[valid_token_mask] = lstm_out
|
42 |
+
return res
|
43 |
+
|
44 |
+
|
45 |
+
class RetokenizerForCharLSTM:
|
46 |
+
# Assumes that these control characters are not present in treebank text
|
47 |
+
CHAR_UNK = "\0"
|
48 |
+
CHAR_ID_UNK = 0
|
49 |
+
CHAR_START_SENTENCE = "\1"
|
50 |
+
CHAR_START_WORD = "\2"
|
51 |
+
CHAR_STOP_WORD = "\3"
|
52 |
+
CHAR_STOP_SENTENCE = "\4"
|
53 |
+
|
54 |
+
def __init__(self, char_vocab):
|
55 |
+
self.char_vocab = char_vocab
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def build_vocab(cls, sentences):
|
59 |
+
char_set = set()
|
60 |
+
for sentence in sentences:
|
61 |
+
if isinstance(sentence, tuple):
|
62 |
+
sentence = sentence[0]
|
63 |
+
for word in sentence:
|
64 |
+
char_set |= set(word)
|
65 |
+
|
66 |
+
# If codepoints are small (e.g. Latin alphabet), index by codepoint
|
67 |
+
# directly
|
68 |
+
highest_codepoint = max(ord(char) for char in char_set)
|
69 |
+
if highest_codepoint < 512:
|
70 |
+
if highest_codepoint < 256:
|
71 |
+
highest_codepoint = 256
|
72 |
+
else:
|
73 |
+
highest_codepoint = 512
|
74 |
+
|
75 |
+
char_vocab = {}
|
76 |
+
# This also takes care of constants like CHAR_UNK, etc.
|
77 |
+
for codepoint in range(highest_codepoint):
|
78 |
+
char_vocab[chr(codepoint)] = codepoint
|
79 |
+
return char_vocab
|
80 |
+
else:
|
81 |
+
char_vocab = {}
|
82 |
+
char_vocab[cls.CHAR_UNK] = 0
|
83 |
+
char_vocab[cls.CHAR_START_SENTENCE] = 1
|
84 |
+
char_vocab[cls.CHAR_START_WORD] = 2
|
85 |
+
char_vocab[cls.CHAR_STOP_WORD] = 3
|
86 |
+
char_vocab[cls.CHAR_STOP_SENTENCE] = 4
|
87 |
+
for id_, char in enumerate(sorted(char_set), start=5):
|
88 |
+
char_vocab[char] = id_
|
89 |
+
return char_vocab
|
90 |
+
|
91 |
+
def __call__(self, words, space_after="ignored", return_tensors=None):
|
92 |
+
if return_tensors != "np":
|
93 |
+
raise NotImplementedError("Only return_tensors='np' is supported.")
|
94 |
+
|
95 |
+
res = {}
|
96 |
+
|
97 |
+
# Sentence-level start/stop tokens are encoded as 3 pseudo-chars
|
98 |
+
# Within each word, account for 2 start/stop characters
|
99 |
+
max_word_len = max(3, max(len(word) for word in words)) + 2
|
100 |
+
char_ids = np.zeros((len(words) + 2, max_word_len), dtype=int)
|
101 |
+
word_lens = np.zeros(len(words) + 2, dtype=int)
|
102 |
+
|
103 |
+
char_ids[0, :5] = [
|
104 |
+
self.char_vocab[self.CHAR_START_WORD],
|
105 |
+
self.char_vocab[self.CHAR_START_SENTENCE],
|
106 |
+
self.char_vocab[self.CHAR_START_SENTENCE],
|
107 |
+
self.char_vocab[self.CHAR_START_SENTENCE],
|
108 |
+
self.char_vocab[self.CHAR_STOP_WORD],
|
109 |
+
]
|
110 |
+
word_lens[0] = 5
|
111 |
+
for i, word in enumerate(words, start=1):
|
112 |
+
char_ids[i, 0] = self.char_vocab[self.CHAR_START_WORD]
|
113 |
+
for j, char in enumerate(word, start=1):
|
114 |
+
char_ids[i, j] = self.char_vocab.get(char, self.CHAR_ID_UNK)
|
115 |
+
char_ids[i, j + 1] = self.char_vocab[self.CHAR_STOP_WORD]
|
116 |
+
word_lens[i] = j + 2
|
117 |
+
char_ids[i + 1, :5] = [
|
118 |
+
self.char_vocab[self.CHAR_START_WORD],
|
119 |
+
self.char_vocab[self.CHAR_STOP_SENTENCE],
|
120 |
+
self.char_vocab[self.CHAR_STOP_SENTENCE],
|
121 |
+
self.char_vocab[self.CHAR_STOP_SENTENCE],
|
122 |
+
self.char_vocab[self.CHAR_STOP_WORD],
|
123 |
+
]
|
124 |
+
word_lens[i + 1] = 5
|
125 |
+
|
126 |
+
res["char_ids"] = char_ids
|
127 |
+
res["word_lens"] = word_lens
|
128 |
+
res["valid_token_mask"] = np.ones_like(word_lens, dtype=bool)
|
129 |
+
|
130 |
+
return res
|
131 |
+
|
132 |
+
def pad(self, examples, return_tensors=None):
|
133 |
+
if return_tensors != "pt":
|
134 |
+
raise NotImplementedError("Only return_tensors='pt' is supported.")
|
135 |
+
max_word_len = max(example["char_ids"].shape[-1] for example in examples)
|
136 |
+
char_ids = torch.cat(
|
137 |
+
[
|
138 |
+
F.pad(
|
139 |
+
torch.tensor(example["char_ids"]),
|
140 |
+
(0, max_word_len - example["char_ids"].shape[-1]),
|
141 |
+
)
|
142 |
+
for example in examples
|
143 |
+
]
|
144 |
+
)
|
145 |
+
word_lens = torch.cat(
|
146 |
+
[torch.tensor(example["word_lens"]) for example in examples]
|
147 |
+
)
|
148 |
+
valid_token_mask = nn.utils.rnn.pad_sequence(
|
149 |
+
[torch.tensor(example["valid_token_mask"]) for example in examples],
|
150 |
+
batch_first=True,
|
151 |
+
padding_value=False,
|
152 |
+
)
|
153 |
+
|
154 |
+
char_ids = nn.utils.rnn.pack_padded_sequence(
|
155 |
+
char_ids, word_lens, batch_first=True, enforce_sorted=False
|
156 |
+
)
|
157 |
+
return {
|
158 |
+
"char_ids": char_ids,
|
159 |
+
"valid_token_mask": valid_token_mask,
|
160 |
+
}
|
benepar/decode_chart.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Parsing formulated as span classification (https://arxiv.org/abs/1705.03919)
|
3 |
+
"""
|
4 |
+
|
5 |
+
import nltk
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch_struct
|
11 |
+
|
12 |
+
from .parse_base import CompressedParserOutput
|
13 |
+
|
14 |
+
|
15 |
+
def pad_charts(charts, padding_value=-100):
|
16 |
+
"""Pad a list of variable-length charts with `padding_value`."""
|
17 |
+
batch_size = len(charts)
|
18 |
+
max_len = max(chart.shape[0] for chart in charts)
|
19 |
+
padded_charts = torch.full(
|
20 |
+
(batch_size, max_len, max_len),
|
21 |
+
padding_value,
|
22 |
+
dtype=charts[0].dtype,
|
23 |
+
device=charts[0].device,
|
24 |
+
)
|
25 |
+
for i, chart in enumerate(charts):
|
26 |
+
chart_size = chart.shape[0]
|
27 |
+
padded_charts[i, :chart_size, :chart_size] = chart
|
28 |
+
return padded_charts
|
29 |
+
|
30 |
+
|
31 |
+
def collapse_unary_strip_pos(tree, strip_top=True):
|
32 |
+
"""Collapse unary chains and strip part of speech tags."""
|
33 |
+
|
34 |
+
def strip_pos(tree):
|
35 |
+
if len(tree) == 1 and isinstance(tree[0], str):
|
36 |
+
return tree[0]
|
37 |
+
else:
|
38 |
+
return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])
|
39 |
+
|
40 |
+
collapsed_tree = strip_pos(tree)
|
41 |
+
collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
|
42 |
+
if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
|
43 |
+
if strip_top:
|
44 |
+
if len(collapsed_tree) == 1:
|
45 |
+
collapsed_tree = collapsed_tree[0]
|
46 |
+
else:
|
47 |
+
collapsed_tree.set_label("")
|
48 |
+
elif len(collapsed_tree) == 1:
|
49 |
+
collapsed_tree[0].set_label(
|
50 |
+
collapsed_tree.label() + "::" + collapsed_tree[0].label())
|
51 |
+
collapsed_tree = collapsed_tree[0]
|
52 |
+
return collapsed_tree
|
53 |
+
|
54 |
+
|
55 |
+
def _get_labeled_spans(tree, spans_out, start):
|
56 |
+
if isinstance(tree, str):
|
57 |
+
return start + 1
|
58 |
+
|
59 |
+
assert len(tree) > 1 or isinstance(
|
60 |
+
tree[0], str
|
61 |
+
), "Must call collapse_unary_strip_pos first"
|
62 |
+
end = start
|
63 |
+
for child in tree:
|
64 |
+
end = _get_labeled_spans(child, spans_out, end)
|
65 |
+
# Spans are returned as closed intervals on both ends
|
66 |
+
spans_out.append((start, end - 1, tree.label()))
|
67 |
+
return end
|
68 |
+
|
69 |
+
|
70 |
+
def get_labeled_spans(tree):
|
71 |
+
"""Converts a tree into a list of labeled spans.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
tree: an nltk.tree.Tree object
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
A list of (span_start, span_end, span_label) tuples. The start and end
|
78 |
+
indices indicate the first and last words of the span (a closed
|
79 |
+
interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
|
80 |
+
result in a single span labeled "S+VP".
|
81 |
+
"""
|
82 |
+
tree = collapse_unary_strip_pos(tree)
|
83 |
+
spans_out = []
|
84 |
+
_get_labeled_spans(tree, spans_out, start=0)
|
85 |
+
return spans_out
|
86 |
+
|
87 |
+
|
88 |
+
def uncollapse_unary(tree, ensure_top=False):
|
89 |
+
"""Un-collapse unary chains."""
|
90 |
+
if isinstance(tree, str):
|
91 |
+
return tree
|
92 |
+
else:
|
93 |
+
labels = tree.label().split("::")
|
94 |
+
if ensure_top and labels[0] != "TOP":
|
95 |
+
labels = ["TOP"] + labels
|
96 |
+
children = []
|
97 |
+
for child in tree:
|
98 |
+
child = uncollapse_unary(child)
|
99 |
+
children.append(child)
|
100 |
+
for label in labels[::-1]:
|
101 |
+
children = [nltk.tree.Tree(label, children)]
|
102 |
+
return children[0]
|
103 |
+
|
104 |
+
|
105 |
+
class ChartDecoder:
|
106 |
+
"""A chart decoder for parsing formulated as span classification."""
|
107 |
+
|
108 |
+
def __init__(self, label_vocab, force_root_constituent=True):
|
109 |
+
"""Constructs a new ChartDecoder object.
|
110 |
+
Args:
|
111 |
+
label_vocab: A mapping from span labels to integer indices.
|
112 |
+
"""
|
113 |
+
self.label_vocab = label_vocab
|
114 |
+
self.label_from_index = {i: label for label, i in label_vocab.items()}
|
115 |
+
self.force_root_constituent = force_root_constituent
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def build_vocab(trees):
|
119 |
+
label_set = set()
|
120 |
+
for tree in trees:
|
121 |
+
for _, _, label in get_labeled_spans(tree):
|
122 |
+
if label:
|
123 |
+
label_set.add(label)
|
124 |
+
label_set = [""] + sorted(label_set)
|
125 |
+
return {label: i for i, label in enumerate(label_set)}
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def infer_force_root_constituent(trees):
|
129 |
+
for tree in trees:
|
130 |
+
for _, _, label in get_labeled_spans(tree):
|
131 |
+
if not label:
|
132 |
+
return False
|
133 |
+
return True
|
134 |
+
|
135 |
+
def chart_from_tree(self, tree):
|
136 |
+
spans = get_labeled_spans(tree)
|
137 |
+
num_words = len(tree.leaves())
|
138 |
+
chart = np.full((num_words, num_words), -100, dtype=int)
|
139 |
+
chart = np.tril(chart, -1)
|
140 |
+
# Now all invalid entries are filled with -100, and valid entries with 0
|
141 |
+
for start, end, label in spans:
|
142 |
+
# Previously unseen unary chains can occur in the dev/test sets.
|
143 |
+
# For now, we ignore them and don't mark the corresponding chart
|
144 |
+
# entry as a constituent.
|
145 |
+
if label in self.label_vocab:
|
146 |
+
chart[start, end] = self.label_vocab[label]
|
147 |
+
return chart
|
148 |
+
|
149 |
+
def charts_from_pytorch_scores_batched(self, scores, lengths):
|
150 |
+
"""Runs CKY to recover span labels from scores (e.g. logits).
|
151 |
+
|
152 |
+
This method uses pytorch-struct to speed up decoding compared to the
|
153 |
+
pure-Python implementation of CKY used by tree_from_scores().
|
154 |
+
|
155 |
+
Args:
|
156 |
+
scores: a pytorch tensor of shape (batch size, max length,
|
157 |
+
max length, label vocab size).
|
158 |
+
lengths: a pytorch tensor of shape (batch size,)
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
A list of numpy arrays, each of shape (sentence length, sentence
|
162 |
+
length).
|
163 |
+
"""
|
164 |
+
scores = scores.detach()
|
165 |
+
scores = scores - scores[..., :1]
|
166 |
+
if self.force_root_constituent:
|
167 |
+
scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9
|
168 |
+
dist = torch_struct.TreeCRF(scores, lengths=lengths)
|
169 |
+
amax = dist.argmax
|
170 |
+
amax[..., 0] += 1e-9
|
171 |
+
padded_charts = amax.argmax(-1)
|
172 |
+
padded_charts = padded_charts.detach().cpu().numpy()
|
173 |
+
return [
|
174 |
+
chart[:length, :length] for chart, length in zip(padded_charts, lengths)
|
175 |
+
]
|
176 |
+
|
177 |
+
def compressed_output_from_chart(self, chart):
|
178 |
+
chart_with_filled_diagonal = chart.copy()
|
179 |
+
np.fill_diagonal(chart_with_filled_diagonal, 1)
|
180 |
+
chart_with_filled_diagonal[0, -1] = 1
|
181 |
+
starts, inclusive_ends = np.where(chart_with_filled_diagonal)
|
182 |
+
preorder_sort = np.lexsort((-inclusive_ends, starts))
|
183 |
+
starts = starts[preorder_sort]
|
184 |
+
inclusive_ends = inclusive_ends[preorder_sort]
|
185 |
+
labels = chart[starts, inclusive_ends]
|
186 |
+
ends = inclusive_ends + 1
|
187 |
+
return CompressedParserOutput(starts=starts, ends=ends, labels=labels)
|
188 |
+
|
189 |
+
def tree_from_chart(self, chart, leaves):
|
190 |
+
compressed_output = self.compressed_output_from_chart(chart)
|
191 |
+
return compressed_output.to_tree(leaves, self.label_from_index)
|
192 |
+
|
193 |
+
def tree_from_scores(self, scores, leaves):
|
194 |
+
"""Runs CKY to decode a tree from scores (e.g. logits).
|
195 |
+
|
196 |
+
If speed is important, consider using charts_from_pytorch_scores_batched
|
197 |
+
followed by compressed_output_from_chart or tree_from_chart instead.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
scores: a chart of scores (or logits) of shape
|
201 |
+
(sentence length, sentence length, label vocab size). The first
|
202 |
+
two dimensions may be padded to a longer length, but all padded
|
203 |
+
values will be ignored.
|
204 |
+
leaves: the leaf nodes to use in the constructed tree. These
|
205 |
+
may be of type str or nltk.Tree, or (word, tag) tuples that
|
206 |
+
will be used to construct the leaf node objects.
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
An nltk.Tree object.
|
210 |
+
"""
|
211 |
+
leaves = [
|
212 |
+
nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node
|
213 |
+
for node in leaves
|
214 |
+
]
|
215 |
+
|
216 |
+
chart = {}
|
217 |
+
scores = scores - scores[:, :, 0, None]
|
218 |
+
for length in range(1, len(leaves) + 1):
|
219 |
+
for left in range(0, len(leaves) + 1 - length):
|
220 |
+
right = left + length
|
221 |
+
|
222 |
+
label_scores = scores[left, right - 1]
|
223 |
+
label_scores = label_scores - label_scores[0]
|
224 |
+
|
225 |
+
argmax_label_index = int(
|
226 |
+
label_scores.argmax()
|
227 |
+
if length < len(leaves) or not self.force_root_constituent
|
228 |
+
else label_scores[1:].argmax() + 1
|
229 |
+
)
|
230 |
+
argmax_label = self.label_from_index[argmax_label_index]
|
231 |
+
label = argmax_label
|
232 |
+
label_score = label_scores[argmax_label_index]
|
233 |
+
|
234 |
+
if length == 1:
|
235 |
+
tree = leaves[left]
|
236 |
+
if label:
|
237 |
+
tree = nltk.tree.Tree(label, [tree])
|
238 |
+
chart[left, right] = [tree], label_score
|
239 |
+
continue
|
240 |
+
|
241 |
+
best_split = max(
|
242 |
+
range(left + 1, right),
|
243 |
+
key=lambda split: (chart[left, split][1] + chart[split, right][1]),
|
244 |
+
)
|
245 |
+
|
246 |
+
left_trees, left_score = chart[left, best_split]
|
247 |
+
right_trees, right_score = chart[best_split, right]
|
248 |
+
|
249 |
+
children = left_trees + right_trees
|
250 |
+
if label:
|
251 |
+
children = [nltk.tree.Tree(label, children)]
|
252 |
+
|
253 |
+
chart[left, right] = (children, label_score + left_score + right_score)
|
254 |
+
|
255 |
+
children, score = chart[0, len(leaves)]
|
256 |
+
tree = nltk.tree.Tree("TOP", children)
|
257 |
+
tree = uncollapse_unary(tree)
|
258 |
+
return tree
|
259 |
+
|
260 |
+
|
261 |
+
class SpanClassificationMarginLoss(nn.Module):
|
262 |
+
def __init__(self, force_root_constituent=True, reduction="mean"):
|
263 |
+
super().__init__()
|
264 |
+
self.force_root_constituent = force_root_constituent
|
265 |
+
if reduction not in ("none", "mean", "sum"):
|
266 |
+
raise ValueError(f"Invalid value for reduction: {reduction}")
|
267 |
+
self.reduction = reduction
|
268 |
+
|
269 |
+
def forward(self, logits, labels):
|
270 |
+
gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1])
|
271 |
+
|
272 |
+
logits = logits - logits[..., :1]
|
273 |
+
lengths = (labels[:, 0, :] != -100).sum(-1)
|
274 |
+
augment = (1 - gold_event).to(torch.float)
|
275 |
+
if self.force_root_constituent:
|
276 |
+
augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9
|
277 |
+
dist = torch_struct.TreeCRF(logits + augment, lengths=lengths)
|
278 |
+
|
279 |
+
pred_score = dist.max
|
280 |
+
gold_score = (logits * gold_event).sum((1, 2, 3))
|
281 |
+
|
282 |
+
margin_losses = F.relu(pred_score - gold_score)
|
283 |
+
|
284 |
+
if self.reduction == "none":
|
285 |
+
return margin_losses
|
286 |
+
elif self.reduction == "mean":
|
287 |
+
return margin_losses.mean()
|
288 |
+
elif self.reduction == "sum":
|
289 |
+
return margin_losses.sum()
|
290 |
+
else:
|
291 |
+
assert False, f"Unexpected reduction: {self.reduction}"
|
benepar/decode_chart.py~
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Parsing formulated as span classification (https://arxiv.org/abs/1705.03919)
|
3 |
+
"""
|
4 |
+
|
5 |
+
import nltk
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch_struct
|
11 |
+
|
12 |
+
from .parse_base import CompressedParserOutput
|
13 |
+
|
14 |
+
|
15 |
+
def pad_charts(charts, padding_value=-100):
|
16 |
+
"""Pad a list of variable-length charts with `padding_value`."""
|
17 |
+
batch_size = len(charts)
|
18 |
+
max_len = max(chart.shape[0] for chart in charts)
|
19 |
+
padded_charts = torch.full(
|
20 |
+
(batch_size, max_len, max_len),
|
21 |
+
padding_value,
|
22 |
+
dtype=charts[0].dtype,
|
23 |
+
device=charts[0].device,
|
24 |
+
)
|
25 |
+
for i, chart in enumerate(charts):
|
26 |
+
chart_size = chart.shape[0]
|
27 |
+
padded_charts[i, :chart_size, :chart_size] = chart
|
28 |
+
return padded_charts
|
29 |
+
|
30 |
+
|
31 |
+
def collapse_unary_strip_pos(tree, strip_top=True):
|
32 |
+
"""Collapse unary chains and strip part of speech tags."""
|
33 |
+
|
34 |
+
def strip_pos(tree):
|
35 |
+
if len(tree) == 1 and isinstance(tree[0], str):
|
36 |
+
return tree[0]
|
37 |
+
else:
|
38 |
+
return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])
|
39 |
+
|
40 |
+
collapsed_tree = strip_pos(tree)
|
41 |
+
collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
|
42 |
+
if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
|
43 |
+
if strip_top:
|
44 |
+
if len(collapsed_tree) == 1:
|
45 |
+
collapsed_tree = collapsed_tree[0]
|
46 |
+
else:
|
47 |
+
collapsed_tree.set_label("")
|
48 |
+
elif len(collapsed_tree) == 1:
|
49 |
+
collapsed_tree[0].set_label(
|
50 |
+
collapsed_tree.label() + "::" + collapsed_tree[0].label())
|
51 |
+
collapsed_tree = collapsed_tree[0]
|
52 |
+
return collapsed_tree
|
53 |
+
|
54 |
+
|
55 |
+
def _get_labeled_spans(tree, spans_out, start):
|
56 |
+
if isinstance(tree, str):
|
57 |
+
return start + 1
|
58 |
+
|
59 |
+
assert len(tree) > 1 or isinstance(
|
60 |
+
tree[0], str
|
61 |
+
), "Must call collapse_unary_strip_pos first"
|
62 |
+
end = start
|
63 |
+
for child in tree:
|
64 |
+
end = _get_labeled_spans(child, spans_out, end)
|
65 |
+
# Spans are returned as closed intervals on both ends
|
66 |
+
spans_out.append((start, end - 1, tree.label()))
|
67 |
+
return end
|
68 |
+
|
69 |
+
|
70 |
+
def get_labeled_spans(tree):
|
71 |
+
"""Converts a tree into a list of labeled spans.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
tree: an nltk.tree.Tree object
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
A list of (span_start, span_end, span_label) tuples. The start and end
|
78 |
+
indices indicate the first and last words of the span (a closed
|
79 |
+
interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
|
80 |
+
result in a single span labeled "S+VP".
|
81 |
+
"""
|
82 |
+
tree = collapse_unary_strip_pos(tree)
|
83 |
+
spans_out = []
|
84 |
+
_get_labeled_spans(tree, spans_out, start=0)
|
85 |
+
return spans_out
|
86 |
+
|
87 |
+
|
88 |
+
def uncollapse_unary(tree, ensure_top=False):
|
89 |
+
"""Un-collapse unary chains."""
|
90 |
+
if isinstance(tree, str):
|
91 |
+
return tree
|
92 |
+
else:
|
93 |
+
labels = tree.label().split("::")
|
94 |
+
if ensure_top and labels[0] != "TOP":
|
95 |
+
labels = ["TOP"] + labels
|
96 |
+
children = []
|
97 |
+
for child in tree:
|
98 |
+
child = uncollapse_unary(child)
|
99 |
+
children.append(child)
|
100 |
+
for label in labels[::-1]:
|
101 |
+
children = [nltk.tree.Tree(label, children)]
|
102 |
+
return children[0]
|
103 |
+
|
104 |
+
|
105 |
+
class ChartDecoder:
|
106 |
+
"""A chart decoder for parsing formulated as span classification."""
|
107 |
+
|
108 |
+
def __init__(self, label_vocab, force_root_constituent=True):
|
109 |
+
"""Constructs a new ChartDecoder object.
|
110 |
+
Args:
|
111 |
+
label_vocab: A mapping from span labels to integer indices.
|
112 |
+
"""
|
113 |
+
self.label_vocab = label_vocab
|
114 |
+
self.label_from_index = {i: label for label, i in label_vocab.items()}
|
115 |
+
self.force_root_constituent = force_root_constituent
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def build_vocab(trees):
|
119 |
+
label_set = set()
|
120 |
+
for tree in trees:
|
121 |
+
for _, _, label in get_labeled_spans(tree):
|
122 |
+
if label:
|
123 |
+
label_set.add(label)
|
124 |
+
label_set = [""] + sorted(label_set)
|
125 |
+
return {label: i for i, label in enumerate(label_set)}
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def infer_force_root_constituent(trees):
|
129 |
+
for tree in trees:
|
130 |
+
for _, _, label in get_labeled_spans(tree):
|
131 |
+
if not label:
|
132 |
+
return False
|
133 |
+
return True
|
134 |
+
|
135 |
+
def chart_from_tree(self, tree):
|
136 |
+
spans = get_labeled_spans(tree)
|
137 |
+
num_words = len(tree.leaves())
|
138 |
+
chart = np.full((num_words, num_words), -100, dtype=int)
|
139 |
+
chart = np.tril(chart, -1)
|
140 |
+
# Now all invalid entries are filled with -100, and valid entries with 0
|
141 |
+
for start, end, label in spans:
|
142 |
+
# Previously unseen unary chains can occur in the dev/test sets.
|
143 |
+
# For now, we ignore them and don't mark the corresponding chart
|
144 |
+
# entry as a constituent.
|
145 |
+
if label in self.label_vocab:
|
146 |
+
chart[start, end] = self.label_vocab[label]
|
147 |
+
return chart
|
148 |
+
|
149 |
+
def charts_from_pytorch_scores_batched(self, scores, lengths):
|
150 |
+
"""Runs CKY to recover span labels from scores (e.g. logits).
|
151 |
+
|
152 |
+
This method uses pytorch-struct to speed up decoding compared to the
|
153 |
+
pure-Python implementation of CKY used by tree_from_scores().
|
154 |
+
|
155 |
+
Args:
|
156 |
+
scores: a pytorch tensor of shape (batch size, max length,
|
157 |
+
max length, label vocab size).
|
158 |
+
lengths: a pytorch tensor of shape (batch size,)
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
A list of numpy arrays, each of shape (sentence length, sentence
|
162 |
+
length).
|
163 |
+
"""
|
164 |
+
scores = scores.detach()
|
165 |
+
scores = scores - scores[..., :1]
|
166 |
+
if self.force_root_constituent:
|
167 |
+
scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9
|
168 |
+
dist = torch_struct.TreeCRF(scores, lengths=lengths)
|
169 |
+
amax = dist.argmax
|
170 |
+
amax[..., 0] += 1e-9
|
171 |
+
padded_charts = amax.argmax(-1)
|
172 |
+
padded_charts = padded_charts.detach().cpu().numpy()
|
173 |
+
return [
|
174 |
+
chart[:length, :length] for chart, length in zip(padded_charts, lengths)
|
175 |
+
]
|
176 |
+
|
177 |
+
def compressed_output_from_chart(self, chart):
|
178 |
+
chart_with_filled_diagonal = chart.copy()
|
179 |
+
np.fill_diagonal(chart_with_filled_diagonal, 1)
|
180 |
+
chart_with_filled_diagonal[0, -1] = 1
|
181 |
+
starts, inclusive_ends = np.where(chart_with_filled_diagonal)
|
182 |
+
preorder_sort = np.lexsort((-inclusive_ends, starts))
|
183 |
+
starts = starts[preorder_sort]
|
184 |
+
inclusive_ends = inclusive_ends[preorder_sort]
|
185 |
+
labels = chart[starts, inclusive_ends]
|
186 |
+
ends = inclusive_ends + 1
|
187 |
+
return CompressedParserOutput(starts=starts, ends=ends, labels=labels)
|
188 |
+
|
189 |
+
def tree_from_chart(self, chart, leaves):
|
190 |
+
compressed_output = self.compressed_output_from_chart(chart)
|
191 |
+
return compressed_output.to_tree(leaves, self.label_from_index)
|
192 |
+
|
193 |
+
def tree_from_scores(self, scores, leaves):
|
194 |
+
"""Runs CKY to decode a tree from scores (e.g. logits).
|
195 |
+
|
196 |
+
If speed is important, consider using charts_from_pytorch_scores_batched
|
197 |
+
followed by compressed_output_from_chart or tree_from_chart instead.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
scores: a chart of scores (or logits) of shape
|
201 |
+
(sentence length, sentence length, label vocab size). The first
|
202 |
+
two dimensions may be padded to a longer length, but all padded
|
203 |
+
values will be ignored.
|
204 |
+
leaves: the leaf nodes to use in the constructed tree. These
|
205 |
+
may be of type str or nltk.Tree, or (word, tag) tuples that
|
206 |
+
will be used to construct the leaf node objects.
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
An nltk.Tree object.
|
210 |
+
"""
|
211 |
+
leaves = [
|
212 |
+
nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node
|
213 |
+
for node in leaves
|
214 |
+
]
|
215 |
+
|
216 |
+
chart = {}
|
217 |
+
scores = scores - scores[:, :, 0, None]
|
218 |
+
for length in range(1, len(leaves) + 1):
|
219 |
+
for left in range(0, len(leaves) + 1 - length):
|
220 |
+
right = left + length
|
221 |
+
|
222 |
+
label_scores = scores[left, right - 1]
|
223 |
+
label_scores = label_scores - label_scores[0]
|
224 |
+
|
225 |
+
argmax_label_index = int(
|
226 |
+
label_scores.argmax()
|
227 |
+
if length < len(leaves) or not self.force_root_constituent
|
228 |
+
else label_scores[1:].argmax() + 1
|
229 |
+
)
|
230 |
+
argmax_label = self.label_from_index[argmax_label_index]
|
231 |
+
label = argmax_label
|
232 |
+
label_score = label_scores[argmax_label_index]
|
233 |
+
|
234 |
+
if length == 1:
|
235 |
+
tree = leaves[left]
|
236 |
+
if label:
|
237 |
+
tree = nltk.tree.Tree(label, [tree])
|
238 |
+
chart[left, right] = [tree], label_score
|
239 |
+
continue
|
240 |
+
|
241 |
+
best_split = max(
|
242 |
+
range(left + 1, right),
|
243 |
+
key=lambda split: (chart[left, split][1] + chart[split, right][1]),
|
244 |
+
)
|
245 |
+
|
246 |
+
left_trees, left_score = chart[left, best_split]
|
247 |
+
right_trees, right_score = chart[best_split, right]
|
248 |
+
|
249 |
+
children = left_trees + right_trees
|
250 |
+
if label:
|
251 |
+
children = [nltk.tree.Tree(label, children)]
|
252 |
+
|
253 |
+
chart[left, right] = (children, label_score + left_score + right_score)
|
254 |
+
|
255 |
+
children, score = chart[0, len(leaves)]
|
256 |
+
tree = nltk.tree.Tree("TOP", children)
|
257 |
+
tree = uncollapse_unary(tree)
|
258 |
+
return tree
|
259 |
+
|
260 |
+
|
261 |
+
class SpanClassificationMarginLoss(nn.Module):
|
262 |
+
def __init__(self, force_root_constituent=True, reduction="mean"):
|
263 |
+
super().__init__()
|
264 |
+
self.force_root_constituent = force_root_constituent
|
265 |
+
if reduction not in ("none", "mean", "sum"):
|
266 |
+
raise ValueError(f"Invalid value for reduction: {reduction}")
|
267 |
+
self.reduction = reduction
|
268 |
+
|
269 |
+
def forward(self, logits, labels):
|
270 |
+
gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1])
|
271 |
+
|
272 |
+
logits = logits - logits[..., :1]
|
273 |
+
lengths = (labels[:, 0, :] != -100).sum(-1)
|
274 |
+
augment = (1 - gold_event).to(torch.float)
|
275 |
+
if self.force_root_constituent:
|
276 |
+
augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9
|
277 |
+
dist = torch_struct.TreeCRF(logits + augment, lengths=lengths)
|
278 |
+
|
279 |
+
pred_score = dist.max
|
280 |
+
gold_score = (logits * gold_event).sum((1, 2, 3))
|
281 |
+
|
282 |
+
margin_losses = F.relu(pred_score - gold_score)
|
283 |
+
|
284 |
+
if self.reduction == "none":
|
285 |
+
return margin_losses
|
286 |
+
elif self.reduction == "mean":
|
287 |
+
return margin_losses.mean()
|
288 |
+
elif self.reduction == "sum":
|
289 |
+
return margin_losses.sum()
|
290 |
+
else:
|
291 |
+
assert False, f"Unexpected reduction: {self.reduction}"
|
benepar/integrations/__init__.py
ADDED
File without changes
|
benepar/integrations/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (208 Bytes). View file
|
|
benepar/integrations/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (191 Bytes). View file
|
|
benepar/integrations/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (173 Bytes). View file
|
|
benepar/integrations/__pycache__/downloader.cpython-310.pyc
ADDED
Binary file (1.35 kB). View file
|
|
benepar/integrations/__pycache__/downloader.cpython-37.pyc
ADDED
Binary file (1.31 kB). View file
|
|
benepar/integrations/__pycache__/downloader.cpython-38.pyc
ADDED
Binary file (1.31 kB). View file
|
|
benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc
ADDED
Binary file (11.1 kB). View file
|
|
benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc
ADDED
Binary file (11.2 kB). View file
|
|
benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc
ADDED
Binary file (4.44 kB). View file
|
|
benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc
ADDED
Binary file (4.32 kB). View file
|
|
benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc
ADDED
Binary file (4.36 kB). View file
|
|
benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc
ADDED
Binary file (6.63 kB). View file
|
|