File size: 2,164 Bytes
6c63bd9 da15cde 6c63bd9 da15cde 6c63bd9 da15cde 6c63bd9 da15cde 6c63bd9 da15cde 6c63bd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import os
from typing import Union, List, Optional, Tuple
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
class SentencePieceJA(PreTrainedTokenizer):
def __init__(self,
model_path = "./tokenizer.json",
pad = "<PAD>",
bos = "<BOS>",
eos = "<EOS>",
unk = "<UNK>",
mask = "<MASK>",
**kwargs):
from tokenizers import Tokenizer
self._tokenizer = Tokenizer.from_file(model_path)
super().__init__(
pad_token=pad,
bos_token=bos,
eos_token=eos,
unk_token=unk,
mask_token=mask,
**kwargs)
self.add_special_tokens({
'pad_token': pad,
'bos_token': bos,
'eos_token': eos,
'unk_token': unk,
'mask_token': mask
})
def get_vocab(self) -> int:
return self._tokenizer.get_vocab()
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()
def _tokenize(self, text, **kwargs):
return self._tokenizer.encode(text).tokens
def _convert_token_to_id(self, token):
return self._tokenizer.encode(token).ids[0]
def _convert_id_to_token(self, index: int) -> str:
return self._tokenizer.decode(index)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
if index != token_index:
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,) |