Spaces:
Sleeping
Sleeping
""" | |
Converts from linguistically motivated word-based tokenization to subword | |
tokenization used by pre-trained models. | |
""" | |
import numpy as np | |
import torch | |
import transformers | |
def retokenize( | |
tokenizer, | |
words, | |
space_after, | |
return_attention_mask=True, | |
return_offsets_mapping=False, | |
return_tensors=None, | |
**kwargs | |
): | |
"""Re-tokenize into subwords. | |
Args: | |
tokenizer: An instance of transformers.PreTrainedTokenizerFast | |
words: List of words | |
space_after: A list of the same length as `words`, indicating whether | |
whitespace follows each word. | |
**kwargs: all remaining arguments are passed on to tokenizer.__call__ | |
Returns: | |
The output of tokenizer.__call__, with one additional dictionary field: | |
- **words_from_tokens** -- List of the same length as `words`, where | |
each entry is the index of the *last* subword that overlaps the | |
corresponding word. | |
""" | |
s = "".join([w + (" " if sp else "") for w, sp in zip(words, space_after)]) | |
word_offset_starts = np.cumsum( | |
[0] + [len(w) + (1 if sp else 0) for w, sp in zip(words, space_after)] | |
)[:-1] | |
word_offset_ends = word_offset_starts + np.asarray([len(w) for w in words]) | |
tokenized = tokenizer( | |
s, | |
return_attention_mask=return_attention_mask, | |
return_offsets_mapping=True, | |
return_tensors=return_tensors, | |
**kwargs | |
) | |
if return_offsets_mapping: | |
token_offset_mapping = tokenized["offset_mapping"] | |
else: | |
token_offset_mapping = tokenized.pop("offset_mapping") | |
if return_tensors is not None: | |
token_offset_mapping = np.asarray(token_offset_mapping)[0].tolist() | |
offset_mapping_iter = iter( | |
[ | |
(i, (start, end)) | |
for (i, (start, end)) in enumerate(token_offset_mapping) | |
if start != end | |
] | |
) | |
token_idx, (token_start, token_end) = next(offset_mapping_iter) | |
words_from_tokens = [-100] * len(words) | |
for word_idx, (word_start, word_end) in enumerate( | |
zip(word_offset_starts, word_offset_ends) | |
): | |
while token_end <= word_start: | |
token_idx, (token_start, token_end) = next(offset_mapping_iter) | |
if token_end > word_end: | |
words_from_tokens[word_idx] = token_idx | |
while token_end <= word_end: | |
words_from_tokens[word_idx] = token_idx | |
try: | |
token_idx, (token_start, token_end) = next(offset_mapping_iter) | |
except StopIteration: | |
assert word_idx == len(words) - 1 | |
break | |
if return_tensors == "np": | |
words_from_tokens = np.asarray(words_from_tokens, dtype=int) | |
elif return_tensors == "pt": | |
words_from_tokens = torch.tensor(words_from_tokens, dtype=torch.long) | |
elif return_tensors == "tf": | |
raise NotImplementedError("Returning tf tensors is not implemented") | |
tokenized["words_from_tokens"] = words_from_tokens | |
return tokenized | |
class Retokenizer: | |
def __init__(self, pretrained_model_name_or_path, retain_start_stop=False): | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path, fast=True | |
) | |
if not self.tokenizer.is_fast: | |
raise NotImplementedError( | |
"Converting from treebank tokenization to tokenization used by a " | |
"pre-trained model requires a 'fast' tokenizer, which appears to not " | |
"be available for this pre-trained model type." | |
) | |
self.retain_start_stop = retain_start_stop | |
self.is_t5 = "T5Tokenizer" in str(type(self.tokenizer)) | |
self.is_gpt2 = "GPT2Tokenizer" in str(type(self.tokenizer)) | |
if self.is_gpt2: | |
# The provided GPT-2 tokenizer does not specify a padding token by default | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
if self.retain_start_stop: | |
# When retain_start_stop is set, the next layer after the pre-trained model | |
# expects start and stop token embeddings. For BERT these can naturally be | |
# the feature vectors for CLS and SEP, but pre-trained models differ in the | |
# special tokens that they use. This code attempts to find special token | |
# positions for each pre-trained model. | |
dummy_ids = self.tokenizer.build_inputs_with_special_tokens([-100]) | |
if self.is_t5: | |
# For T5 we use the output from the decoder, which accepts inputs that | |
# are shifted relative to the encoder. | |
dummy_ids = [self.tokenizer.pad_token_id] + dummy_ids | |
if self.is_gpt2: | |
# For GPT-2, we append an eos token if special tokens are needed | |
dummy_ids = dummy_ids + [self.tokenizer.eos_token_id] | |
try: | |
input_idx = dummy_ids.index(-100) | |
except ValueError: | |
raise NotImplementedError( | |
"Could not automatically infer how to extract start/stop tokens " | |
"from this pre-trained model" | |
) | |
num_prefix_tokens = input_idx | |
num_suffix_tokens = len(dummy_ids) - input_idx - 1 | |
self.start_token_idx = None | |
self.stop_token_idx = None | |
if num_prefix_tokens > 0: | |
self.start_token_idx = num_prefix_tokens - 1 | |
if num_suffix_tokens > 0: | |
self.stop_token_idx = -num_suffix_tokens | |
if self.start_token_idx is None and num_suffix_tokens > 0: | |
self.start_token_idx = -1 | |
if self.stop_token_idx is None and num_prefix_tokens > 0: | |
self.stop_token_idx = 0 | |
if self.start_token_idx is None or self.stop_token_idx is None: | |
assert num_prefix_tokens == 0 and num_suffix_tokens == 0 | |
raise NotImplementedError( | |
"Could not automatically infer how to extract start/stop tokens " | |
"from this pre-trained model because the associated tokenizer " | |
"appears not to add any special start/stop/cls/sep/etc. tokens " | |
"to the sequence." | |
) | |
def __call__(self, words, space_after, **kwargs): | |
example = retokenize(self.tokenizer, words, space_after, **kwargs) | |
if self.is_t5: | |
# decoder_input_ids (which are shifted wrt input_ids) will be created after | |
# padding, but we adjust words_from_tokens now, in anticipation. | |
if isinstance(example["words_from_tokens"], list): | |
example["words_from_tokens"] = [ | |
x + 1 for x in example["words_from_tokens"] | |
] | |
else: | |
example["words_from_tokens"] += 1 | |
if self.retain_start_stop: | |
num_tokens = len(example["input_ids"]) | |
if self.is_t5: | |
num_tokens += 1 | |
if self.is_gpt2: | |
num_tokens += 1 | |
if kwargs.get("return_tensors") == "pt": | |
example["input_ids"] = torch.cat( | |
example["input_ids"], | |
torch.tensor([self.tokenizer.eos_token_id]), | |
) | |
example["attention_mask"] = torch.cat( | |
example["attention_mask"], torch.tensor([1]) | |
) | |
else: | |
example["input_ids"].append(self.tokenizer.eos_token_id) | |
example["attention_mask"].append(1) | |
if num_tokens > self.tokenizer.model_max_length: | |
raise ValueError( | |
f"Sentence of length {num_tokens} (in sub-word tokens) exceeds the " | |
f"maximum supported length of {self.tokenizer.model_max_length}" | |
) | |
start_token_idx = ( | |
self.start_token_idx | |
if self.start_token_idx >= 0 | |
else num_tokens + self.start_token_idx | |
) | |
stop_token_idx = ( | |
self.stop_token_idx | |
if self.stop_token_idx >= 0 | |
else num_tokens + self.stop_token_idx | |
) | |
if kwargs.get("return_tensors") == "pt": | |
example["words_from_tokens"] = torch.cat( | |
[ | |
torch.tensor([start_token_idx]), | |
example["words_from_tokens"], | |
torch.tensor([stop_token_idx]), | |
] | |
) | |
else: | |
example["words_from_tokens"] = ( | |
[start_token_idx] + example["words_from_tokens"] + [stop_token_idx] | |
) | |
return example | |
def pad(self, encoded_inputs, return_tensors=None, **kwargs): | |
if return_tensors != "pt": | |
raise NotImplementedError("Only return_tensors='pt' is supported.") | |
res = self.tokenizer.pad( | |
[ | |
{k: v for k, v in example.items() if k != "words_from_tokens"} | |
for example in encoded_inputs | |
], | |
return_tensors=return_tensors, | |
**kwargs | |
) | |
if self.tokenizer.padding_side == "right": | |
res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence( | |
[ | |
torch.tensor(example["words_from_tokens"]) | |
for example in encoded_inputs | |
], | |
batch_first=True, | |
padding_value=-100, | |
) | |
else: | |
# XLNet adds padding tokens on the left of the sequence, so | |
# words_from_tokens must be adjusted to skip the added padding tokens. | |
assert self.tokenizer.padding_side == "left" | |
res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence( | |
[ | |
torch.tensor(example["words_from_tokens"]) | |
+ (res["input_ids"].shape[-1] - len(example["input_ids"])) | |
for example in encoded_inputs | |
], | |
batch_first=True, | |
padding_value=-100, | |
) | |
if self.is_t5: | |
res["decoder_input_ids"] = torch.cat( | |
[ | |
torch.full_like( | |
res["input_ids"][:, :1], self.tokenizer.pad_token_id | |
), | |
res["input_ids"], | |
], | |
1, | |
) | |
res["decoder_attention_mask"] = torch.cat( | |
[ | |
torch.ones_like(res["attention_mask"][:, :1]), | |
res["attention_mask"], | |
], | |
1, | |
) | |
res["valid_token_mask"] = res["words_from_tokens"] != -100 | |
return res | |