Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoConfig, AutoModel | |
from . import char_lstm | |
from . import decode_chart | |
from . import nkutil | |
from .partitioned_transformer import ( | |
ConcatPositionalEncoding, | |
FeatureDropout, | |
PartitionedTransformerEncoder, | |
PartitionedTransformerEncoderLayer, | |
) | |
from . import parse_base | |
from . import retokenization | |
from . import subbatching | |
class ChartParser(nn.Module, parse_base.BaseParser): | |
def __init__( | |
self, | |
tag_vocab, | |
label_vocab, | |
char_vocab, | |
hparams, | |
pretrained_model_path=None, | |
): | |
super().__init__() | |
self.config = locals() | |
self.config.pop("self") | |
self.config.pop("__class__") | |
self.config.pop("pretrained_model_path") | |
self.config["hparams"] = hparams.to_dict() | |
self.tag_vocab = tag_vocab | |
self.label_vocab = label_vocab | |
self.char_vocab = char_vocab | |
self.d_model = hparams.d_model | |
self.char_encoder = None | |
self.pretrained_model = None | |
if hparams.use_chars_lstm: | |
assert ( | |
not hparams.use_pretrained | |
), "use_chars_lstm and use_pretrained are mutually exclusive" | |
self.retokenizer = char_lstm.RetokenizerForCharLSTM(self.char_vocab) | |
self.char_encoder = char_lstm.CharacterLSTM( | |
max(self.char_vocab.values()) + 1, | |
hparams.d_char_emb, | |
hparams.d_model // 2, # Half-size to leave room for | |
# partitioned positional encoding | |
char_dropout=hparams.char_lstm_input_dropout, | |
) | |
elif hparams.use_pretrained: | |
if pretrained_model_path is None: | |
self.retokenizer = retokenization.Retokenizer( | |
hparams.pretrained_model, retain_start_stop=True | |
) | |
self.pretrained_model = AutoModel.from_pretrained( | |
hparams.pretrained_model | |
) | |
else: | |
self.retokenizer = retokenization.Retokenizer( | |
pretrained_model_path, retain_start_stop=True | |
) | |
self.pretrained_model = AutoModel.from_config( | |
AutoConfig.from_pretrained(pretrained_model_path) | |
) | |
d_pretrained = self.pretrained_model.config.hidden_size | |
if hparams.use_encoder: | |
self.project_pretrained = nn.Linear( | |
d_pretrained, hparams.d_model // 2, bias=False | |
) | |
else: | |
self.project_pretrained = nn.Linear( | |
d_pretrained, hparams.d_model, bias=False | |
) | |
if hparams.use_encoder: | |
self.morpho_emb_dropout = FeatureDropout(hparams.morpho_emb_dropout) | |
self.add_timing = ConcatPositionalEncoding( | |
d_model=hparams.d_model, | |
max_len=hparams.encoder_max_len, | |
) | |
encoder_layer = PartitionedTransformerEncoderLayer( | |
hparams.d_model, | |
n_head=hparams.num_heads, | |
d_qkv=hparams.d_kv, | |
d_ff=hparams.d_ff, | |
ff_dropout=hparams.relu_dropout, | |
residual_dropout=hparams.residual_dropout, | |
attention_dropout=hparams.attention_dropout, | |
) | |
self.encoder = PartitionedTransformerEncoder( | |
encoder_layer, hparams.num_layers | |
) | |
else: | |
self.morpho_emb_dropout = None | |
self.add_timing = None | |
self.encoder = None | |
self.f_label = nn.Sequential( | |
nn.Linear(hparams.d_model, hparams.d_label_hidden), | |
nn.LayerNorm(hparams.d_label_hidden), | |
nn.ReLU(), | |
nn.Linear(hparams.d_label_hidden, max(label_vocab.values())), | |
) | |
if hparams.predict_tags: | |
self.f_tag = nn.Sequential( | |
nn.Linear(hparams.d_model, hparams.d_tag_hidden), | |
nn.LayerNorm(hparams.d_tag_hidden), | |
nn.ReLU(), | |
nn.Linear(hparams.d_tag_hidden, max(tag_vocab.values()) + 1), | |
) | |
self.tag_loss_scale = hparams.tag_loss_scale | |
self.tag_from_index = {i: label for label, i in tag_vocab.items()} | |
else: | |
self.f_tag = None | |
self.tag_from_index = None | |
self.decoder = decode_chart.ChartDecoder( | |
label_vocab=self.label_vocab, | |
force_root_constituent=hparams.force_root_constituent, | |
) | |
self.criterion = decode_chart.SpanClassificationMarginLoss( | |
reduction="sum", force_root_constituent=hparams.force_root_constituent | |
) | |
self.parallelized_devices = None | |
def device(self): | |
if self.parallelized_devices is not None: | |
return self.parallelized_devices[0] | |
else: | |
return next(self.f_label.parameters()).device | |
def output_device(self): | |
if self.parallelized_devices is not None: | |
return self.parallelized_devices[1] | |
else: | |
return next(self.f_label.parameters()).device | |
def parallelize(self, *args, **kwargs): | |
self.parallelized_devices = (torch.device("cuda", 0), torch.device("cuda", 1)) | |
for child in self.children(): | |
if child != self.pretrained_model: | |
child.to(self.output_device) | |
self.pretrained_model.parallelize(*args, **kwargs) | |
def from_trained(cls, model_path): | |
if os.path.isdir(model_path): | |
# Multi-file format used when exporting models for release. | |
# Unlike the checkpoints saved during training, these files include | |
# all tokenizer parameters and a copy of the pre-trained model | |
# config (rather than downloading these on-demand). | |
config = AutoConfig.from_pretrained(model_path).benepar | |
state_dict = torch.load( | |
os.path.join(model_path, "benepar_model.bin"), map_location="cpu" | |
) | |
config["pretrained_model_path"] = model_path | |
else: | |
# Single-file format used for saving checkpoints during training. | |
data = torch.load(model_path, map_location="cpu") | |
config = data["config"] | |
state_dict = data["state_dict"] | |
hparams = config["hparams"] | |
if "force_root_constituent" not in hparams: | |
hparams["force_root_constituent"] = True | |
config["hparams"] = nkutil.HParams(**hparams) | |
parser = cls(**config) | |
parser.load_state_dict(state_dict) | |
return parser | |
def encode(self, example): | |
if self.char_encoder is not None: | |
encoded = self.retokenizer(example.words, return_tensors="np") | |
else: | |
encoded = self.retokenizer(example.words, example.space_after) | |
if example.tree is not None: | |
encoded["span_labels"] = torch.tensor( | |
self.decoder.chart_from_tree(example.tree) | |
) | |
if self.f_tag is not None: | |
encoded["tag_labels"] = torch.tensor( | |
[-100] + [self.tag_vocab[tag] for _, tag in example.pos()] + [-100] | |
) | |
return encoded | |
def pad_encoded(self, encoded_batch): | |
batch = self.retokenizer.pad( | |
[ | |
{ | |
k: v | |
for k, v in example.items() | |
if (k != "span_labels" and k != "tag_labels") | |
} | |
for example in encoded_batch | |
], | |
return_tensors="pt", | |
) | |
if encoded_batch and "span_labels" in encoded_batch[0]: | |
batch["span_labels"] = decode_chart.pad_charts( | |
[example["span_labels"] for example in encoded_batch] | |
) | |
if encoded_batch and "tag_labels" in encoded_batch[0]: | |
batch["tag_labels"] = nn.utils.rnn.pad_sequence( | |
[example["tag_labels"] for example in encoded_batch], | |
batch_first=True, | |
padding_value=-100, | |
) | |
return batch | |
def _get_lens(self, encoded_batch): | |
if self.pretrained_model is not None: | |
return [len(encoded["input_ids"]) for encoded in encoded_batch] | |
return [len(encoded["valid_token_mask"]) for encoded in encoded_batch] | |
def encode_and_collate_subbatches(self, examples, subbatch_max_tokens): | |
batch_size = len(examples) | |
batch_num_tokens = sum(len(x.words) for x in examples) | |
encoded = [self.encode(example) for example in examples] | |
res = [] | |
for ids, subbatch_encoded in subbatching.split( | |
encoded, costs=self._get_lens(encoded), max_cost=subbatch_max_tokens | |
): | |
subbatch = self.pad_encoded(subbatch_encoded) | |
subbatch["batch_size"] = batch_size | |
subbatch["batch_num_tokens"] = batch_num_tokens | |
res.append((len(ids), subbatch)) | |
return res | |
def forward(self, batch): | |
valid_token_mask = batch["valid_token_mask"].to(self.output_device) | |
if ( | |
self.encoder is not None | |
and valid_token_mask.shape[1] > self.add_timing.timing_table.shape[0] | |
): | |
raise ValueError( | |
"Sentence of length {} exceeds the maximum supported length of " | |
"{}".format( | |
valid_token_mask.shape[1] - 2, | |
self.add_timing.timing_table.shape[0] - 2, | |
) | |
) | |
if self.char_encoder is not None: | |
assert isinstance(self.char_encoder, char_lstm.CharacterLSTM) | |
char_ids = batch["char_ids"].to(self.device) | |
extra_content_annotations = self.char_encoder(char_ids, valid_token_mask) | |
elif self.pretrained_model is not None: | |
input_ids = batch["input_ids"].to(self.device) | |
words_from_tokens = batch["words_from_tokens"].to(self.output_device) | |
pretrained_attention_mask = batch["attention_mask"].to(self.device) | |
extra_kwargs = {} | |
if "token_type_ids" in batch: | |
extra_kwargs["token_type_ids"] = batch["token_type_ids"].to(self.device) | |
if "decoder_input_ids" in batch: | |
extra_kwargs["decoder_input_ids"] = batch["decoder_input_ids"].to( | |
self.device | |
) | |
extra_kwargs["decoder_attention_mask"] = batch[ | |
"decoder_attention_mask" | |
].to(self.device) | |
pretrained_out = self.pretrained_model( | |
input_ids, attention_mask=pretrained_attention_mask, **extra_kwargs | |
) | |
features = pretrained_out.last_hidden_state.to(self.output_device) | |
features = features[ | |
torch.arange(features.shape[0])[:, None], | |
# Note that words_from_tokens uses index -100 for invalid positions | |
F.relu(words_from_tokens), | |
] | |
features.masked_fill_(~valid_token_mask[:, :, None], 0) | |
if self.encoder is not None: | |
extra_content_annotations = self.project_pretrained(features) | |
if self.encoder is not None: | |
encoder_in = self.add_timing( | |
self.morpho_emb_dropout(extra_content_annotations) | |
) | |
annotations = self.encoder(encoder_in, valid_token_mask) | |
# Rearrange the annotations to ensure that the transition to | |
# fenceposts captures an even split between position and content. | |
annotations = torch.cat( | |
[ | |
annotations[..., 0::2], | |
annotations[..., 1::2], | |
], | |
-1, | |
) | |
else: | |
assert self.pretrained_model is not None | |
annotations = self.project_pretrained(features) | |
if self.f_tag is not None: | |
tag_scores = self.f_tag(annotations) | |
else: | |
tag_scores = None | |
fencepost_annotations = torch.cat( | |
[ | |
annotations[:, :-1, : self.d_model // 2], | |
annotations[:, 1:, self.d_model // 2 :], | |
], | |
-1, | |
) | |
# Note that the bias added to the final layer norm is useless because | |
# this subtraction gets rid of it | |
span_features = ( | |
torch.unsqueeze(fencepost_annotations, 1) | |
- torch.unsqueeze(fencepost_annotations, 2) | |
)[:, :-1, 1:] | |
span_scores = self.f_label(span_features) | |
span_scores = torch.cat( | |
[span_scores.new_zeros(span_scores.shape[:-1] + (1,)), span_scores], -1 | |
) | |
return span_scores, tag_scores | |
def compute_loss(self, batch): | |
span_scores, tag_scores = self.forward(batch) | |
span_labels = batch["span_labels"].to(span_scores.device) | |
span_loss = self.criterion(span_scores, span_labels) | |
# Divide by the total batch size, not by the subbatch size | |
span_loss = span_loss / batch["batch_size"] | |
if tag_scores is None: | |
return span_loss | |
else: | |
tag_labels = batch["tag_labels"].to(tag_scores.device) | |
tag_loss = self.tag_loss_scale * F.cross_entropy( | |
tag_scores.reshape((-1, tag_scores.shape[-1])), | |
tag_labels.reshape((-1,)), | |
reduction="sum", | |
ignore_index=-100, | |
) | |
tag_loss = tag_loss / batch["batch_num_tokens"] | |
return span_loss + tag_loss | |
def _parse_encoded( | |
self, examples, encoded, return_compressed=False, return_scores=False | |
): | |
with torch.no_grad(): | |
batch = self.pad_encoded(encoded) | |
span_scores, tag_scores = self.forward(batch) | |
if return_scores: | |
span_scores_np = span_scores.cpu().numpy() | |
else: | |
# Start/stop tokens don't count, so subtract 2 | |
lengths = batch["valid_token_mask"].sum(-1) - 2 | |
charts_np = self.decoder.charts_from_pytorch_scores_batched( | |
span_scores, lengths.to(span_scores.device) | |
) | |
if tag_scores is not None: | |
tag_ids_np = tag_scores.argmax(-1).cpu().numpy() | |
else: | |
tag_ids_np = None | |
for i in range(len(encoded)): | |
example_len = len(examples[i].words) | |
if return_scores: | |
yield span_scores_np[i, :example_len, :example_len] | |
elif return_compressed: | |
output = self.decoder.compressed_output_from_chart(charts_np[i]) | |
if tag_ids_np is not None: | |
output = output.with_tags(tag_ids_np[i, 1 : example_len + 1]) | |
yield output | |
else: | |
if tag_scores is None: | |
leaves = examples[i].pos() | |
else: | |
predicted_tags = [ | |
self.tag_from_index[i] | |
for i in tag_ids_np[i, 1 : example_len + 1] | |
] | |
leaves = [ | |
(word, predicted_tag) | |
for predicted_tag, (word, gold_tag) in zip( | |
predicted_tags, examples[i].pos() | |
) | |
] | |
yield self.decoder.tree_from_chart(charts_np[i], leaves=leaves) | |
def parse( | |
self, | |
examples, | |
return_compressed=False, | |
return_scores=False, | |
subbatch_max_tokens=None, | |
): | |
training = self.training | |
self.eval() | |
encoded = [self.encode(example) for example in examples] | |
if subbatch_max_tokens is not None: | |
res = subbatching.map( | |
self._parse_encoded, | |
examples, | |
encoded, | |
costs=self._get_lens(encoded), | |
max_cost=subbatch_max_tokens, | |
return_compressed=return_compressed, | |
return_scores=return_scores, | |
) | |
else: | |
res = self._parse_encoded( | |
examples, | |
encoded, | |
return_compressed=return_compressed, | |
return_scores=return_scores, | |
) | |
res = list(res) | |
self.train(training) | |
return res | |