""" Character LSTM implementation (matches https://arxiv.org/pdf/1805.01052.pdf) """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class CharacterLSTM(nn.Module): def __init__(self, num_embeddings, d_embedding, d_out, char_dropout=0.0, **kwargs): super().__init__() self.d_embedding = d_embedding self.d_out = d_out self.lstm = nn.LSTM( self.d_embedding, self.d_out // 2, num_layers=1, bidirectional=True ) self.emb = nn.Embedding(num_embeddings, self.d_embedding, **kwargs) self.char_dropout = nn.Dropout(char_dropout) def forward(self, chars_packed, valid_token_mask): inp_embs = nn.utils.rnn.PackedSequence( self.char_dropout(self.emb(chars_packed.data)), batch_sizes=chars_packed.batch_sizes, sorted_indices=chars_packed.sorted_indices, unsorted_indices=chars_packed.unsorted_indices, ) _, (lstm_out, _) = self.lstm(inp_embs) lstm_out = torch.cat([lstm_out[0], lstm_out[1]], -1) # Switch to a representation where there are dummy vectors for invalid # tokens generated by padding. res = lstm_out.new_zeros( (valid_token_mask.shape[0], valid_token_mask.shape[1], lstm_out.shape[-1]) ) res[valid_token_mask] = lstm_out return res class RetokenizerForCharLSTM: # Assumes that these control characters are not present in treebank text CHAR_UNK = "\0" CHAR_ID_UNK = 0 CHAR_START_SENTENCE = "\1" CHAR_START_WORD = "\2" CHAR_STOP_WORD = "\3" CHAR_STOP_SENTENCE = "\4" def __init__(self, char_vocab): self.char_vocab = char_vocab @classmethod def build_vocab(cls, sentences): char_set = set() for sentence in sentences: if isinstance(sentence, tuple): sentence = sentence[0] for word in sentence: char_set |= set(word) # If codepoints are small (e.g. Latin alphabet), index by codepoint # directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 char_vocab = {} # This also takes care of constants like CHAR_UNK, etc. for codepoint in range(highest_codepoint): char_vocab[chr(codepoint)] = codepoint return char_vocab else: char_vocab = {} char_vocab[cls.CHAR_UNK] = 0 char_vocab[cls.CHAR_START_SENTENCE] = 1 char_vocab[cls.CHAR_START_WORD] = 2 char_vocab[cls.CHAR_STOP_WORD] = 3 char_vocab[cls.CHAR_STOP_SENTENCE] = 4 for id_, char in enumerate(sorted(char_set), start=5): char_vocab[char] = id_ return char_vocab def __call__(self, words, space_after="ignored", return_tensors=None): if return_tensors != "np": raise NotImplementedError("Only return_tensors='np' is supported.") res = {} # Sentence-level start/stop tokens are encoded as 3 pseudo-chars # Within each word, account for 2 start/stop characters max_word_len = max(3, max(len(word) for word in words)) + 2 char_ids = np.zeros((len(words) + 2, max_word_len), dtype=int) word_lens = np.zeros(len(words) + 2, dtype=int) char_ids[0, :5] = [ self.char_vocab[self.CHAR_START_WORD], self.char_vocab[self.CHAR_START_SENTENCE], self.char_vocab[self.CHAR_START_SENTENCE], self.char_vocab[self.CHAR_START_SENTENCE], self.char_vocab[self.CHAR_STOP_WORD], ] word_lens[0] = 5 for i, word in enumerate(words, start=1): char_ids[i, 0] = self.char_vocab[self.CHAR_START_WORD] for j, char in enumerate(word, start=1): char_ids[i, j] = self.char_vocab.get(char, self.CHAR_ID_UNK) char_ids[i, j + 1] = self.char_vocab[self.CHAR_STOP_WORD] word_lens[i] = j + 2 char_ids[i + 1, :5] = [ self.char_vocab[self.CHAR_START_WORD], self.char_vocab[self.CHAR_STOP_SENTENCE], self.char_vocab[self.CHAR_STOP_SENTENCE], self.char_vocab[self.CHAR_STOP_SENTENCE], self.char_vocab[self.CHAR_STOP_WORD], ] word_lens[i + 1] = 5 res["char_ids"] = char_ids res["word_lens"] = word_lens res["valid_token_mask"] = np.ones_like(word_lens, dtype=bool) return res def pad(self, examples, return_tensors=None): if return_tensors != "pt": raise NotImplementedError("Only return_tensors='pt' is supported.") max_word_len = max(example["char_ids"].shape[-1] for example in examples) char_ids = torch.cat( [ F.pad( torch.tensor(example["char_ids"]), (0, max_word_len - example["char_ids"].shape[-1]), ) for example in examples ] ) word_lens = torch.cat( [torch.tensor(example["word_lens"]) for example in examples] ) valid_token_mask = nn.utils.rnn.pad_sequence( [torch.tensor(example["valid_token_mask"]) for example in examples], batch_first=True, padding_value=False, ) char_ids = nn.utils.rnn.pack_padded_sequence( char_ids, word_lens, batch_first=True, enforce_sorted=False ) return { "char_ids": char_ids, "valid_token_mask": valid_token_mask, }