sentencepiece_ja / sentencepiece_ja.py
if001's picture
add
6c63bd9
raw
history blame
No virus
2.22 kB
import os
from typing import Union, List, Optional, Tuple
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
class SentencePieceJA(PreTrainedTokenizer):
def __init__(self, model_path, **kwargs):
super().__init__(**kwargs)
from tokenizers import Tokenizer
self._tokenizer = Tokenizer.from_file(model_path)
self.__pad_id = self._tokenize("<PAD>")[0]
self.__bos_id = self._tokenize("<BOS>")[0]
self.__eos_id = self._tokenize("<EOS>")[0]
self.__unk_id = self._tokenize("<UNK>")[0]
self.__mask_id = self._tokenize("<MASK>")[0]
def get_vocab(self) -> int:
return self._tokenizer.get_vocab()
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()
def _tokenize(self, text, **kwargs):
return self._tokenizer.encode(text).ids
def _convert_token_to_id(self, token):
return token
def _convert_id_to_token(self, index: int) -> str:
return self._tokenizer.decode(index)
# return self._tokenizer.id_to_token(index)
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
return tokens
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[str, List[str]]:
decoded = self._tokenizer.decode(ids)
return decoded
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
if index != token_index:
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)