Tuchuanhuhuhu commited on
Commit
9c45970
1 Parent(s): 893df38

使用tiktoken精确计数输入token

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. utils.py +15 -10
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  gradio
2
  mdtex2html
3
  pypinyin
4
- jieba
5
  socksio
6
  tqdm
 
 
1
  gradio
2
  mdtex2html
3
  pypinyin
4
+ tiktoken
5
  socksio
6
  tqdm
7
+ colorama
utils.py CHANGED
@@ -12,8 +12,9 @@ import csv
12
  import mdtex2html
13
  from pypinyin import lazy_pinyin
14
  from presets import *
15
- import jieba
16
  from tqdm import tqdm
 
17
 
18
  if TYPE_CHECKING:
19
  from typing import TypedDict
@@ -47,11 +48,12 @@ def postprocess(
47
  )
48
  return y
49
 
50
- def count_words(input_str):
51
- print("计算输入字数中……")
52
- words = jieba.lcut(input_str)
 
53
  print("计算完成!")
54
- return len(words)
55
 
56
  def parse_text(text):
57
  lines = text.split("\n")
@@ -97,8 +99,7 @@ def construct_assistant(text):
97
  return construct_text("assistant", text)
98
 
99
  def construct_token_message(token, stream=False):
100
- extra = "【粗略计数(因为实时传输回答)】 " if stream else ""
101
- return f"{extra}Token 计数: {token}"
102
 
103
  def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
104
  headers = {
@@ -135,10 +136,12 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
135
  counter = 0
136
  status_text = "开始实时传输回答……"
137
  history.append(construct_user(inputs))
 
138
  if len(previous_token_count) == 0:
139
- rough_user_token_count = count_words(inputs) + count_words(system_prompt)
140
  else:
141
- rough_user_token_count = count_words(inputs)
 
142
  try:
143
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
144
  except requests.exceptions.ConnectTimeout:
@@ -162,7 +165,7 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
162
  # decode each line as response data is in bytes
163
  if chunklength > 6 and "delta" in chunk['choices'][0]:
164
  finish_reason = chunk['choices'][0]['finish_reason']
165
- status_text = construct_token_message(sum(previous_token_count)+token_counter+rough_user_token_count, stream=True)
166
  if finish_reason == "stop":
167
  print("生成完毕")
168
  yield get_return_value()
@@ -197,6 +200,7 @@ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previou
197
 
198
 
199
  def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
 
200
  if stream:
201
  print("使用流式传输")
202
  iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
@@ -207,6 +211,7 @@ def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count
207
  chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
208
  yield chatbot, history, status_text, token_count
209
  print(f"传输完毕。当前token计数为{token_count}")
 
210
  if stream:
211
  max_token = max_token_streaming
212
  else:
 
12
  import mdtex2html
13
  from pypinyin import lazy_pinyin
14
  from presets import *
15
+ import tiktoken
16
  from tqdm import tqdm
17
+ import colorama
18
 
19
  if TYPE_CHECKING:
20
  from typing import TypedDict
 
48
  )
49
  return y
50
 
51
+ def count_token(input_str):
52
+ print("计算输入Token计数中……")
53
+ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
54
+ length = len(encoding.encode("tiktoken is great!"))
55
  print("计算完成!")
56
+ return length
57
 
58
  def parse_text(text):
59
  lines = text.split("\n")
 
99
  return construct_text("assistant", text)
100
 
101
  def construct_token_message(token, stream=False):
102
+ return f"Token 计数: {token}"
 
103
 
104
  def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
105
  headers = {
 
136
  counter = 0
137
  status_text = "开始实时传输回答……"
138
  history.append(construct_user(inputs))
139
+ user_token_count = 0
140
  if len(previous_token_count) == 0:
141
+ user_token_count = count_token(inputs) + count_token(system_prompt)
142
  else:
143
+ user_token_count = count_token(inputs)
144
+ print(f"输入token计数: {user_token_count}")
145
  try:
146
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
147
  except requests.exceptions.ConnectTimeout:
 
165
  # decode each line as response data is in bytes
166
  if chunklength > 6 and "delta" in chunk['choices'][0]:
167
  finish_reason = chunk['choices'][0]['finish_reason']
168
+ status_text = construct_token_message(sum(previous_token_count)+token_counter+user_token_count, stream=True)
169
  if finish_reason == "stop":
170
  print("生成完毕")
171
  yield get_return_value()
 
200
 
201
 
202
  def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
203
+ print(colorama.Fore.BLUE + f"输入为:{inputs}" + colorama.Style.RESET_ALL)
204
  if stream:
205
  print("使用流式传输")
206
  iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
 
211
  chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
212
  yield chatbot, history, status_text, token_count
213
  print(f"传输完毕。当前token计数为{token_count}")
214
+ print(colorama.Fore.BLUE + f"回答为:{history[-1]['content']}" + colorama.Style.RESET_ALL)
215
  if stream:
216
  max_token = max_token_streaming
217
  else: