File size: 6,429 Bytes
d10ecd7
 
 
d27a756
d10ecd7
 
79b95c3
814ee6b
9495a4f
d10ecd7
 
9495a4f
 
d10ecd7
 
9495a4f
d10ecd7
 
d27a756
 
 
 
d10ecd7
 
 
 
 
 
 
 
f4973d4
d10ecd7
 
 
 
 
5425d5d
79b95c3
 
d10ecd7
 
9495a4f
d10ecd7
 
 
9495a4f
d10ecd7
e6543ac
 
 
 
d10ecd7
 
a6c67ec
d10ecd7
 
 
 
 
9495a4f
d10ecd7
 
 
 
 
e6543ac
d10ecd7
 
9495a4f
d10ecd7
 
9495a4f
 
 
 
 
b15345c
d10ecd7
 
 
 
7cb27ea
d10ecd7
 
814ee6b
 
 
 
 
 
 
 
 
 
d10ecd7
 
9495a4f
d10ecd7
 
 
9495a4f
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
9495a4f
 
d10ecd7
 
 
9495a4f
7a8d6d6
9495a4f
 
 
 
 
 
7a8d6d6
 
 
 
9495a4f
7cb27ea
9495a4f
 
 
 
 
 
7a8d6d6
 
 
480ae5d
 
 
7a8d6d6
9495a4f
 
 
814ee6b
 
 
d10ecd7
 
 
 
 
 
9495a4f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import json
import pandas as pd
import config
from vocab import load_tokener
from utils.zh_util import iter_vocab
from utils.log_util import logger
from utils.compress_rate_util import tokenize_corpus, unit_convertor
from functools import lru_cache


@lru_cache
def tokenize(text, tokenizer_type, color_num=5):
    """
    """
    logger.info("param=" + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False))
    pos_tokens = []
    tokenizer = load_tokener(tokenizer_type)
    if config.ADD_SPECIAL_TOKEN:
        encoding = tokenizer.encode(text, add_special_tokens=True)
    else:
        encoding = tokenizer.encode(text, add_special_tokens=False)

    table = []

    for idx, token_id in enumerate(encoding):
        decode_text = tokenizer.decode([token_id])  # 特殊字符解码后会统一变成 �,对应 "\ufffd"
        pos_tokens.extend([(decode_text, str(idx % color_num))])

        # token  "Byte":  # 这是 utf-8编码吧?
        token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0]
        if isinstance(token, bytes):
            try:
                token_str = token.decode("utf-8")
            except:
                token_str = token.decode("utf-8", errors="ignore")
                logger.error(f"{idx}: decode_error: " + json.dumps(    # gpt_35_turbo 经常有token会decode error,这里用来记录一下
                    {"tokenizer_type": tokenizer_type, "token": str(token), "token_str": token_str},
                    ensure_ascii=False))

            token_bytes = token
            # json_dumps = json.dumps(token_str)
        elif isinstance(token, str):
            token_str = token
            token_bytes = bytes(token_str, "utf-8")
            # json_dumps = json.dumps(token_str)
        else:
            logger.error(f"{idx}: wrong type for token {token_id} {type(token)} " + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False))
            token_str = token
            token_bytes = token
            # continue

        # ⭐
        # TODO: gpt3.5_turbo错误: 只有id和text是对的,token和 utf8都是错的。说明 convert_ids_to_tokens 出错了。
        table.append(
            {"TokenID": token_id,
             "Token": token_str,  # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama
             "Text": decode_text,  #
             # "Bytes": token_bytes,  # bytes类型在gradio前端页面被解码成字符串,比如   b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes)
             "UTF8 Bytes": str(token_bytes),
             # "Unicode": json_dumps  # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode
             }
        )

    table_df = pd.DataFrame(table)
    logger.info(f"tokenizer_type={tokenizer_type}, Tokens={table[:4]}")
    # print(table_df)

    return gr.update(value=pos_tokens, label=f"Tokens: {len(encoding)}"), table_df


@lru_cache
def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2):
    """
    input_text.change
    """
    pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1)
    pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2)
    return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2


@lru_cache
def basic_count(tokenizer_type):
    tokenizer = load_tokener(tokenizer_type)
    stats = iter_vocab(tokenizer)
    return tokenizer.vocab_size, f'{stats["中文token数"]}'
    # return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}'

def get_compress_rate(tokenizer_type, all_corpus, unit):
    corpus_name = all_corpus[0]
    tokenizer = load_tokener(tokenizer_type)
    compress_rate_stats = tokenize_corpus(tokenizer, corpus_name)
    compress_rate = unit_convertor(compress_rate_stats, unit)
    return compress_rate


@lru_cache
def get_overlap_token_size(tokenizer_type_1, tokenizer_type_2):
    tokenizer1 = load_tokener(tokenizer_type_1)
    tokenizer2 = load_tokener(tokenizer_type_2)

    vocab_set_1 = tokenizer1.get_vocab().keys()
    vocab_set_2 = tokenizer2.get_vocab().keys()

    token1 = next(iter(vocab_set_1))
    token2 = next(iter(vocab_set_2))
    if type(token1) != type(token2):  # bytes  str
        if isinstance(token1, str):
            vocab_set_1 = set([token.encode("utf-8") for token in vocab_set_1])
        if isinstance(token2, str):
            vocab_set_2 = set([token.encode("utf-8") for token in vocab_set_2])

    overlap_tokens = vocab_set_1 & vocab_set_2
    overlap_token_size = len(overlap_tokens)
    logger.info(
        f"{overlap_token_size} OverlapTokens of {tokenizer_type_1} {tokenizer_type_2}: {list(overlap_tokens)[:10]}")
    return overlap_token_size, overlap_token_size



def on_load(url_params, request: gr.Request):
    """
    onLoad
    """
    text = None
    tokenizer_type_1 = None
    tokenizer_type_2 = None
    try:
        url_params = json.loads(url_params)
    except:
        url_params = {}
    if request:
        logger.info(str(request.headers))
        client_ip = request.client.host
        # local_ip = socket.gethostbyname(socket.gethostbyname(""))
        # headers = request.kwargs['headers']
        # if headers and 'x-forwarded-for' in headers:
        #     x_forwarded_for = headers['x-forwarded-for']
        #     client_ip = x_forwarded_for.split(' ')[0] if x_forwarded_for else ""
        # if "referer" in request.headers:   # not work for huggingface-space
        #     url_params = parse_qs(urlparse(request.headers["referer"]).query)
        #     url_params = {k: v[0] for k, v in url_params.items() if len(v) > 0}
        tokenizer_type_1 = url_params.get("tokenizer1", config.default_tokenizer_type_1)
        tokenizer_type_2 = url_params.get("tokenizer2", config.default_tokenizer_type_2)
        text = url_params.get("text", config.default_user_input)
        logger.info(f"client_ip: {client_ip}; params: {url_params}")
    return text, tokenizer_type_1, tokenizer_type_2


def compress_rate_unit_change(unit):
    return gr.update(label=f"Compress Rate: {unit}"), gr.update(label=f"Compress Rate: {unit}"),

def test_coding():
    bytes1 = b'\xe4\xb8\xad'
    print(bytes1)  # b'\xe4\xb8\xad'


if __name__ == "__main__":
    print(get_overlap_token_size("gpt_35_turbo", "gpt_4"))
    # print(basic_count("internlm_chat_7b"))