Spaces:
Sleeping
Sleeping
""" | |
Character LSTM implementation (matches https://arxiv.org/pdf/1805.01052.pdf) | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class CharacterLSTM(nn.Module): | |
def __init__(self, num_embeddings, d_embedding, d_out, char_dropout=0.0, **kwargs): | |
super().__init__() | |
self.d_embedding = d_embedding | |
self.d_out = d_out | |
self.lstm = nn.LSTM( | |
self.d_embedding, self.d_out // 2, num_layers=1, bidirectional=True | |
) | |
self.emb = nn.Embedding(num_embeddings, self.d_embedding, **kwargs) | |
self.char_dropout = nn.Dropout(char_dropout) | |
def forward(self, chars_packed, valid_token_mask): | |
inp_embs = nn.utils.rnn.PackedSequence( | |
self.char_dropout(self.emb(chars_packed.data)), | |
batch_sizes=chars_packed.batch_sizes, | |
sorted_indices=chars_packed.sorted_indices, | |
unsorted_indices=chars_packed.unsorted_indices, | |
) | |
_, (lstm_out, _) = self.lstm(inp_embs) | |
lstm_out = torch.cat([lstm_out[0], lstm_out[1]], -1) | |
# Switch to a representation where there are dummy vectors for invalid | |
# tokens generated by padding. | |
res = lstm_out.new_zeros( | |
(valid_token_mask.shape[0], valid_token_mask.shape[1], lstm_out.shape[-1]) | |
) | |
res[valid_token_mask] = lstm_out | |
return res | |
class RetokenizerForCharLSTM: | |
# Assumes that these control characters are not present in treebank text | |
CHAR_UNK = "\0" | |
CHAR_ID_UNK = 0 | |
CHAR_START_SENTENCE = "\1" | |
CHAR_START_WORD = "\2" | |
CHAR_STOP_WORD = "\3" | |
CHAR_STOP_SENTENCE = "\4" | |
def __init__(self, char_vocab): | |
self.char_vocab = char_vocab | |
def build_vocab(cls, sentences): | |
char_set = set() | |
for sentence in sentences: | |
if isinstance(sentence, tuple): | |
sentence = sentence[0] | |
for word in sentence: | |
char_set |= set(word) | |
# If codepoints are small (e.g. Latin alphabet), index by codepoint | |
# directly | |
highest_codepoint = max(ord(char) for char in char_set) | |
if highest_codepoint < 512: | |
if highest_codepoint < 256: | |
highest_codepoint = 256 | |
else: | |
highest_codepoint = 512 | |
char_vocab = {} | |
# This also takes care of constants like CHAR_UNK, etc. | |
for codepoint in range(highest_codepoint): | |
char_vocab[chr(codepoint)] = codepoint | |
return char_vocab | |
else: | |
char_vocab = {} | |
char_vocab[cls.CHAR_UNK] = 0 | |
char_vocab[cls.CHAR_START_SENTENCE] = 1 | |
char_vocab[cls.CHAR_START_WORD] = 2 | |
char_vocab[cls.CHAR_STOP_WORD] = 3 | |
char_vocab[cls.CHAR_STOP_SENTENCE] = 4 | |
for id_, char in enumerate(sorted(char_set), start=5): | |
char_vocab[char] = id_ | |
return char_vocab | |
def __call__(self, words, space_after="ignored", return_tensors=None): | |
if return_tensors != "np": | |
raise NotImplementedError("Only return_tensors='np' is supported.") | |
res = {} | |
# Sentence-level start/stop tokens are encoded as 3 pseudo-chars | |
# Within each word, account for 2 start/stop characters | |
max_word_len = max(3, max(len(word) for word in words)) + 2 | |
char_ids = np.zeros((len(words) + 2, max_word_len), dtype=int) | |
word_lens = np.zeros(len(words) + 2, dtype=int) | |
char_ids[0, :5] = [ | |
self.char_vocab[self.CHAR_START_WORD], | |
self.char_vocab[self.CHAR_START_SENTENCE], | |
self.char_vocab[self.CHAR_START_SENTENCE], | |
self.char_vocab[self.CHAR_START_SENTENCE], | |
self.char_vocab[self.CHAR_STOP_WORD], | |
] | |
word_lens[0] = 5 | |
for i, word in enumerate(words, start=1): | |
char_ids[i, 0] = self.char_vocab[self.CHAR_START_WORD] | |
for j, char in enumerate(word, start=1): | |
char_ids[i, j] = self.char_vocab.get(char, self.CHAR_ID_UNK) | |
char_ids[i, j + 1] = self.char_vocab[self.CHAR_STOP_WORD] | |
word_lens[i] = j + 2 | |
char_ids[i + 1, :5] = [ | |
self.char_vocab[self.CHAR_START_WORD], | |
self.char_vocab[self.CHAR_STOP_SENTENCE], | |
self.char_vocab[self.CHAR_STOP_SENTENCE], | |
self.char_vocab[self.CHAR_STOP_SENTENCE], | |
self.char_vocab[self.CHAR_STOP_WORD], | |
] | |
word_lens[i + 1] = 5 | |
res["char_ids"] = char_ids | |
res["word_lens"] = word_lens | |
res["valid_token_mask"] = np.ones_like(word_lens, dtype=bool) | |
return res | |
def pad(self, examples, return_tensors=None): | |
if return_tensors != "pt": | |
raise NotImplementedError("Only return_tensors='pt' is supported.") | |
max_word_len = max(example["char_ids"].shape[-1] for example in examples) | |
char_ids = torch.cat( | |
[ | |
F.pad( | |
torch.tensor(example["char_ids"]), | |
(0, max_word_len - example["char_ids"].shape[-1]), | |
) | |
for example in examples | |
] | |
) | |
word_lens = torch.cat( | |
[torch.tensor(example["word_lens"]) for example in examples] | |
) | |
valid_token_mask = nn.utils.rnn.pad_sequence( | |
[torch.tensor(example["valid_token_mask"]) for example in examples], | |
batch_first=True, | |
padding_value=False, | |
) | |
char_ids = nn.utils.rnn.pack_padded_sequence( | |
char_ids, word_lens, batch_first=True, enforce_sorted=False | |
) | |
return { | |
"char_ids": char_ids, | |
"valid_token_mask": valid_token_mask, | |
} | |