CosyVoice commited on
Commit
ed87445
1 Parent(s): f6d44af

add 25hz text tokenizer

Browse files
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
cosyvoice/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import string
4
+ from dataclasses import dataclass, field
5
+ from functools import cached_property, lru_cache
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import tiktoken
9
+
10
+ LANGUAGES = {
11
+ "en": "english",
12
+ "zh": "chinese",
13
+ "de": "german",
14
+ "es": "spanish",
15
+ "ru": "russian",
16
+ "ko": "korean",
17
+ "fr": "french",
18
+ "ja": "japanese",
19
+ "pt": "portuguese",
20
+ "tr": "turkish",
21
+ "pl": "polish",
22
+ "ca": "catalan",
23
+ "nl": "dutch",
24
+ "ar": "arabic",
25
+ "sv": "swedish",
26
+ "it": "italian",
27
+ "id": "indonesian",
28
+ "hi": "hindi",
29
+ "fi": "finnish",
30
+ "vi": "vietnamese",
31
+ "he": "hebrew",
32
+ "uk": "ukrainian",
33
+ "el": "greek",
34
+ "ms": "malay",
35
+ "cs": "czech",
36
+ "ro": "romanian",
37
+ "da": "danish",
38
+ "hu": "hungarian",
39
+ "ta": "tamil",
40
+ "no": "norwegian",
41
+ "th": "thai",
42
+ "ur": "urdu",
43
+ "hr": "croatian",
44
+ "bg": "bulgarian",
45
+ "lt": "lithuanian",
46
+ "la": "latin",
47
+ "mi": "maori",
48
+ "ml": "malayalam",
49
+ "cy": "welsh",
50
+ "sk": "slovak",
51
+ "te": "telugu",
52
+ "fa": "persian",
53
+ "lv": "latvian",
54
+ "bn": "bengali",
55
+ "sr": "serbian",
56
+ "az": "azerbaijani",
57
+ "sl": "slovenian",
58
+ "kn": "kannada",
59
+ "et": "estonian",
60
+ "mk": "macedonian",
61
+ "br": "breton",
62
+ "eu": "basque",
63
+ "is": "icelandic",
64
+ "hy": "armenian",
65
+ "ne": "nepali",
66
+ "mn": "mongolian",
67
+ "bs": "bosnian",
68
+ "kk": "kazakh",
69
+ "sq": "albanian",
70
+ "sw": "swahili",
71
+ "gl": "galician",
72
+ "mr": "marathi",
73
+ "pa": "punjabi",
74
+ "si": "sinhala",
75
+ "km": "khmer",
76
+ "sn": "shona",
77
+ "yo": "yoruba",
78
+ "so": "somali",
79
+ "af": "afrikaans",
80
+ "oc": "occitan",
81
+ "ka": "georgian",
82
+ "be": "belarusian",
83
+ "tg": "tajik",
84
+ "sd": "sindhi",
85
+ "gu": "gujarati",
86
+ "am": "amharic",
87
+ "yi": "yiddish",
88
+ "lo": "lao",
89
+ "uz": "uzbek",
90
+ "fo": "faroese",
91
+ "ht": "haitian creole",
92
+ "ps": "pashto",
93
+ "tk": "turkmen",
94
+ "nn": "nynorsk",
95
+ "mt": "maltese",
96
+ "sa": "sanskrit",
97
+ "lb": "luxembourgish",
98
+ "my": "myanmar",
99
+ "bo": "tibetan",
100
+ "tl": "tagalog",
101
+ "mg": "malagasy",
102
+ "as": "assamese",
103
+ "tt": "tatar",
104
+ "haw": "hawaiian",
105
+ "ln": "lingala",
106
+ "ha": "hausa",
107
+ "ba": "bashkir",
108
+ "jw": "javanese",
109
+ "su": "sundanese",
110
+ "yue": "cantonese",
111
+ "minnan": "minnan",
112
+ "wuyu": "wuyu",
113
+ "dialect": "dialect",
114
+ "zh/en": "zh/en",
115
+ "en/zh": "en/zh",
116
+ }
117
+
118
+ # language code lookup by name, with a few language aliases
119
+ TO_LANGUAGE_CODE = {
120
+ **{language: code for code, language in LANGUAGES.items()},
121
+ "burmese": "my",
122
+ "valencian": "ca",
123
+ "flemish": "nl",
124
+ "haitian": "ht",
125
+ "letzeburgesch": "lb",
126
+ "pushto": "ps",
127
+ "panjabi": "pa",
128
+ "moldavian": "ro",
129
+ "moldovan": "ro",
130
+ "sinhalese": "si",
131
+ "castilian": "es",
132
+ "mandarin": "zh",
133
+ }
134
+
135
+ AUDIO_EVENT = {
136
+ "ASR": "ASR",
137
+ "AED": "AED",
138
+ "SER": "SER",
139
+ "Speech": "Speech",
140
+ "/Speech": "/Speech",
141
+ "BGM": "BGM",
142
+ "/BGM": "/BGM",
143
+ "Laughter": "Laughter",
144
+ "/Laughter": "/Laughter",
145
+ "Applause": "Applause",
146
+ "/Applause": "/Applause",
147
+ }
148
+
149
+ EMOTION = {
150
+ "HAPPY": "HAPPY",
151
+ "SAD": "SAD",
152
+ "ANGRY": "ANGRY",
153
+ "NEUTRAL": "NEUTRAL",
154
+ }
155
+
156
+ TTS_Vocal_Token = {
157
+ "TTS/B": "TTS/B",
158
+ "TTS/O": "TTS/O",
159
+ "TTS/Q": "TTS/Q",
160
+ "TTS/A": "TTS/A",
161
+ "TTS/CO": "TTS/CO",
162
+ "TTS/CL": "TTS/CL",
163
+ "TTS/H": "TTS/H",
164
+ **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
165
+ }
166
+
167
+
168
+ @dataclass
169
+ class Tokenizer:
170
+ """A thin wrapper around `tiktoken` providing quick access to special tokens"""
171
+
172
+ encoding: tiktoken.Encoding
173
+ num_languages: int
174
+ language: Optional[str] = None
175
+ task: Optional[str] = None
176
+ sot_sequence: Tuple[int] = ()
177
+ special_tokens: Dict[str, int] = field(default_factory=dict)
178
+
179
+ def __post_init__(self):
180
+ for special in self.encoding.special_tokens_set:
181
+ special_token = self.encoding.encode_single_token(special)
182
+ self.special_tokens[special] = special_token
183
+
184
+ sot: int = self.special_tokens["<|startoftranscript|>"]
185
+ translate: int = self.special_tokens["<|translate|>"]
186
+ transcribe: int = self.special_tokens["<|transcribe|>"]
187
+
188
+ langs = tuple(LANGUAGES.keys())[: self.num_languages]
189
+ sot_sequence = [sot]
190
+ if self.language is not None:
191
+ sot_sequence.append(sot + 1 + langs.index(self.language))
192
+ if self.task is not None:
193
+ task_token: int = transcribe if self.task == "transcribe" else translate
194
+ sot_sequence.append(task_token)
195
+
196
+ self.sot_sequence = tuple(sot_sequence)
197
+
198
+ def encode(self, text, **kwargs):
199
+ return self.encoding.encode(text, **kwargs)
200
+
201
+ def decode(self, token_ids: List[int], **kwargs) -> str:
202
+ token_ids = [t for t in token_ids if t < self.timestamp_begin]
203
+ return self.encoding.decode(token_ids, **kwargs)
204
+
205
+ def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
206
+ """
207
+ Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
208
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
209
+ """
210
+ return self.encoding.decode(token_ids, **kwargs)
211
+
212
+ def get_vocab_size(self) -> int:
213
+ return self.encoding.n_vocab
214
+
215
+ @cached_property
216
+ def eot(self) -> int:
217
+ return self.encoding.eot_token
218
+
219
+ @cached_property
220
+ def transcribe(self) -> int:
221
+ return self.special_tokens["<|transcribe|>"]
222
+
223
+ @cached_property
224
+ def translate(self) -> int:
225
+ return self.special_tokens["<|translate|>"]
226
+
227
+ @cached_property
228
+ def sot(self) -> int:
229
+ return self.special_tokens["<|startoftranscript|>"]
230
+
231
+ @cached_property
232
+ def sot_lm(self) -> int:
233
+ return self.special_tokens["<|startoflm|>"]
234
+
235
+ @cached_property
236
+ def sot_prev(self) -> int:
237
+ return self.special_tokens["<|startofprev|>"]
238
+
239
+ @cached_property
240
+ def no_speech(self) -> int:
241
+ return self.special_tokens["<|nospeech|>"]
242
+
243
+ @cached_property
244
+ def no_timestamps(self) -> int:
245
+ return self.special_tokens["<|notimestamps|>"]
246
+
247
+ @cached_property
248
+ def timestamp_begin(self) -> int:
249
+ return self.special_tokens["<|0.00|>"]
250
+
251
+ @cached_property
252
+ def language_token(self) -> int:
253
+ """Returns the token id corresponding to the value of the `language` field"""
254
+ if self.language is None:
255
+ raise ValueError("This tokenizer does not have language token configured")
256
+
257
+ return self.to_language_token(self.language)
258
+
259
+ def to_language_token(self, language):
260
+ if token := self.special_tokens.get(f"<|{language}|>", None):
261
+ return token
262
+
263
+ raise KeyError(f"Language {language} not found in tokenizer.")
264
+
265
+ @cached_property
266
+ def all_language_tokens(self) -> Tuple[int]:
267
+ result = []
268
+ for token, token_id in self.special_tokens.items():
269
+ if token.strip("<|>") in LANGUAGES:
270
+ result.append(token_id)
271
+ return tuple(result)[: self.num_languages]
272
+
273
+ @cached_property
274
+ def all_language_codes(self) -> Tuple[str]:
275
+ return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
276
+
277
+ @cached_property
278
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
279
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
280
+
281
+ @cached_property
282
+ def non_speech_tokens(self) -> Tuple[int]:
283
+ """
284
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
285
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
286
+
287
+ - ♪♪♪
288
+ - ( SPEAKING FOREIGN LANGUAGE )
289
+ - [DAVID] Hey there,
290
+
291
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
292
+ """
293
+ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
294
+ symbols += (
295
+ "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
296
+ )
297
+
298
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
299
+ # In case they're multiple tokens, suppress the first token, which is safe because:
300
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
301
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
302
+ miscellaneous = set("♩♪♫♬♭♮♯")
303
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
304
+
305
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
306
+ result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
307
+ for symbol in symbols + list(miscellaneous):
308
+ for tokens in [
309
+ self.encoding.encode(symbol),
310
+ self.encoding.encode(" " + symbol),
311
+ ]:
312
+ if len(tokens) == 1 or symbol in miscellaneous:
313
+ result.add(tokens[0])
314
+
315
+ return tuple(sorted(result))
316
+
317
+ def split_to_word_tokens(self, tokens: List[int]):
318
+ if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
319
+ # These languages don't typically use spaces, so it is difficult to split words
320
+ # without morpheme analysis. Here, we instead split words at any
321
+ # position where the tokens are decoded as valid unicode points
322
+ return self.split_tokens_on_unicode(tokens)
323
+
324
+ return self.split_tokens_on_spaces(tokens)
325
+
326
+ def split_tokens_on_unicode(self, tokens: List[int]):
327
+ decoded_full = self.decode_with_timestamps(tokens)
328
+ replacement_char = "\ufffd"
329
+
330
+ words = []
331
+ word_tokens = []
332
+ current_tokens = []
333
+ unicode_offset = 0
334
+
335
+ for token in tokens:
336
+ current_tokens.append(token)
337
+ decoded = self.decode_with_timestamps(current_tokens)
338
+
339
+ if (
340
+ replacement_char not in decoded
341
+ or decoded_full[unicode_offset + decoded.index(replacement_char)]
342
+ == replacement_char
343
+ ):
344
+ words.append(decoded)
345
+ word_tokens.append(current_tokens)
346
+ current_tokens = []
347
+ unicode_offset += len(decoded)
348
+
349
+ return words, word_tokens
350
+
351
+ def split_tokens_on_spaces(self, tokens: List[int]):
352
+ subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
353
+ words = []
354
+ word_tokens = []
355
+
356
+ for subword, subword_tokens in zip(subwords, subword_tokens_list):
357
+ special = subword_tokens[0] >= self.eot
358
+ with_space = subword.startswith(" ")
359
+ punctuation = subword.strip() in string.punctuation
360
+ if special or with_space or punctuation or len(words) == 0:
361
+ words.append(subword)
362
+ word_tokens.append(subword_tokens)
363
+ else:
364
+ words[-1] = words[-1] + subword
365
+ word_tokens[-1].extend(subword_tokens)
366
+
367
+ return words, word_tokens
368
+
369
+
370
+ @lru_cache(maxsize=None)
371
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
372
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
373
+ ranks = {
374
+ base64.b64decode(token): int(rank)
375
+ for token, rank in (line.split() for line in open(vocab_path) if line)
376
+ }
377
+ n_vocab = len(ranks)
378
+ special_tokens = {}
379
+
380
+ specials = [
381
+ "<|endoftext|>",
382
+ "<|startoftranscript|>",
383
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
384
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
385
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
386
+ "<|translate|>",
387
+ "<|transcribe|>",
388
+ "<|startoflm|>",
389
+ "<|startofprev|>",
390
+ "<|nospeech|>",
391
+ "<|notimestamps|>",
392
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
393
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
394
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
395
+ ]
396
+
397
+ for token in specials:
398
+ special_tokens[token] = n_vocab
399
+ n_vocab += 1
400
+
401
+ return tiktoken.Encoding(
402
+ name=os.path.basename(vocab_path),
403
+ explicit_n_vocab=n_vocab,
404
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
405
+ mergeable_ranks=ranks,
406
+ special_tokens=special_tokens,
407
+ )
408
+
409
+
410
+ @lru_cache(maxsize=None)
411
+ def get_tokenizer(
412
+ multilingual: bool,
413
+ *,
414
+ num_languages: int = 99,
415
+ language: Optional[str] = None,
416
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
417
+ ) -> Tokenizer:
418
+ if language is not None:
419
+ language = language.lower()
420
+ if language not in LANGUAGES:
421
+ if language in TO_LANGUAGE_CODE:
422
+ language = TO_LANGUAGE_CODE[language]
423
+ else:
424
+ raise ValueError(f"Unsupported language: {language}")
425
+
426
+ if multilingual:
427
+ encoding_name = "multilingual_zh_ja_yue_char_del"
428
+ language = language or "en"
429
+ task = task or "transcribe"
430
+ else:
431
+ encoding_name = "gpt2"
432
+ language = None
433
+ task = None
434
+
435
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
436
+
437
+ return Tokenizer(
438
+ encoding=encoding, num_languages=num_languages, language=language, task=task
439
+ )