add
Browse files- sentencepiece_ja.py +56 -0
sentencepiece_ja.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union, List, Optional, Tuple
|
3 |
+
|
4 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
5 |
+
|
6 |
+
class SentencePieceJA(PreTrainedTokenizer):
|
7 |
+
def __init__(self, model_path, **kwargs):
|
8 |
+
super().__init__(**kwargs)
|
9 |
+
from tokenizers import Tokenizer
|
10 |
+
self._tokenizer = Tokenizer.from_file(model_path)
|
11 |
+
self.__pad_id = self._tokenize("<PAD>")[0]
|
12 |
+
self.__bos_id = self._tokenize("<BOS>")[0]
|
13 |
+
self.__eos_id = self._tokenize("<EOS>")[0]
|
14 |
+
self.__unk_id = self._tokenize("<UNK>")[0]
|
15 |
+
self.__mask_id = self._tokenize("<MASK>")[0]
|
16 |
+
|
17 |
+
def get_vocab(self) -> int:
|
18 |
+
return self._tokenizer.get_vocab()
|
19 |
+
|
20 |
+
def vocab_size(self) -> int:
|
21 |
+
return self._tokenizer.get_vocab_size()
|
22 |
+
|
23 |
+
def _tokenize(self, text, **kwargs):
|
24 |
+
return self._tokenizer.encode(text).ids
|
25 |
+
|
26 |
+
def _convert_token_to_id(self, token):
|
27 |
+
return token
|
28 |
+
|
29 |
+
def _convert_id_to_token(self, index: int) -> str:
|
30 |
+
return self._tokenizer.decode(index)
|
31 |
+
# return self._tokenizer.id_to_token(index)
|
32 |
+
|
33 |
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
34 |
+
return tokens
|
35 |
+
|
36 |
+
def convert_ids_to_tokens(
|
37 |
+
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
38 |
+
) -> Union[str, List[str]]:
|
39 |
+
decoded = self._tokenizer.decode(ids)
|
40 |
+
return decoded
|
41 |
+
|
42 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
43 |
+
index = 0
|
44 |
+
if os.path.isdir(save_directory):
|
45 |
+
vocab_file = os.path.join(
|
46 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
50 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
51 |
+
for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
|
52 |
+
if index != token_index:
|
53 |
+
index = token_index
|
54 |
+
writer.write(token + "\n")
|
55 |
+
index += 1
|
56 |
+
return (vocab_file,)
|