Manu101 commited on
Commit
693faa9
1 Parent(s): 0974415

Upload 12 files

Browse files
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import random
3
+
4
+ import gradio as gr
5
+ from src import HindiTokenizer, BasicTokenizer
6
+
7
+ Basic = BasicTokenizer()
8
+ Basic._build_vocab()
9
+
10
+ Hindi = HindiTokenizer()
11
+ Hindi.load(
12
+ model_file_path=pathlib.Path(
13
+ "saved_vocabs/batch_1_Hindi_Tokenizer-test-all_batches-100_000_batchsize-initial_vocab_size_5000.model"))
14
+
15
+
16
+ def tokenize_and_color(text, tokenizer_choice="HindiTokenizer"):
17
+ if tokenizer_choice == "BasicTokenizer":
18
+ tokenizer = Basic
19
+ else:
20
+ tokenizer = Hindi
21
+
22
+ tokens = tokenizer.encode(text)
23
+
24
+ # colors = [
25
+ # "#FF5733", "#33FF57", "#3357FF", "#F333FF",
26
+ # "#33FFF3", "#F3FF33", "#FF3380", "#3380FF",
27
+ # "#83FF33", "#FF8333"
28
+ # ]
29
+ colors = [
30
+ "#FF5733", "#33FF57", "#3357FF", "#F333FF",
31
+ "#33FFF3", "#F3FF33", "#FF3380", "#3380FF",
32
+ "#83FF33", "#FF8333", "#7FDBFF", "#0074D9",
33
+ "#39CCCC", "#3D9970", "#2ECC40", "#01FF70",
34
+ "#FFDC00", "#FF851B", "#FF4136", "#85144b",
35
+ "#F012BE", "#B10DC9", "#AAAAAA", "#DDDDDD"
36
+ ]
37
+
38
+ colored_text = '<div style="word-wrap: break-word; white-space: pre-wrap;">'
39
+ token_color_mapping = {}
40
+ last_color = ""
41
+ for index, token in enumerate(tokens):
42
+ token_id = token
43
+ if token_id in token_color_mapping:
44
+ color = token_color_mapping[token_id]
45
+ else:
46
+ color = random.choice([c for c in colors if c != last_color])
47
+ last_color = color
48
+ token_color_mapping[token_id] = color
49
+ colored_text += f'<span id="{token_id}" style="color: {color}; margin-right: 20px;">{token}</span>'
50
+ colored_text += '</div>'
51
+
52
+ return colored_text
53
+
54
+
55
+ examples = [
56
+ ["आप कैसे हैं??"],
57
+ ["यह एक परीक्षण है।"],
58
+ ["लोरेम इप्सम एक छद्म-लैटिन पाठ है जिसका उपयोग मुद्रण और टाइपसेटिंग उद्योगों में किया जाता है।"]
59
+ ]
60
+
61
+ iface = gr.Interface(fn=tokenize_and_color,
62
+ title="Hindi Text Tokenizer",
63
+ description="Enter text to see the tokenized output with each token colored differently.",
64
+ inputs=[
65
+ gr.Textbox(lines=2, label="Input Text"),
66
+ # gr.Radio(choices=["BasicTokenizer", "HindiTokenizer"], label="Tokenizer Choice",
67
+ # value="HindiTokenizer")
68
+ ],
69
+ outputs=[
70
+ gr.HTML(label="Tokenized and Colored Text")
71
+ ],
72
+ examples=examples,
73
+ # theme=gr.themes.Soft()
74
+ theme=gr.themes.Base()
75
+ )
76
+ if __name__ == "__main__":
77
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ scrapy
saved_vocabs/batch_2_Hindi_Tokenizer-test-all_batches-100_000_batchsize-initial_vocab_size_5000.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a08206e8219876b874bdb5aedbd4080a0504e1de86b794cc4655b3d1847ee59
3
+ size 47214
src/Basictokenizer.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal (byte-level) Byte Pair Encoding tokenizer.
3
+
4
+ Algorithmically follows along the GPT tokenizer:
5
+ https://github.com/openai/gpt-2/blob/master/src/encoder.py
6
+
7
+ But:
8
+ - Does not handle the regular expression splitting pattern.
9
+ - Does not handle any special tokens.
10
+ """
11
+ import copy
12
+
13
+ from .base import Tokenizer, get_stats, merge
14
+
15
+
16
+ # class BasicTokenizer(Tokenizer):
17
+ #
18
+ # def __init__(self):
19
+ # super().__init__()
20
+ #
21
+ # def train(self, text, vocab_size, verbose=False):
22
+ # assert vocab_size >= 256
23
+ # num_merges = vocab_size - 256
24
+ #
25
+ # # input text preprocessing
26
+ # text_bytes = text.encode("utf-8") # raw bytes
27
+ # ids = list(text_bytes) # list of integers in range 0..255
28
+ #
29
+ # # iteratively merge the most common pairs to create new tokens
30
+ # merges = {} # (int, int) -> int
31
+ # vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
32
+ # for i in range(num_merges):
33
+ # # count up the number of times every consecutive pair appears
34
+ # stats = get_stats(ids)
35
+ # # find the pair with the highest count
36
+ # pair = max(stats, key=stats.get)
37
+ # # mint a new token: assign it the next available id
38
+ # idx = 256 + i
39
+ # # replace all occurrences of pair in ids with idx
40
+ # ids = merge(ids, pair, idx)
41
+ # # save the merge
42
+ # merges[pair] = idx
43
+ # vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
44
+ # # prints
45
+ # if verbose:
46
+ # print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
47
+ #
48
+ # # save class variables
49
+ # self.merges = merges # used in encode()
50
+ # self.vocab = vocab # used in decode()
51
+ #
52
+ # def decode(self, ids):
53
+ # # given ids (list of integers), return Python string
54
+ # text_bytes = b"".join(self.vocab[idx] for idx in ids)
55
+ # text = text_bytes.decode("utf-8", errors="replace")
56
+ # return text
57
+ #
58
+ # def encode(self, text):
59
+ # # given a string text, return the token ids
60
+ # text_bytes = text.encode("utf-8") # raw bytes
61
+ # ids = list(text_bytes) # list of integers in range 0..255
62
+ # while len(ids) >= 2:
63
+ # # find the pair with the lowest merge index
64
+ # stats = get_stats(ids)
65
+ # pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
66
+ # # subtle: if there are no more merges available, the key will
67
+ # # result in an inf for every single pair, and the min will be
68
+ # # just the first pair in the list, arbitrarily
69
+ # # we can detect this terminating case by a membership check
70
+ # if pair not in self.merges:
71
+ # break # nothing else can be merged anymore
72
+ # # otherwise let's merge the best pair (lowest merge index)
73
+ # idx = self.merges[pair]
74
+ # ids = merge(ids, pair, idx)
75
+ # return ids
76
+
77
+
78
+ class BasicTokenizer(Tokenizer):
79
+
80
+ def __init__(self):
81
+ super().__init__()
82
+ self.merge_counter = 0
83
+
84
+ def train(self, text, vocab_size, verbose=False):
85
+ # left assert in place just to introduce consistency and a hard check of the increase in vocab size and number of merges
86
+ assert vocab_size >= 256
87
+ num_merges = vocab_size - 256
88
+
89
+ current_batch_merge_counter = 0 # in case not all exact `num_merges` happen
90
+
91
+ # input text preprocessing
92
+ text_bytes = text.encode("utf-8") # encode to get all waw bytes
93
+ ids = list(text_bytes) # represent the bytes in ints
94
+
95
+ # use same merge dict if exists
96
+ self.merges = {} if self.merges is None else self.merges # to hold all merges (int, int) -> int
97
+
98
+ # Use same vocab for this Tokenizer object if it exists
99
+ # Tokenizer vocab: int -> bytes
100
+ self.vocab = {idx: bytes([idx]) for idx in range(256)} if self.vocab is None else self.vocab
101
+
102
+ # iteratively merge the MOST COMMON pair from the text
103
+ for i in range(num_merges):
104
+ # get count of pairs
105
+ stats = get_stats(ids)
106
+
107
+ # find the pair with the highest count
108
+ # pair = max(stats, key=stats.get)
109
+
110
+ # tmp_stats = copy.deepcopy(stats)
111
+
112
+ # get most occurring pair from ids
113
+ pair = max(stats, key=stats.get)
114
+
115
+ while pair in self.merges:
116
+ # pair was previously merged ... use this first to update IDS
117
+ # No need to add to merges and vocab, use previously stored token
118
+ already_merged_idx = self.merges[pair]
119
+
120
+ # just replace already merged pairs in ids and get new ids and no need to again add to merges and vocab
121
+ ids = merge(ids, pair, already_merged_idx)
122
+
123
+ stats = get_stats(ids)
124
+
125
+ if stats and len(ids) >= 2:
126
+ pair = max(stats, key=stats.get)
127
+ else:
128
+ # no new merges found in this incoming data batch
129
+ print(f"\n\nstopping merges as no new byte pair found in the current batch")
130
+ break
131
+
132
+ # this most occurring pair not merged yet in any data batch
133
+ # generate a new token considering how many have been generated so far for the same tokenizer
134
+ idx = len(self.vocab) + 1
135
+
136
+ # update current new generated tokens to add to self.merge_counter later
137
+ current_batch_merge_counter += 1
138
+
139
+ # replace all occurrences of `pair` above in `ids` with NEW `idx` token, add this one to merges & vocab
140
+ # Note: this pair has never been seen for merging
141
+ ids = merge(ids, pair, idx)
142
+ self.merges[pair] = idx
143
+ self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
144
+ if verbose:
145
+ print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had {stats[pair]} count")
146
+
147
+ self.merge_counter += current_batch_merge_counter
148
+
149
+ def decode(self, ids):
150
+ # given ids (list of integers), return Python string
151
+ text_bytes = b"".join(self.vocab[idx] for idx in ids)
152
+ text = text_bytes.decode("utf-8", errors="replace")
153
+ return text
154
+
155
+ def encode(self, text):
156
+ # input a string text, returns the token ids
157
+ text_bytes = text.encode("utf-8")
158
+ ids = list(text_bytes)
159
+ while len(ids) >= 2:
160
+ # here find the pair with the lowest merge index
161
+ stats = get_stats(ids)
162
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
163
+ # if no merges i.e. the pair is not in merges dict,
164
+ # the key will result in an `inf` for every single pair,
165
+ # and the min will be just the first pair in the list,
166
+ # we can detect this terminating case by a membership check
167
+ if pair not in self.merges:
168
+ break # nothing else can be merged anymore
169
+ # otherwise merge the best pair NOTE: (lowest merge index)
170
+ idx = self.merges[pair]
171
+ ids = merge(ids, pair, idx)
172
+ return ids
src/HindiTokenizer.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import time
4
+ from textwrap import dedent
5
+
6
+ import regex as re
7
+ import unicodedata
8
+
9
+ import utilities
10
+ from src.base import Tokenizer, get_stats, merge
11
+
12
+ whitespace = ' \t\n\r\v\f'
13
+ ascii_lowercase = 'abcdefghijklmnopqrstuvwxyz'
14
+ ascii_uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
15
+ ascii_letters = ascii_lowercase + ascii_uppercase
16
+ digits = '0123456789'
17
+ hexdigits = digits + 'abcdef' + 'ABCDEF'
18
+ octdigits = '01234567'
19
+ punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"""
20
+
21
+ ascii_printable = whitespace + ascii_letters + hexdigits + punctuation
22
+
23
+ # the main GPT text split patterns, see
24
+ # https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
25
+ GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
26
+ GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
27
+
28
+ """
29
+ Basic Devanagari: \u0900 to \u097F
30
+ Vedic Extensions: \u1CD0 to \u1CFF
31
+ Extended Devanagari: \uA8E0 to \uA8FF
32
+ """
33
+ # ignore case in compile below
34
+ SIMPLE_HINDI_PATTERN = r"""[\t\n\r\f\v]?|[^\r\n\p{Devanagari}\p{N}]?+\p{Devanagari}+|\\p{N}{1,}| ?[^\s\p{Devanagari}+\p{N}]++[\r\n]*|\s*[\r\n]*|\s+(?!\S)|\s+"""
35
+ EXTENDED_HINDI_PATTERN = r"""[\t\n\r\f\v]?|[^\r\n\p{Devanagari}\uA8E0-\uA8FF\u1CD0-\u1CFF\p{N}]?+[\p{Devanagari}\uA8E0-\uA8FF\u1CD0-\u1CFF]+|\p{N}{1,}| ?[^\s\p{Devanagari}+\p{N}\uA8E0-\uA8FF\u1CD0-\u1CFF]++[\r\n]*|\s*[\r\n]*|\s+(?!\S)|\s+"""
36
+
37
+
38
+ def replace_control_characters(s: str) -> str:
39
+ chars = []
40
+ for ch in s:
41
+ if unicodedata.category(ch)[0] != "C":
42
+ chars.append(ch) # this character is ok
43
+ else:
44
+ chars.append(f"\\u{ord(ch):04x}") # escape
45
+ return "".join(chars)
46
+
47
+
48
+ def render_token(t: bytes) -> str:
49
+ # pretty print a token, escaping control characters
50
+ s = t.decode('utf-8', errors='replace')
51
+ s = replace_control_characters(s)
52
+ return s
53
+
54
+
55
+ class HindiTokenizer:
56
+ def __init__(self, pattern=None, encoding="utf-8"):
57
+ self.pattern = SIMPLE_HINDI_PATTERN if pattern is None else pattern
58
+ self.compiled_pattern = re.compile(self.pattern, re.IGNORECASE, re.UNICODE)
59
+ self.inverse_special_tokens = {}
60
+ self.merges = None
61
+ self.vocab = None
62
+ self.encoding = encoding
63
+ self.hindi_varnmala_and_key_units = dedent("""
64
+ अ आ इ ई उ ऊ ए ऐ ओ औ अं अः ऋ ॠ
65
+ ा ि ी ु ू ृॄ ॅॆ े ैॉ ॊ ो ौ
66
+ क ख ग घ ङ क़ ख़ ग़ घ़ ङ़
67
+ च छ ज झ ञ ज़ झ़ ञ़
68
+ ट ठ ड ढ ण ड़ ढ़ ण़
69
+ त थ द ध न त़ थ़ द़ ध़ ऩ
70
+ प फ ब भ म प़ फ़ ब़ म़
71
+ य र ल ळ व य़ ऱ ल़ ऴ व़
72
+ श ष ॺ स ह श़ ष़ स़ ह़
73
+ ० १ २ ३ ४ ५ ६ ७ ८ ९
74
+
75
+ """)
76
+ self.special_tokens = {}
77
+ super().__init__()
78
+
79
+ def _build_vocab(self):
80
+ '''add other important ASCII units except English letters'''
81
+
82
+ print("\n====================================\n\n"
83
+ "Building initial Hindi vocabulary with basic Hindi letters and key tokens")
84
+ self.vocab = {}
85
+ ascii_letters_encoded = ascii_letters.encode(
86
+ encoding="utf-8") # was using this to ignore ASCII English letters, revisit/todo, hindi usage with English or day to day usage and chats may include english letter and what to fill with those blank idxes?
87
+ for idx in range(256):
88
+ self.vocab[idx] = bytes([idx])
89
+
90
+ max_idx = max(self.vocab.keys()) + 1
91
+
92
+ basic_hindi_alphabet = self.hindi_varnmala_and_key_units.strip().split()
93
+
94
+ for idx in range(len(basic_hindi_alphabet)):
95
+ encoded_char = basic_hindi_alphabet[idx].encode(encoding=self.encoding)
96
+
97
+ new_idx = idx + max_idx
98
+ self.vocab[new_idx] = encoded_char
99
+
100
+ for (pos0, pos1), idx in self.merges.items():
101
+ self.vocab[idx] = self.vocab[pos0] + self.vocab[pos1]
102
+
103
+ # NOW add special tokens defined in __init__()
104
+ # NOTE encode special tokens using .encode with UTF-8 encoding
105
+ for tok, idx in self.special_tokens.items():
106
+ self.vocab[idx] = tok.encode("utf-8")
107
+
108
+ print("\n=================\nVocab initialisation done...")
109
+ # verified the resumed letter from .model file b'\xe0\xa4\x85'.decode("utf-8") is indeed character 'अ' ;
110
+ # One index extra is skipped (number idx 357 so had to add +1 where needed when re-building vocab 😅)
111
+ # not needed here though.
112
+ return self.vocab
113
+
114
+ # @utilities.log_to_file("HindiTokenizer-train.log")
115
+ def train(self, text, vocab_size, verbose=False,
116
+ default_initial_vocab_size=256 + 101,
117
+ encoding="utf-8",
118
+ save_tokenizer_at_train_end: bool = False,
119
+ prefix_for_save: str = "Hindi_Tokenizer",
120
+ just_replacing_already_seen_tokens_counter_threshold=100,
121
+ minting_new_token_for_merge_threshold=10,
122
+ current_batch_num=None,
123
+ save_at_every_nth_iteration=100
124
+ ):
125
+ """
126
+ text: the incoming text sata in str
127
+
128
+ vocab_size: int: the new target vocab size to build, used to determine how many merges to run
129
+
130
+ verbose: bool: to print when a new token is generated and used to merge pairs in the data' ids
131
+
132
+ encoding: str="utf-8" : the encoding to use
133
+
134
+ save_tokenizer_at_train_end: bool: a flag to save incrementing vocab and merges dictionaries so later can be resumed and re-used
135
+
136
+ prefix_for_save: str: the prefix for saving tokenizer files
137
+
138
+ just_replacing_already_seen_tokens_counter_threshold: int = 50: a threshold int value to check if number of replacements in current batch is for existing pairs created previously
139
+ the idea is if a new data batch has no or very few pairs that can be generated as new entries then quickly stop and move to new data batch
140
+
141
+ minting_new_token_for_merge_threshold: int=10: another threshold for checking if new minted tokens are below or above this, used in conjunction with previous threshold value
142
+
143
+ current_batch_num: int or None, to indicate what batch number is currently running, for print logs and save files options
144
+ """
145
+ if self.vocab is None:
146
+ self._build_vocab()
147
+
148
+ print("\n`Training`...for HindiTokenizer")
149
+
150
+ assert vocab_size >= default_initial_vocab_size
151
+ num_merges = vocab_size - default_initial_vocab_size
152
+ stop_this_batch = False
153
+
154
+ if current_batch_num is not None and isinstance(current_batch_num, int):
155
+ current_batch_num = "batch_" + str(current_batch_num) + "_"
156
+ prefix_for_save = current_batch_num + prefix_for_save
157
+
158
+ # split the text up into text chunks
159
+ text_chunks = re.findall(self.compiled_pattern, text)
160
+
161
+ # input text preprocessing
162
+ ids = [list(ch.encode("utf-8")) for ch in text_chunks if len(ch) > 1]
163
+
164
+ # iteratively merge the MOST COMMON pair from the text
165
+ # use same merge dict if exists
166
+ self.merges = {} if self.merges is None else self.merges # to hold all merges (int, int) -> int
167
+
168
+ '''Some counters for helping to check running batch's work if all is into replacing already
169
+ created tokens/existing ones OR actually finding something new to mint new token & add to merge and vocab'''
170
+ minting_new_token_for_merge_counter = 0
171
+ just_replacing_already_seen_tokens_counter = 0
172
+
173
+ # run merging iteratively
174
+ for i in range(num_merges):
175
+ if i + 1 % save_at_every_nth_iteration == 0:
176
+ self.save(file_prefix=prefix_for_save + f"_at_{i}_iteration_",
177
+ save_to_folder=pathlib.Path("saved_vocabs"))
178
+
179
+ merge_start_time = time.perf_counter()
180
+ # count the number of times every consecutive pair appears
181
+ stats = {}
182
+ for chunk_ids in ids:
183
+ # passing in stats will update it in place, adding up counts
184
+ get_stats(chunk_ids, stats)
185
+
186
+ # find the pair with the highest count
187
+ pair = max(stats, key=stats.get)
188
+
189
+ while pair in self.merges:
190
+ replacing_time_start = time.perf_counter()
191
+ just_replacing_already_seen_tokens_counter += 1
192
+
193
+ '''A simple check that says: If pairs are already seen in this batch
194
+ and what happens more is just replacement of already existing pairs,
195
+ way more than generating new tokens, best is to skip this batch...
196
+ [use those thresholds to experiment further]'''
197
+
198
+ if just_replacing_already_seen_tokens_counter > just_replacing_already_seen_tokens_counter_threshold \
199
+ and minting_new_token_for_merge_counter < minting_new_token_for_merge_threshold:
200
+ print("\n\n===========\nStopping current batch as replacing previously learned merges is way"
201
+ f" higher than creating new merges\njust_replacing_already_seen_tokens_counter:"
202
+ f" {just_replacing_already_seen_tokens_counter}"
203
+ f" and minting_new_token_for_merge_counter: {minting_new_token_for_merge_counter}")
204
+ stop_this_batch = True
205
+ break
206
+
207
+ # pair was previously merged ... use this first to update IDS
208
+ # No need to add to merges and vocab, use previously seen and stored token
209
+ already_merged_idx = self.merges[pair]
210
+ print(f"\nPair: {pair} already in merged tokens... replacing in IDS...")
211
+ print(f"with.. id.. {already_merged_idx}")
212
+
213
+ # just replace already merged pairs in ids and get new ids and no need to again add to merges and vocab
214
+ ids = [merge(chunk_ids, pair, already_merged_idx) for chunk_ids in ids]
215
+
216
+ print(
217
+ f"\nReplacing existing pair:{pair} in IDs took :{time.perf_counter() - replacing_time_start} seconds")
218
+
219
+ # get updated stats now, here ids are list of lists, so use above way of updating stats
220
+ stats = {}
221
+ for chunk_ids in ids:
222
+ # passing in stats will update it in place
223
+ get_stats(chunk_ids, stats)
224
+
225
+ # just avoiding merging when ids become less than 2
226
+ if stats and len(ids) >= 2:
227
+ pair = max(stats, key=stats.get)
228
+ else:
229
+ # no new merges found in this incoming data batch
230
+ print(f"\n\nstopping merges as no new byte pair found in the current batch")
231
+ stop_this_batch = True
232
+ break
233
+
234
+ if stop_this_batch is True:
235
+ break
236
+
237
+ # mint a new token as the pair was already not in merges: assign it the next available id
238
+ idx = len(self.vocab) + 1
239
+
240
+ minting_new_token_for_merge_counter += 1
241
+
242
+ # replace all occurrences of pair in ids with idx
243
+ ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
244
+
245
+ # save the merge
246
+ self.merges[pair] = idx
247
+ self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
248
+
249
+ if verbose:
250
+ print(
251
+ f"\n\nmerge {i + 1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had"
252
+ f" {stats[pair]:_} occurrences."
253
+ f"\ntime taken: {time.perf_counter() - merge_start_time} seconds")
254
+
255
+ if save_tokenizer_at_train_end:
256
+ self.save(file_prefix=prefix_for_save, save_to_folder=pathlib.Path("saved_vocabs"))
257
+
258
+ def register_special_tokens(self, special_tokens):
259
+ # special_tokens is a dictionary of str -> int
260
+ # example: {"<|endoftext|>": 100257}
261
+ self.special_tokens = special_tokens
262
+ self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
263
+
264
+ @utilities.log_to_file("HindiTokenizer-decode.log")
265
+ def decode(self, ids):
266
+ print("\nDecoding...for HindiTokenizer")
267
+ # given ids (list of integers), return Python string
268
+ part_bytes = []
269
+ for idx in ids:
270
+ if idx in self.vocab:
271
+ part_bytes.append(self.vocab[idx])
272
+ elif idx in self.inverse_special_tokens:
273
+ part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
274
+ else:
275
+ raise ValueError(f"invalid token id: {idx}")
276
+ text_bytes = b"".join(part_bytes)
277
+ text = text_bytes.decode("utf-8", errors="replace")
278
+ return text
279
+
280
+ def _encode_chunk(self, text_bytes):
281
+ # return the token ids
282
+ # let's begin. first, convert all bytes to integers in range 0..255
283
+ ids = list(text_bytes)
284
+ while len(ids) >= 2:
285
+ # find the pair with the lowest merge index
286
+ stats = get_stats(ids)
287
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
288
+ # subtle: if there are no more merges available, the key will
289
+ # result in an inf for every single pair, and the min will be
290
+ # just the first pair in the list, arbitrarily
291
+ # we can detect this terminating case by a membership check
292
+ if pair not in self.merges:
293
+ break # nothing else can be merged anymore
294
+ # otherwise let's merge the best pair (lowest merge index)
295
+ idx = self.merges[pair]
296
+ ids = merge(ids, pair, idx)
297
+ return ids
298
+
299
+ def encode_ordinary(self, text):
300
+ """Encoding that ignores any special tokens."""
301
+ # split text into chunks of text by categories defined in regex pattern
302
+ text_chunks = re.findall(self.compiled_pattern, text)
303
+ # all chunks of text are encoded separately, then results are joined
304
+ ids = []
305
+ for chunk in text_chunks:
306
+ chunk_bytes = chunk.encode("utf-8") # raw bytes
307
+ chunk_ids = self._encode_chunk(chunk_bytes)
308
+ ids.extend(chunk_ids)
309
+ return ids
310
+
311
+ @utilities.log_to_file("HindiTokenizer-encode.log")
312
+ def encode(self, text, allowed_special="none_raise"):
313
+ """
314
+ Unlike encode_ordinary, this function handles special tokens.
315
+ allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
316
+ if none_raise, then an error is raised if any special token is encountered in text
317
+ this is the default tiktoken behavior right now as well
318
+ any other behavior is either annoying, or a major footgun
319
+ """
320
+ # decode the user desire w.r.t. handling of special tokens
321
+ special = None
322
+ if allowed_special == "all":
323
+ special = self.special_tokens
324
+ elif allowed_special == "none":
325
+ special = {}
326
+ elif allowed_special == "none_raise":
327
+ special = {}
328
+ assert all(token not in text for token in self.special_tokens)
329
+ elif isinstance(allowed_special, set):
330
+ special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
331
+ else:
332
+ raise ValueError(f"allowed_special={allowed_special} not understood")
333
+ if not special:
334
+ # shortcut: if no special tokens, just use the ordinary encoding
335
+ return self.encode_ordinary(text)
336
+ # otherwise, we have to be careful with potential special tokens in text
337
+ # we handle special tokens by splitting the text
338
+ # based on the occurrence of any exact match with any of the special tokens
339
+ # we can use re.split for this. note that surrounding the pattern with ()
340
+ # makes it into a capturing group, so the special tokens will be included
341
+ special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
342
+ special_chunks = re.split(special_pattern, text)
343
+ # now all the special characters are separated from the rest of the text
344
+ # all chunks of text are encoded separately, then results are joined
345
+ ids = []
346
+ for part in special_chunks:
347
+ if part in special:
348
+ # this is a special token, encode it separately as a special case
349
+ ids.append(special[part])
350
+ else:
351
+ # this is an ordinary sequence, encode it normally
352
+ ids.extend(self.encode_ordinary(part))
353
+ return ids
354
+
355
+ # directly from BPE repo
356
+ def save(self, file_prefix, save_to_folder: pathlib.Path, version=1):
357
+ """
358
+ Saves two files: file_prefix.vocab and file_prefix.model
359
+ This is inspired (but not equivalent to!) sentencepiece's model saving:
360
+ - model file is the critical one, intended for load()
361
+ - vocab file is just a pretty printed version for human inspection only
362
+ """
363
+ print("Saving tokenizer...")
364
+ # write the model: to be used in load() later
365
+ assert save_to_folder is not None and isinstance(save_to_folder,
366
+ pathlib.Path), \
367
+ "the Path passed to store vocab and models seems to be wrong"
368
+
369
+ model_file = file_prefix + ".model"
370
+ model_file = os.path.join(os.path.abspath(save_to_folder), model_file)
371
+
372
+ with open(model_file, 'w') as f:
373
+ f.write(f"version:{version}\n")
374
+ f.write(f"{self.pattern}\n")
375
+ # write the special tokens, first the number of them, then each one
376
+ f.write(f"{len(self.special_tokens)}\n")
377
+ for special, idx in self.special_tokens.items():
378
+ f.write(f"{special} {idx}\n")
379
+ # the merges dict
380
+ for idx1, idx2 in self.merges:
381
+ f.write(f"{idx1} {idx2}\n")
382
+
383
+ # write the vocab
384
+ vocab_file = file_prefix + ".vocab"
385
+ vocab_file = os.path.join(save_to_folder, vocab_file)
386
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
387
+ with open(vocab_file, "w", encoding="utf-8") as f:
388
+ for idx, token in self.vocab.items():
389
+ # note: many tokens may be partial utf-8 sequences
390
+ # and cannot be decoded into valid strings. Here we're using
391
+ # errors='replace' to replace them with the replacement char �.
392
+ # this also means that we couldn't possibly use .vocab in load()
393
+ # because decoding in this way is a lossy operation!
394
+ s = render_token(token)
395
+ # find the children of this token, if any
396
+ if idx in inverted_merges:
397
+ # if this token has children, render it nicely as a merge
398
+ idx0, idx1 = inverted_merges[idx]
399
+ s0 = render_token(self.vocab[idx0])
400
+ s1 = render_token(self.vocab[idx1])
401
+ f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
402
+ else:
403
+ # otherwise this is leaf token, just print it
404
+ # (this should just be the first 256 tokens, the bytes)
405
+ f.write(f"[{s}] {idx}\n")
406
+
407
+ def load(self, model_file_path):
408
+ """Inverse of save() but only for the model file"""
409
+ if isinstance(model_file_path, pathlib.Path):
410
+ model_file_path = str(model_file_path.absolute())
411
+ assert model_file_path.endswith(".model")
412
+ # read the model file
413
+ merges = {}
414
+ special_tokens = {}
415
+ # 256 for default first 256 chars and their bytes next 101 for Hindi
416
+ idx = 256 + 101 + 1 # One index extra is skipped initially when creating merges (number idx 357 so had to add +1 where needed when re-building vocab 😅)
417
+ with open(model_file_path, 'r', encoding="utf-8") as f:
418
+ # read the version
419
+ version = f.readline().strip()
420
+ print(version)
421
+
422
+ # read the pattern
423
+ self.pattern = f.readline().strip()
424
+
425
+ # read the special tokens
426
+ num_special = int(f.readline().strip())
427
+ for _ in range(num_special):
428
+ special, special_idx = f.readline().strip().split()
429
+ special_tokens[special] = int(special_idx)
430
+ # read the merges
431
+ for line in f:
432
+ idx1, idx2 = map(int, line.split())
433
+ merges[(idx1, idx2)] = idx
434
+ idx += 1
435
+ self.merges = merges
436
+ self.special_tokens = special_tokens
437
+ self.vocab = self._build_vocab()
438
+
439
+ # if __name__ == "__main__":
440
+ # custom_text = """
441
+ # <|endoftext|>ूज रहा है जहाँ चकित हो जन-जन देख अकाज
442
+ # सात वर्ष हो गये राह में, अटका कहाँ स्वराज?
443
+ #
444
+ # अटका कहाँ स्वराज? बोल दिल्ली! तू क्या कहती है?
445
+ # तू रानी बन गयी वेदना जनता क्यों सहती है?
446
+ # सबके भाग्य दबा रखे हैं किसने अपने कर में?
447
+ # उतरी थी जो विभा, हुई बंदिनी बता किस घर में
448
+ #
449
+ # समर शेष है, यह प्रकाश बंदीगृह से छूटेगा
450
+ # और नहीं तो तुझ पर पापिनी! महावज्र टूटेगा
451
+ #
452
+ # समर शेष है, उस स्वराज को सत्य बनाना होगा
453
+ # जिसका है ये न्यास उसे सत्वर पहुँचाना होगा
454
+ # धारा के मग में अनेक जो पर्वत खडे हुए हैं
455
+ # गंगा का पथ रोक इन्द्र के गज जो अडे हुए हैं
456
+ #
457
+ # कह दो उनसे झुके अगर तो जग मे यश पाएंगे
458
+ # अड़े रहे अगर तो ऐरावत पत्तों से बह जाऐंगे<|fim_prefix|><|endofprompt|>
459
+ # """.strip()
460
+ # special_tokens = {
461
+ # '<|endoftext|>': 100257,
462
+ # '<|fim_prefix|>': 100258,
463
+ # '<|fim_middle|>': 100259,
464
+ # '<|fim_suffix|>': 100260,
465
+ # '<|endofprompt|>': 100276
466
+ # }
467
+ # text = custom_text
468
+ # # create a Tokenizer and do 64 merges
469
+ # tokenizer = HindiTokenizer()
470
+ # tokenizer.train(text, 256 + 2, verbose=True)
471
+ # tokenizer.register_special_tokens(special_tokens)
472
+ # # verify that decode(encode(x)) == x
473
+ # assert tokenizer.decode(tokenizer.encode(text, "all")) == text
src/HuggingFace-based-tokenizer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # source: https://huggingface.co/learn/nlp-course/en/chapter6/8?fw=pt
2
+
3
+ from tokenizers import normalizers, models, decoders, pre_tokenizers, trainers, Tokenizer, processors
4
+ from datasets import load_dataset
5
+
6
+ dataset = load_dataset("wikitext", name="wikitext-2-raw-v1", split="train")
7
+
8
+
9
+ def get_training_corpus(batch_size=1000):
10
+ for i in range(0, len(dataset), batch_size):
11
+ yield dataset[i: i + batch_size]["text"]
12
+
13
+
14
+ tokenizer = Tokenizer(model=models.WordPiece(unk_token="[UNK]"))
15
+
16
+ tokenizer.normalizer = normalizers.Sequence([normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()])
17
+
18
+ print(tokenizer.normalizer.normalize_str("Héllò hôw are ü?"))
19
+
20
+ tokenizer.pre_tokenizer = pre_tokenizers.Whitespace() # pre_tokenizers.BertPreTokenizer()
21
+
22
+ print(tokenizer.pre_tokenizer.pre_tokenize_str("Let's test my pre-tokenizer."))
23
+ pre_tokenizer = pre_tokenizers.WhitespaceSplit()
24
+ print(pre_tokenizer.pre_tokenize_str("Let's test my pre-tokenizer."))
25
+
26
+ # manually selecting individual splitters
27
+ pre_tokenizer = pre_tokenizers.Sequence(
28
+ [pre_tokenizers.WhitespaceSplit(), pre_tokenizers.Punctuation()]
29
+ )
30
+ print(pre_tokenizer.pre_tokenize_str("Let's test my pre-tokenizer."))
31
+
32
+ special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
33
+ trainer = trainers.WordPieceTrainer(vocab_size=25000, special_tokens=special_tokens)
34
+
35
+ # train from an iterator
36
+ tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)
37
+ cls_token_id = tokenizer.token_to_id("[CLS]")
38
+ sep_token_id = tokenizer.token_to_id("[SEP]")
39
+
40
+ print(cls_token_id, sep_token_id)
41
+
42
+ """
43
+ To write the template for the TemplateProcessor, we have to specify how to treat a single sentence and a pair of sentences.
44
+ For both, we write the special tokens we want to use; the first (or single) sentence is represented by $A,
45
+ while the second sentence (if encoding a pair) is represented by $B. For each of these (special tokens and sentences),
46
+ we also specify the corresponding token type ID after a colon.
47
+
48
+ The classic BERT template is thus defined as follows:
49
+
50
+ """
51
+
52
+ tokenizer.post_processor = processors.TemplateProcessing(
53
+ single=f"[CLS]:0 $A:0 [SEP]:0",
54
+ pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
55
+ special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)],
56
+ )
57
+
58
+ encoding = tokenizer.encode("Let's test this tokenizer.")
59
+ print(encoding.tokens)
60
+
61
+ encoding = tokenizer.encode("Let's test this tokenizer...", "on a pair of sentences.")
62
+ print(encoding.tokens)
63
+ print(encoding.type_ids)
64
+
65
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
66
+
67
+ from transformers import PreTrainedTokenizerFast
68
+
69
+ wrapped_tokenizer = PreTrainedTokenizerFast(
70
+ tokenizer_object=tokenizer,
71
+ # tokenizer_file="tokenizer.json", # You can load from the tokenizer file, alternatively
72
+ unk_token="[UNK]",
73
+ pad_token="[PAD]",
74
+ cls_token="[CLS]",
75
+ sep_token="[SEP]",
76
+ mask_token="[MASK]",
77
+ )
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import Tokenizer
2
+ from .Basictokenizer import BasicTokenizer
3
+ from .HindiTokenizer import HindiTokenizer
src/__pycache__/Basictokenizer.cpython-312.pyc ADDED
Binary file (4.35 kB). View file
 
src/__pycache__/HindiTokenizer.cpython-312.pyc ADDED
Binary file (20.3 kB). View file
 
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (280 Bytes). View file
 
src/__pycache__/base.cpython-312.pyc ADDED
Binary file (7.25 kB). View file
 
src/base.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unicodedata
2
+
3
+
4
+ def get_stats(ids, counts=None):
5
+ """
6
+ Given a list of ints/ids, count the pairwise occurence
7
+ Returns count dict
8
+ """
9
+ counts = {} if counts is None else counts
10
+ for pair in zip(ids, ids[1:]):
11
+ counts[pair] = counts.get(pair, 0) + 1
12
+
13
+ return counts
14
+
15
+
16
+ def merge(ids, pair_to_merge, idx_to_use):
17
+ """
18
+ find and merge the given `pair` and replace it with given `idx_to_use` in given list of ints/ids
19
+ Return updated list
20
+ """
21
+ new_ids = []
22
+
23
+ i = 0
24
+
25
+ while i < len(ids):
26
+ # check pair match AND if 0th position is NOT last element
27
+ if i < len(ids) - 1 and (pair_to_merge[0] == ids[i] and pair_to_merge[1] == ids[i + 1]):
28
+ new_ids.append(idx_to_use) # pair found, append to new list of ids
29
+ i += 2 # skip by two elements as the pair is found
30
+ else:
31
+ # pair not found in the list, normal 1 element update
32
+ new_ids.append(ids[i]) # append the current item from old list as it is not a pair
33
+ i += 1
34
+ return new_ids
35
+
36
+
37
+ # helper functions taken directly from Karpathy's BPE repo
38
+ def replace_control_characters(s: str) -> str:
39
+ chars = []
40
+ for ch in s:
41
+ if unicodedata.category(ch)[0] != "C":
42
+ chars.append(ch) # this character is ok
43
+ else:
44
+ chars.append(f"\\u{ord(ch):04x}") # escape
45
+ return "".join(chars)
46
+
47
+
48
+ def render_token(t: bytes) -> str:
49
+ # pretty print a token, escaping control characters
50
+ s = t.decode('utf-8', errors='replace')
51
+ s = replace_control_characters(s)
52
+ return s
53
+
54
+
55
+ # base Tokenizer class
56
+
57
+ class Tokenizer:
58
+ """Base Tokenizer class, MUST inherit for use"""
59
+
60
+ def __init__(self) -> None:
61
+ # defaults -> no patterns used, no merges, use usual first 256 bytes as mapping/vocab items
62
+ self.merges = {} # this will hold the actual merged data eg: (101, 32) -> 256 , here say 101 chr e and 32 ' '(space) had max pair count -> replace this with next ID in order
63
+ self.pattern = "" # any regular expression pattern if to be used on raw text
64
+ self.special_tokens = {} # a mapping t hold any special tokens, empty here, to be used for subclasses, str -> int, e.g. {'<|endoftext|>': 90257}
65
+ self.vocab = self._build_vocab() # int -> bytes
66
+
67
+ def train(self, text, vocab_size, verbose=False):
68
+ # Tokenizer can train a vocabulary of size vocab_size from text
69
+ raise NotImplementedError
70
+
71
+ def encode(self, text):
72
+ # Tokenizer can encode a string into a list of integers
73
+ raise NotImplementedError
74
+
75
+ def decode(self, ids):
76
+ # Tokenizer can decode a list of integers into a string
77
+ raise NotImplementedError
78
+
79
+ def _build_vocab(self):
80
+ # here vocab starts from normal 256 bytes of ints and then merges after it
81
+ vocab = {idx: bytes([idx]) for idx in range(256)}
82
+
83
+ for (pos0, pos1), idx in self.merges.items():
84
+ vocab[idx] = vocab[pos0] + vocab[pos1]
85
+
86
+ # NOW add special tokens defined in __init__()
87
+ # NOTE encode special tokens using .encode with UTF-8 encoding
88
+ for tok, idx in self.special_tokens.items():
89
+ vocab[idx] = tok.encode("utf-8")
90
+
91
+ # directly from BPE repo
92
+ def save(self, file_prefix):
93
+ """
94
+ Saves two files: file_prefix.vocab and file_prefix.model
95
+ This is inspired (but not equivalent to!) sentencepiece's model saving:
96
+ - model file is the critical one, intended for load()
97
+ - vocab file is just a pretty printed version for human inspection only
98
+ """
99
+ print("Saving tokenizer...")
100
+ # write the model: to be used in load() later
101
+ model_file = file_prefix + ".model"
102
+ with open(model_file, 'w') as f:
103
+ # write the version, pattern and merges, that's all that's needed
104
+ f.write("base v1\n")
105
+ f.write(f"{self.pattern}\n")
106
+ # write the special tokens, first the number of them, then each one
107
+ f.write(f"{len(self.special_tokens)}\n")
108
+ for special, idx in self.special_tokens.items():
109
+ f.write(f"{special} {idx}\n")
110
+ # the merges dict
111
+ for idx1, idx2 in self.merges:
112
+ f.write(f"{idx1} {idx2}\n")
113
+ # write the vocab: for the human to look at
114
+ vocab_file = file_prefix + ".vocab"
115
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
116
+ with open(vocab_file, "w", encoding="utf-8") as f:
117
+ for idx, token in self.vocab.items():
118
+ # note: many tokens may be partial utf-8 sequences
119
+ # and cannot be decoded into valid strings. Here we're using
120
+ # errors='replace' to replace them with the replacement char �.
121
+ # this also means that we couldn't possibly use .vocab in load()
122
+ # because decoding in this way is a lossy operation!
123
+ s = render_token(token)
124
+ # find the children of this token, if any
125
+ if idx in inverted_merges:
126
+ # if this token has children, render it nicely as a merge
127
+ idx0, idx1 = inverted_merges[idx]
128
+ s0 = render_token(self.vocab[idx0])
129
+ s1 = render_token(self.vocab[idx1])
130
+ f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
131
+ else:
132
+ # otherwise this is leaf token, just print it
133
+ # (this should just be the first 256 tokens, the bytes)
134
+ f.write(f"[{s}] {idx}\n")
135
+
136
+ def load(self, model_file):
137
+ """Inverse of save() but only for the model file"""
138
+ assert model_file.endswith(".model")
139
+ # read the model file
140
+ merges = {}
141
+ special_tokens = {}
142
+ idx = 256
143
+ with open(model_file, 'r', encoding="utf-8") as f:
144
+ # read the version
145
+ version = f.readline().strip()
146
+ print(version)
147
+
148
+ # read the pattern
149
+ self.pattern = f.readline().strip()
150
+
151
+ # read the special tokens
152
+ num_special = int(f.readline().strip())
153
+ for _ in range(num_special):
154
+ special, special_idx = f.readline().strip().split()
155
+ special_tokens[special] = int(special_idx)
156
+ # read the merges
157
+ for line in f:
158
+ idx1, idx2 = map(int, line.split())
159
+ merges[(idx1, idx2)] = idx
160
+ idx += 1
161
+ self.merges = merges
162
+ self.special_tokens = special_tokens
163
+ self.vocab = self._build_vocab()