|
import os |
|
from typing import Union, List, Optional, Tuple |
|
|
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
|
|
|
class SentencePieceJA(PreTrainedTokenizer): |
|
def __init__(self, model_path, **kwargs): |
|
super().__init__(**kwargs) |
|
from tokenizers import Tokenizer |
|
self._tokenizer = Tokenizer.from_file(model_path) |
|
self.__pad_id = self._tokenize("<PAD>")[0] |
|
self.__bos_id = self._tokenize("<BOS>")[0] |
|
self.__eos_id = self._tokenize("<EOS>")[0] |
|
self.__unk_id = self._tokenize("<UNK>")[0] |
|
self.__mask_id = self._tokenize("<MASK>")[0] |
|
|
|
def get_vocab(self) -> int: |
|
return self._tokenizer.get_vocab() |
|
|
|
def vocab_size(self) -> int: |
|
return self._tokenizer.get_vocab_size() |
|
|
|
def _tokenize(self, text, **kwargs): |
|
return self._tokenizer.encode(text).ids |
|
|
|
def _convert_token_to_id(self, token): |
|
return token |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
return self._tokenizer.decode(index) |
|
|
|
|
|
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: |
|
return tokens |
|
|
|
def convert_ids_to_tokens( |
|
self, ids: Union[int, List[int]], skip_special_tokens: bool = False |
|
) -> Union[str, List[str]]: |
|
decoded = self._tokenizer.decode(ids) |
|
return decoded |
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
|
index = 0 |
|
if os.path.isdir(save_directory): |
|
vocab_file = os.path.join( |
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt' |
|
) |
|
else: |
|
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory |
|
with open(vocab_file, "w", encoding="utf-8") as writer: |
|
for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]): |
|
if index != token_index: |
|
index = token_index |
|
writer.write(token + "\n") |
|
index += 1 |
|
return (vocab_file,) |