Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
•
a2dfe6a
1
Parent(s):
b0a1d94
改进减少token逻辑
Browse files- ChuanhuChatbot.py +1 -1
- chat_func.py +15 -12
- utils.py +46 -26
ChuanhuChatbot.py
CHANGED
@@ -359,7 +359,7 @@ with gr.Blocks(
|
|
359 |
token_count,
|
360 |
top_p,
|
361 |
temperature,
|
362 |
-
|
363 |
model_select_dropdown,
|
364 |
],
|
365 |
[chatbot, history, status_display, token_count],
|
|
|
359 |
token_count,
|
360 |
top_p,
|
361 |
temperature,
|
362 |
+
gr.State(0),
|
363 |
model_select_dropdown,
|
364 |
],
|
365 |
[chatbot, history, status_display, token_count],
|
chat_func.py
CHANGED
@@ -371,9 +371,8 @@ def predict(
|
|
371 |
all_token_counts,
|
372 |
top_p,
|
373 |
temperature,
|
374 |
-
|
375 |
selected_model=selected_model,
|
376 |
-
hidden=True,
|
377 |
)
|
378 |
for chatbot, history, status_text, all_token_counts in iter:
|
379 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
@@ -410,9 +409,10 @@ def retry(
|
|
410 |
stream=stream,
|
411 |
selected_model=selected_model,
|
412 |
)
|
413 |
-
logging.info("
|
414 |
for x in iter:
|
415 |
yield x
|
|
|
416 |
|
417 |
|
418 |
def reduce_token_size(
|
@@ -423,9 +423,8 @@ def reduce_token_size(
|
|
423 |
token_count,
|
424 |
top_p,
|
425 |
temperature,
|
426 |
-
|
427 |
selected_model=MODELS[0],
|
428 |
-
hidden=False,
|
429 |
):
|
430 |
logging.info("开始减少token数量……")
|
431 |
iter = predict(
|
@@ -437,17 +436,21 @@ def reduce_token_size(
|
|
437 |
token_count,
|
438 |
top_p,
|
439 |
temperature,
|
440 |
-
stream=stream,
|
441 |
selected_model=selected_model,
|
442 |
should_check_token_count=False,
|
443 |
)
|
444 |
logging.info(f"chatbot: {chatbot}")
|
|
|
445 |
for chatbot, history, status_text, previous_token_count in iter:
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
452 |
), token_count
|
|
|
453 |
logging.info("减少token数量完毕")
|
|
|
371 |
all_token_counts,
|
372 |
top_p,
|
373 |
temperature,
|
374 |
+
max_token//2,
|
375 |
selected_model=selected_model,
|
|
|
376 |
)
|
377 |
for chatbot, history, status_text, all_token_counts in iter:
|
378 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
|
|
409 |
stream=stream,
|
410 |
selected_model=selected_model,
|
411 |
)
|
412 |
+
logging.info("重试中……")
|
413 |
for x in iter:
|
414 |
yield x
|
415 |
+
logging.info("重试完毕")
|
416 |
|
417 |
|
418 |
def reduce_token_size(
|
|
|
423 |
token_count,
|
424 |
top_p,
|
425 |
temperature,
|
426 |
+
max_token_count,
|
427 |
selected_model=MODELS[0],
|
|
|
428 |
):
|
429 |
logging.info("开始减少token数量……")
|
430 |
iter = predict(
|
|
|
436 |
token_count,
|
437 |
top_p,
|
438 |
temperature,
|
|
|
439 |
selected_model=selected_model,
|
440 |
should_check_token_count=False,
|
441 |
)
|
442 |
logging.info(f"chatbot: {chatbot}")
|
443 |
+
flag = False
|
444 |
for chatbot, history, status_text, previous_token_count in iter:
|
445 |
+
num_chat = find_n(previous_token_count, max_token_count)
|
446 |
+
if flag:
|
447 |
+
chatbot = chatbot[:-1]
|
448 |
+
flag = True
|
449 |
+
history = history[-2*num_chat:] if num_chat > 0 else []
|
450 |
+
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
|
451 |
+
msg = f"保留了最近{num_chat}轮对话"
|
452 |
+
yield chatbot, history, msg + "," + construct_token_message(
|
453 |
+
sum(token_count) if len(token_count) > 0 else 0,
|
454 |
), token_count
|
455 |
+
logging.info(msg)
|
456 |
logging.info("减少token数量完毕")
|
utils.py
CHANGED
@@ -37,9 +37,10 @@ def count_token(message):
|
|
37 |
length = len(encoding.encode(input_str))
|
38 |
return length
|
39 |
|
|
|
40 |
def markdown_to_html_with_syntax_highlight(md_str):
|
41 |
def replacer(match):
|
42 |
-
lang = match.group(1) or
|
43 |
code = match.group(2)
|
44 |
|
45 |
try:
|
@@ -50,60 +51,65 @@ def markdown_to_html_with_syntax_highlight(md_str):
|
|
50 |
formatter = HtmlFormatter()
|
51 |
highlighted_code = highlight(code, lexer, formatter)
|
52 |
|
53 |
-
return f
|
54 |
|
55 |
-
code_block_pattern = r
|
56 |
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
57 |
|
58 |
html_str = markdown(md_str)
|
59 |
return html_str
|
60 |
|
|
|
61 |
def normalize_markdown(md_text: str) -> str:
|
62 |
-
lines = md_text.split(
|
63 |
normalized_lines = []
|
64 |
inside_list = False
|
65 |
|
66 |
for i, line in enumerate(lines):
|
67 |
-
if re.match(r
|
68 |
-
if not inside_list and i > 0 and lines[i - 1].strip() !=
|
69 |
-
normalized_lines.append(
|
70 |
inside_list = True
|
71 |
normalized_lines.append(line)
|
72 |
-
elif inside_list and line.strip() ==
|
73 |
-
if i < len(lines) - 1 and not re.match(
|
|
|
|
|
74 |
normalized_lines.append(line)
|
75 |
continue
|
76 |
else:
|
77 |
inside_list = False
|
78 |
normalized_lines.append(line)
|
79 |
|
80 |
-
return
|
|
|
81 |
|
82 |
def convert_mdtext(md_text):
|
83 |
-
code_block_pattern = re.compile(r
|
84 |
code_blocks = code_block_pattern.findall(md_text)
|
85 |
non_code_parts = code_block_pattern.split(md_text)[::2]
|
86 |
|
87 |
result = []
|
88 |
-
for non_code, code in zip(non_code_parts, code_blocks + [
|
89 |
if non_code.strip():
|
90 |
non_code = normalize_markdown(non_code)
|
91 |
-
result.append(mdtex2html.convert(non_code, extensions=[
|
92 |
if code.strip():
|
93 |
-
_, code = detect_language(code)
|
94 |
code = f"```{code}\n\n```"
|
95 |
code = markdown_to_html_with_syntax_highlight(code)
|
96 |
result.append(code)
|
97 |
result = "".join(result)
|
98 |
return result
|
99 |
|
|
|
100 |
def detect_language(code):
|
101 |
if code.startswith("\n"):
|
102 |
first_line = ""
|
103 |
else:
|
104 |
-
first_line = code.strip().split(
|
105 |
-
language = first_line.lower() if first_line else
|
106 |
-
code_without_language = code[len(first_line):].lstrip() if first_line else code
|
107 |
return language, code_without_language
|
108 |
|
109 |
|
@@ -336,26 +342,40 @@ def replace_today(prompt):
|
|
336 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
337 |
return prompt.replace("{current_date}", today)
|
338 |
|
|
|
339 |
def get_geoip():
|
340 |
-
response = requests.get(
|
341 |
try:
|
342 |
data = response.json()
|
343 |
except:
|
344 |
-
data = {
|
345 |
-
"error": True,
|
346 |
-
"reason" : "连接ipapi失败"
|
347 |
-
}
|
348 |
if "error" in data.keys():
|
349 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
350 |
-
if data[
|
351 |
-
return
|
|
|
|
|
352 |
else:
|
353 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
354 |
else:
|
355 |
-
country = data[
|
356 |
if country == "China":
|
357 |
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
|
358 |
else:
|
359 |
text = f"您的IP区域:{country}。"
|
360 |
logging.info(text)
|
361 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
length = len(encoding.encode(input_str))
|
38 |
return length
|
39 |
|
40 |
+
|
41 |
def markdown_to_html_with_syntax_highlight(md_str):
|
42 |
def replacer(match):
|
43 |
+
lang = match.group(1) or "text"
|
44 |
code = match.group(2)
|
45 |
|
46 |
try:
|
|
|
51 |
formatter = HtmlFormatter()
|
52 |
highlighted_code = highlight(code, lexer, formatter)
|
53 |
|
54 |
+
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
|
55 |
|
56 |
+
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
|
57 |
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
58 |
|
59 |
html_str = markdown(md_str)
|
60 |
return html_str
|
61 |
|
62 |
+
|
63 |
def normalize_markdown(md_text: str) -> str:
|
64 |
+
lines = md_text.split("\n")
|
65 |
normalized_lines = []
|
66 |
inside_list = False
|
67 |
|
68 |
for i, line in enumerate(lines):
|
69 |
+
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
|
70 |
+
if not inside_list and i > 0 and lines[i - 1].strip() != "":
|
71 |
+
normalized_lines.append("")
|
72 |
inside_list = True
|
73 |
normalized_lines.append(line)
|
74 |
+
elif inside_list and line.strip() == "":
|
75 |
+
if i < len(lines) - 1 and not re.match(
|
76 |
+
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
|
77 |
+
):
|
78 |
normalized_lines.append(line)
|
79 |
continue
|
80 |
else:
|
81 |
inside_list = False
|
82 |
normalized_lines.append(line)
|
83 |
|
84 |
+
return "\n".join(normalized_lines)
|
85 |
+
|
86 |
|
87 |
def convert_mdtext(md_text):
|
88 |
+
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
89 |
code_blocks = code_block_pattern.findall(md_text)
|
90 |
non_code_parts = code_block_pattern.split(md_text)[::2]
|
91 |
|
92 |
result = []
|
93 |
+
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
94 |
if non_code.strip():
|
95 |
non_code = normalize_markdown(non_code)
|
96 |
+
result.append(mdtex2html.convert(non_code, extensions=["tables"]))
|
97 |
if code.strip():
|
98 |
+
_, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
|
99 |
code = f"```{code}\n\n```"
|
100 |
code = markdown_to_html_with_syntax_highlight(code)
|
101 |
result.append(code)
|
102 |
result = "".join(result)
|
103 |
return result
|
104 |
|
105 |
+
|
106 |
def detect_language(code):
|
107 |
if code.startswith("\n"):
|
108 |
first_line = ""
|
109 |
else:
|
110 |
+
first_line = code.strip().split("\n", 1)[0]
|
111 |
+
language = first_line.lower() if first_line else ""
|
112 |
+
code_without_language = code[len(first_line) :].lstrip() if first_line else code
|
113 |
return language, code_without_language
|
114 |
|
115 |
|
|
|
342 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
343 |
return prompt.replace("{current_date}", today)
|
344 |
|
345 |
+
|
346 |
def get_geoip():
|
347 |
+
response = requests.get("https://ipapi.co/json/", timeout=5)
|
348 |
try:
|
349 |
data = response.json()
|
350 |
except:
|
351 |
+
data = {"error": True, "reason": "连接ipapi失败"}
|
|
|
|
|
|
|
352 |
if "error" in data.keys():
|
353 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
354 |
+
if data["reason"] == "RateLimited":
|
355 |
+
return (
|
356 |
+
f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
|
357 |
+
)
|
358 |
else:
|
359 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
360 |
else:
|
361 |
+
country = data["country_name"]
|
362 |
if country == "China":
|
363 |
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
|
364 |
else:
|
365 |
text = f"您的IP区域:{country}。"
|
366 |
logging.info(text)
|
367 |
+
return text
|
368 |
+
|
369 |
+
|
370 |
+
def find_n(lst, max_num):
|
371 |
+
n = len(lst)
|
372 |
+
total = sum(lst)
|
373 |
+
|
374 |
+
if total < max_num:
|
375 |
+
return n
|
376 |
+
|
377 |
+
for i in range(len(lst)):
|
378 |
+
if total - lst[i] < max_num:
|
379 |
+
return n - i -1
|
380 |
+
total = total - lst[i]
|
381 |
+
return 1
|