sentencepiece_ja / sentencepiece_ja.py
if001's picture
fix
de00515
raw
history blame
2.01 kB
import os
from typing import Union, List, Optional, Tuple
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
class SentencePieceJA(PreTrainedTokenizer):
def __init__(self,
model_path = "./tokenizer.json",
pad = "<PAD>",
bos = "<BOS>",
eos = "<EOS>",
unk = "<UNK>",
mask = "<MASK>",
**kwargs):
from tokenizers import Tokenizer
self._tokenizer = Tokenizer.from_file(model_path)
super().__init__(**kwargs)
self.add_special_tokens({
'pad_token': pad,
'bos_token': bos,
'eos_token': eos,
'unk_token': unk,
'mask_token': mask
})
def get_vocab(self) -> int:
return self._tokenizer.get_vocab()
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()
def _tokenize(self, text, **kwargs):
return self._tokenizer.encode(text).tokens
def _convert_token_to_id(self, token):
return self._tokenizer.encode(token).ids[0]
def _convert_id_to_token(self, index: int) -> str:
return self._tokenizer.decode(index)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
if index != token_index:
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)