Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- README.md +143 -12
- README_en.md +140 -0
- openai_api_request.py +88 -0
- openai_api_server.py +549 -0
- requirements.txt +27 -0
- trans_batch_demo.py +90 -0
- trans_cli_demo.py +112 -0
- trans_cli_vision_demo.py +121 -0
- trans_stress_test.py +135 -0
- trans_web_demo.py +167 -0
- vllm_cli_demo.py +111 -0
README.md
CHANGED
@@ -1,12 +1,143 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: basic_demo
|
3 |
+
app_file: trans_web_demo.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.36.0
|
6 |
+
---
|
7 |
+
# Basic Demo
|
8 |
+
|
9 |
+
Read this in [English](README_en.md).
|
10 |
+
|
11 |
+
本 demo 中,你将体验到如何使用 GLM-4-9B 开源模型进行基本的任务。
|
12 |
+
|
13 |
+
请严格按照文档的步骤进行操作,以避免不必要的错误。
|
14 |
+
|
15 |
+
## 设备和依赖检查
|
16 |
+
|
17 |
+
### 相关推理测试数据
|
18 |
+
|
19 |
+
**本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
|
20 |
+
|
21 |
+
测试硬件信息:
|
22 |
+
|
23 |
+
+ OS: Ubuntu 22.04
|
24 |
+
+ Memory: 512GB
|
25 |
+
+ Python: 3.12.3
|
26 |
+
+ CUDA Version: 12.3
|
27 |
+
+ GPU Driver: 535.104.05
|
28 |
+
+ GPU: NVIDIA A100-SXM4-80GB * 8
|
29 |
+
|
30 |
+
相关推理的压力测试数据如下:
|
31 |
+
|
32 |
+
**所有测试均在单张GPU上进行测试,所有显存消耗都按照峰值左右进行测算**
|
33 |
+
|
34 |
+
#### GLM-4-9B-Chat
|
35 |
+
|
36 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
37 |
+
|------|-------|------------|---------------|--------------|
|
38 |
+
| BF16 | 19 GB | 0.2s | 27.8 tokens/s | 输入长度为 1000 |
|
39 |
+
| BF16 | 21 GB | 0.8s | 31.8 tokens/s | 输入长度为 8000 |
|
40 |
+
| BF16 | 28 GB | 4.3s | 14.4 tokens/s | 输入长度为 32000 |
|
41 |
+
| BF16 | 58 GB | 38.1s | 3.4 tokens/s | 输入长度为 128000 |
|
42 |
+
|
43 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
44 |
+
|------|-------|------------|---------------|-------------|
|
45 |
+
| INT4 | 8 GB | 0.2s | 23.3 tokens/s | 输入长度为 1000 |
|
46 |
+
| INT4 | 10 GB | 0.8s | 23.4 tokens/s | 输入长度为 8000 |
|
47 |
+
| INT4 | 17 GB | 4.3s | 14.6 tokens/s | 输入长度为 32000 |
|
48 |
+
|
49 |
+
### GLM-4-9B-Chat-1M
|
50 |
+
|
51 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
52 |
+
|------|-------|------------|--------------|--------------|
|
53 |
+
| BF16 | 75 GB | 98.4s | 2.3 tokens/s | 输入长度为 200000 |
|
54 |
+
|
55 |
+
如果您的输入超过200K,我们建议您使用vLLM后端进行多卡推理,以获得更好的性能。
|
56 |
+
|
57 |
+
#### GLM-4V-9B
|
58 |
+
|
59 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
60 |
+
|------|-------|------------|---------------|------------|
|
61 |
+
| BF16 | 28 GB | 0.1s | 33.4 tokens/s | 输入长度为 1000 |
|
62 |
+
| BF16 | 33 GB | 0.7s | 39.2 tokens/s | 输入长度为 8000 |
|
63 |
+
|
64 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
65 |
+
|------|-------|------------|---------------|------------|
|
66 |
+
| INT4 | 10 GB | 0.1s | 28.7 tokens/s | 输入长度为 1000 |
|
67 |
+
| INT4 | 15 GB | 0.8s | 24.2 tokens/s | 输入长度为 8000 |
|
68 |
+
|
69 |
+
### 最低硬件要求
|
70 |
+
|
71 |
+
如果您希望运行官方提供的最基础代码 (transformers 后端) 您需要:
|
72 |
+
|
73 |
+
+ Python >= 3.10
|
74 |
+
+ 内存不少于 32 GB
|
75 |
+
|
76 |
+
如果您希望运行官方提供的本文件夹的所有代码,您还需要:
|
77 |
+
|
78 |
+
+ Linux 操作系统 (Debian 系列最佳)
|
79 |
+
+ 大于 8GB 显存的,支持 CUDA 或者 ROCM 并且支持 `BF16` 推理的 GPU 设备。(`FP16` 精度无法训练,推理有小概率出现问题)
|
80 |
+
|
81 |
+
安装依赖
|
82 |
+
|
83 |
+
```shell
|
84 |
+
pip install -r requirements.txt
|
85 |
+
```
|
86 |
+
|
87 |
+
## 基础功能调用
|
88 |
+
|
89 |
+
**除非特殊说明,本文件夹所有 demo 并不支持 Function Call 和 All Tools 等进阶用法**
|
90 |
+
|
91 |
+
### 使用 transformers 后端代码
|
92 |
+
|
93 |
+
+ 使用命令行与 GLM-4-9B 模型进行对话。
|
94 |
+
|
95 |
+
```shell
|
96 |
+
python trans_cli_demo.py # GLM-4-9B-Chat
|
97 |
+
python trans_cli_vision_demo.py # GLM-4V-9B
|
98 |
+
```
|
99 |
+
|
100 |
+
+ 使用 Gradio 网页端与 GLM-4-9B-Chat 模型进行对话。
|
101 |
+
|
102 |
+
```shell
|
103 |
+
python trans_web_demo.py
|
104 |
+
```
|
105 |
+
|
106 |
+
+ 使用 Batch 推理。
|
107 |
+
|
108 |
+
```shell
|
109 |
+
python cli_batch_request_demo.py
|
110 |
+
```
|
111 |
+
|
112 |
+
### 使用 vLLM 后端代码
|
113 |
+
|
114 |
+
+ 使用命令行与 GLM-4-9B-Chat 模型进行对话。
|
115 |
+
|
116 |
+
```shell
|
117 |
+
python vllm_cli_demo.py
|
118 |
+
```
|
119 |
+
|
120 |
+
+ 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
|
121 |
+
|
122 |
+
启动服务端:
|
123 |
+
|
124 |
+
```shell
|
125 |
+
python openai_api_server.py
|
126 |
+
```
|
127 |
+
|
128 |
+
客户端请求:
|
129 |
+
|
130 |
+
```shell
|
131 |
+
python openai_api_request.py
|
132 |
+
```
|
133 |
+
|
134 |
+
## 压力测试
|
135 |
+
|
136 |
+
用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度:
|
137 |
+
|
138 |
+
```shell
|
139 |
+
python trans_stress_test.py
|
140 |
+
```
|
141 |
+
|
142 |
+
|
143 |
+
|
README_en.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Basic Demo
|
2 |
+
|
3 |
+
In this demo, you will experience how to use the GLM-4-9B open source model to perform basic tasks.
|
4 |
+
|
5 |
+
Please follow the steps in the document strictly to avoid unnecessary errors.
|
6 |
+
|
7 |
+
## Device and dependency check
|
8 |
+
|
9 |
+
### Related inference test data
|
10 |
+
|
11 |
+
**The data in this document are tested in the following hardware environment. The actual operating environment
|
12 |
+
requirements and the GPU memory occupied by the operation are slightly different. Please refer to the actual operating
|
13 |
+
environment.**
|
14 |
+
|
15 |
+
Test hardware information:
|
16 |
+
|
17 |
+
+ OS: Ubuntu 22.04
|
18 |
+
+ Memory: 512GB
|
19 |
+
+ Python: 3.12.3
|
20 |
+
+ CUDA Version: 12.3
|
21 |
+
+ GPU Driver: 535.104.05
|
22 |
+
+ GPU: NVIDIA A100-SXM4-80GB * 8
|
23 |
+
|
24 |
+
The stress test data of relevant inference are as follows:
|
25 |
+
|
26 |
+
**All tests are performed on a single GPU, and all GPU memory consumption is calculated based on the peak value**
|
27 |
+
|
28 |
+
#
|
29 |
+
|
30 |
+
### GLM-4-9B-Chat
|
31 |
+
|
32 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
33 |
+
|-------|------------|------------|---------------|------------------------|
|
34 |
+
| BF16 | 19 GB | 0.2s | 27.8 tokens/s | Input length is 1000 |
|
35 |
+
| BF16 | 21 GB | 0.8s | 31.8 tokens/s | Input length is 8000 |
|
36 |
+
| BF16 | 28 GB | 4.3s | 14.4 tokens/s | Input length is 32000 |
|
37 |
+
| BF16 | 58 GB | 38.1s | 3.4 tokens/s | Input length is 128000 |
|
38 |
+
|
39 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
40 |
+
|-------|------------|------------|---------------|-----------------------|
|
41 |
+
| INT4 | 8 GB | 0.2s | 23.3 tokens/s | Input length is 1000 |
|
42 |
+
| INT4 | 10 GB | 0.8s | 23.4 tokens/s | Input length is 8000 |
|
43 |
+
| INT4 | 17 GB | 4.3s | 14.6 tokens/s | Input length is 32000 |
|
44 |
+
|
45 |
+
### GLM-4-9B-Chat-1M
|
46 |
+
|
47 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
48 |
+
|-------|------------|------------|------------------|------------------------|
|
49 |
+
| BF16 | 74497MiB | 98.4s | 2.3653 tokens/s | Input length is 200000 |
|
50 |
+
|
51 |
+
If your input exceeds 200K, we recommend that you use the vLLM backend with multi gpus for inference to get better
|
52 |
+
performance.
|
53 |
+
|
54 |
+
#### GLM-4V-9B
|
55 |
+
|
56 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
57 |
+
|-------|------------|------------|---------------|----------------------|
|
58 |
+
| BF16 | 28 GB | 0.1s | 33.4 tokens/s | Input length is 1000 |
|
59 |
+
| BF16 | 33 GB | 0.7s | 39.2 tokens/s | Input length is 8000 |
|
60 |
+
|
61 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
62 |
+
|-------|------------|------------|---------------|----------------------|
|
63 |
+
| INT4 | 10 GB | 0.1s | 28.7 tokens/s | Input length is 1000 |
|
64 |
+
| INT4 | 15 GB | 0.8s | 24.2 tokens/s | Input length is 8000 |
|
65 |
+
|
66 |
+
### Minimum hardware requirements
|
67 |
+
|
68 |
+
If you want to run the most basic code provided by the official (transformers backend) you need:
|
69 |
+
|
70 |
+
+ Python >= 3.10
|
71 |
+
+ Memory of at least 32 GB
|
72 |
+
|
73 |
+
If you want to run all the codes in this folder provided by the official, you also need:
|
74 |
+
|
75 |
+
+ Linux operating system (Debian series is best)
|
76 |
+
+ GPU device with more than 8GB GPU memory, supporting CUDA or ROCM and supporting `BF16` reasoning (`FP16` precision
|
77 |
+
cannot be finetuned, and there is a small probability of problems in infering)
|
78 |
+
|
79 |
+
Install dependencies
|
80 |
+
|
81 |
+
```shell
|
82 |
+
pip install -r requirements.txt
|
83 |
+
```
|
84 |
+
|
85 |
+
## Basic function calls
|
86 |
+
|
87 |
+
**Unless otherwise specified, all demos in this folder do not support advanced usage such as Function Call and All Tools
|
88 |
+
**
|
89 |
+
|
90 |
+
### Use transformers backend code
|
91 |
+
|
92 |
+
+ Use the command line to communicate with the GLM-4-9B model.
|
93 |
+
|
94 |
+
```shell
|
95 |
+
python trans_cli_demo.py # GLM-4-9B-Chat
|
96 |
+
python trans_cli_vision_demo.py # GLM-4V-9B
|
97 |
+
```
|
98 |
+
|
99 |
+
+ Use the Gradio web client to communicate with the GLM-4-9B-Chat model.
|
100 |
+
|
101 |
+
```shell
|
102 |
+
python trans_web_demo.py
|
103 |
+
```
|
104 |
+
|
105 |
+
+ Use Batch inference.
|
106 |
+
|
107 |
+
```shell
|
108 |
+
python cli_batch_request_demo.py
|
109 |
+
```
|
110 |
+
|
111 |
+
### Use vLLM backend code
|
112 |
+
|
113 |
+
+ Use the command line to communicate with the GLM-4-9B-Chat model.
|
114 |
+
|
115 |
+
```shell
|
116 |
+
python vllm_cli_demo.py
|
117 |
+
```
|
118 |
+
|
119 |
+
+ Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This
|
120 |
+
demo supports Function Call and All Tools functions.
|
121 |
+
|
122 |
+
Start the server:
|
123 |
+
|
124 |
+
```shell
|
125 |
+
python openai_api_server.py
|
126 |
+
```
|
127 |
+
|
128 |
+
Client request:
|
129 |
+
|
130 |
+
```shell
|
131 |
+
python openai_api_request.py
|
132 |
+
```
|
133 |
+
|
134 |
+
## Stress test
|
135 |
+
|
136 |
+
Users can use this code to test the generation speed of the model on the transformers backend on their own devices:
|
137 |
+
|
138 |
+
```shell
|
139 |
+
python trans_stress_test.py
|
140 |
+
```
|
openai_api_request.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a OpenAI Request demo for the glm-4-9b model, just Use OpenAI API to interact with the model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from openai import OpenAI
|
6 |
+
|
7 |
+
base_url = "http://127.0.0.1:8000/v1/"
|
8 |
+
client = OpenAI(api_key="EMPTY", base_url=base_url)
|
9 |
+
|
10 |
+
|
11 |
+
def function_chat():
|
12 |
+
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
|
13 |
+
tools = [
|
14 |
+
{
|
15 |
+
"type": "function",
|
16 |
+
"function": {
|
17 |
+
"name": "get_current_weather",
|
18 |
+
"description": "Get the current weather in a given location",
|
19 |
+
"parameters": {
|
20 |
+
"type": "object",
|
21 |
+
"properties": {
|
22 |
+
"location": {
|
23 |
+
"type": "string",
|
24 |
+
"description": "The city and state, e.g. San Francisco, CA",
|
25 |
+
},
|
26 |
+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
27 |
+
},
|
28 |
+
"required": ["location"],
|
29 |
+
},
|
30 |
+
},
|
31 |
+
}
|
32 |
+
]
|
33 |
+
|
34 |
+
# All Tools 能力: 绘图
|
35 |
+
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
|
36 |
+
# tools = [{"type": "cogview"}]
|
37 |
+
#
|
38 |
+
# All Tools 能力: 联网查询
|
39 |
+
# messages = [{"role": "user", "content": "今天黄金的价格"}]
|
40 |
+
# tools = [{"type": "simple_browser"}]
|
41 |
+
|
42 |
+
response = client.chat.completions.create(
|
43 |
+
model="glm-4",
|
44 |
+
messages=messages,
|
45 |
+
tools=tools,
|
46 |
+
tool_choice="auto", # use "auto" to let the model choose the tool automatically
|
47 |
+
# tool_choice={"type": "function", "function": {"name": "my_function"}},
|
48 |
+
)
|
49 |
+
if response:
|
50 |
+
content = response.choices[0].message.content
|
51 |
+
print(content)
|
52 |
+
else:
|
53 |
+
print("Error:", response.status_code)
|
54 |
+
|
55 |
+
|
56 |
+
def simple_chat(use_stream=False):
|
57 |
+
messages = [
|
58 |
+
{
|
59 |
+
"role": "system",
|
60 |
+
"content": "你是 GLM-4,请你热情回答用户的问题。",
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"role": "user",
|
64 |
+
"content": "你好,请你用生动的话语给我讲一个小故事吧"
|
65 |
+
}
|
66 |
+
]
|
67 |
+
response = client.chat.completions.create(
|
68 |
+
model="glm-4",
|
69 |
+
messages=messages,
|
70 |
+
stream=use_stream,
|
71 |
+
max_tokens=1024,
|
72 |
+
temperature=0.8,
|
73 |
+
presence_penalty=1.1,
|
74 |
+
top_p=0.8)
|
75 |
+
if response:
|
76 |
+
if use_stream:
|
77 |
+
for chunk in response:
|
78 |
+
print(chunk.choices[0].delta.content)
|
79 |
+
else:
|
80 |
+
content = response.choices[0].message.content
|
81 |
+
print(content)
|
82 |
+
else:
|
83 |
+
print("Error:", response.status_code)
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
simple_chat()
|
88 |
+
function_chat()
|
openai_api_server.py
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from asyncio.log import logger
|
4 |
+
|
5 |
+
import uvicorn
|
6 |
+
import gc
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
11 |
+
from fastapi import FastAPI, HTTPException, Response
|
12 |
+
from fastapi.middleware.cors import CORSMiddleware
|
13 |
+
from contextlib import asynccontextmanager
|
14 |
+
from typing import List, Literal, Optional, Union
|
15 |
+
from pydantic import BaseModel, Field
|
16 |
+
from transformers import AutoTokenizer, LogitsProcessor
|
17 |
+
from sse_starlette.sse import EventSourceResponse
|
18 |
+
|
19 |
+
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
20 |
+
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
21 |
+
MAX_MODEL_LENGTH = 8192
|
22 |
+
|
23 |
+
|
24 |
+
@asynccontextmanager
|
25 |
+
async def lifespan(app: FastAPI):
|
26 |
+
yield
|
27 |
+
if torch.cuda.is_available():
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
torch.cuda.ipc_collect()
|
30 |
+
|
31 |
+
|
32 |
+
app = FastAPI(lifespan=lifespan)
|
33 |
+
|
34 |
+
app.add_middleware(
|
35 |
+
CORSMiddleware,
|
36 |
+
allow_origins=["*"],
|
37 |
+
allow_credentials=True,
|
38 |
+
allow_methods=["*"],
|
39 |
+
allow_headers=["*"],
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class ModelCard(BaseModel):
|
44 |
+
id: str
|
45 |
+
object: str = "model"
|
46 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
47 |
+
owned_by: str = "owner"
|
48 |
+
root: Optional[str] = None
|
49 |
+
parent: Optional[str] = None
|
50 |
+
permission: Optional[list] = None
|
51 |
+
|
52 |
+
|
53 |
+
class ModelList(BaseModel):
|
54 |
+
object: str = "list"
|
55 |
+
data: List[ModelCard] = []
|
56 |
+
|
57 |
+
|
58 |
+
class FunctionCallResponse(BaseModel):
|
59 |
+
name: Optional[str] = None
|
60 |
+
arguments: Optional[str] = None
|
61 |
+
|
62 |
+
|
63 |
+
class ChatMessage(BaseModel):
|
64 |
+
role: Literal["user", "assistant", "system", "tool"]
|
65 |
+
content: str = None
|
66 |
+
name: Optional[str] = None
|
67 |
+
function_call: Optional[FunctionCallResponse] = None
|
68 |
+
|
69 |
+
|
70 |
+
class DeltaMessage(BaseModel):
|
71 |
+
role: Optional[Literal["user", "assistant", "system"]] = None
|
72 |
+
content: Optional[str] = None
|
73 |
+
function_call: Optional[FunctionCallResponse] = None
|
74 |
+
|
75 |
+
|
76 |
+
class EmbeddingRequest(BaseModel):
|
77 |
+
input: Union[List[str], str]
|
78 |
+
model: str
|
79 |
+
|
80 |
+
|
81 |
+
class CompletionUsage(BaseModel):
|
82 |
+
prompt_tokens: int
|
83 |
+
completion_tokens: int
|
84 |
+
total_tokens: int
|
85 |
+
|
86 |
+
|
87 |
+
class EmbeddingResponse(BaseModel):
|
88 |
+
data: list
|
89 |
+
model: str
|
90 |
+
object: str
|
91 |
+
usage: CompletionUsage
|
92 |
+
|
93 |
+
|
94 |
+
class UsageInfo(BaseModel):
|
95 |
+
prompt_tokens: int = 0
|
96 |
+
total_tokens: int = 0
|
97 |
+
completion_tokens: Optional[int] = 0
|
98 |
+
|
99 |
+
|
100 |
+
class ChatCompletionRequest(BaseModel):
|
101 |
+
model: str
|
102 |
+
messages: List[ChatMessage]
|
103 |
+
temperature: Optional[float] = 0.8
|
104 |
+
top_p: Optional[float] = 0.8
|
105 |
+
max_tokens: Optional[int] = None
|
106 |
+
stream: Optional[bool] = False
|
107 |
+
tools: Optional[Union[dict, List[dict]]] = None
|
108 |
+
tool_choice: Optional[Union[str, dict]] = "None"
|
109 |
+
repetition_penalty: Optional[float] = 1.1
|
110 |
+
|
111 |
+
|
112 |
+
class ChatCompletionResponseChoice(BaseModel):
|
113 |
+
index: int
|
114 |
+
message: ChatMessage
|
115 |
+
finish_reason: Literal["stop", "length", "function_call"]
|
116 |
+
|
117 |
+
|
118 |
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
119 |
+
delta: DeltaMessage
|
120 |
+
finish_reason: Optional[Literal["stop", "length", "function_call"]]
|
121 |
+
index: int
|
122 |
+
|
123 |
+
|
124 |
+
class ChatCompletionResponse(BaseModel):
|
125 |
+
model: str
|
126 |
+
id: str
|
127 |
+
object: Literal["chat.completion", "chat.completion.chunk"]
|
128 |
+
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
129 |
+
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
130 |
+
usage: Optional[UsageInfo] = None
|
131 |
+
|
132 |
+
|
133 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
134 |
+
def __call__(
|
135 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
136 |
+
) -> torch.FloatTensor:
|
137 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
138 |
+
scores.zero_()
|
139 |
+
scores[..., 5] = 5e4
|
140 |
+
return scores
|
141 |
+
|
142 |
+
|
143 |
+
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
144 |
+
content = ""
|
145 |
+
for response in output.split("<|assistant|>"):
|
146 |
+
if "\n" in response:
|
147 |
+
metadata, content = response.split("\n", maxsplit=1)
|
148 |
+
else:
|
149 |
+
metadata, content = "", response
|
150 |
+
if not metadata.strip():
|
151 |
+
content = content.strip()
|
152 |
+
else:
|
153 |
+
if use_tool:
|
154 |
+
parameters = eval(content.strip())
|
155 |
+
content = {
|
156 |
+
"name": metadata.strip(),
|
157 |
+
"arguments": json.dumps(parameters, ensure_ascii=False)
|
158 |
+
}
|
159 |
+
else:
|
160 |
+
content = {
|
161 |
+
"name": metadata.strip(),
|
162 |
+
"content": content
|
163 |
+
}
|
164 |
+
return content
|
165 |
+
|
166 |
+
|
167 |
+
@torch.inference_mode()
|
168 |
+
async def generate_stream_glm4(params):
|
169 |
+
messages = params["messages"]
|
170 |
+
tools = params["tools"]
|
171 |
+
tool_choice = params["tool_choice"]
|
172 |
+
temperature = float(params.get("temperature", 1.0))
|
173 |
+
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
174 |
+
top_p = float(params.get("top_p", 1.0))
|
175 |
+
max_new_tokens = int(params.get("max_tokens", 8192))
|
176 |
+
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
|
177 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
178 |
+
params_dict = {
|
179 |
+
"n": 1,
|
180 |
+
"best_of": 1,
|
181 |
+
"presence_penalty": 1.0,
|
182 |
+
"frequency_penalty": 0.0,
|
183 |
+
"temperature": temperature,
|
184 |
+
"top_p": top_p,
|
185 |
+
"top_k": -1,
|
186 |
+
"repetition_penalty": repetition_penalty,
|
187 |
+
"use_beam_search": False,
|
188 |
+
"length_penalty": 1,
|
189 |
+
"early_stopping": False,
|
190 |
+
"stop_token_ids": [151329, 151336, 151338],
|
191 |
+
"ignore_eos": False,
|
192 |
+
"max_tokens": max_new_tokens,
|
193 |
+
"logprobs": None,
|
194 |
+
"prompt_logprobs": None,
|
195 |
+
"skip_special_tokens": True,
|
196 |
+
}
|
197 |
+
sampling_params = SamplingParams(**params_dict)
|
198 |
+
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
199 |
+
output_len = len(output.outputs[0].token_ids)
|
200 |
+
input_len = len(output.prompt_token_ids)
|
201 |
+
ret = {
|
202 |
+
"text": output.outputs[0].text,
|
203 |
+
"usage": {
|
204 |
+
"prompt_tokens": input_len,
|
205 |
+
"completion_tokens": output_len,
|
206 |
+
"total_tokens": output_len + input_len
|
207 |
+
},
|
208 |
+
"finish_reason": output.outputs[0].finish_reason,
|
209 |
+
}
|
210 |
+
yield ret
|
211 |
+
gc.collect()
|
212 |
+
torch.cuda.empty_cache()
|
213 |
+
|
214 |
+
|
215 |
+
def process_messages(messages, tools=None, tool_choice="none"):
|
216 |
+
_messages = messages
|
217 |
+
messages = []
|
218 |
+
msg_has_sys = False
|
219 |
+
|
220 |
+
def filter_tools(tool_choice, tools):
|
221 |
+
function_name = tool_choice.get('function', {}).get('name', None)
|
222 |
+
if not function_name:
|
223 |
+
return []
|
224 |
+
filtered_tools = [
|
225 |
+
tool for tool in tools
|
226 |
+
if tool.get('function', {}).get('name') == function_name
|
227 |
+
]
|
228 |
+
return filtered_tools
|
229 |
+
|
230 |
+
if tool_choice != "none":
|
231 |
+
if isinstance(tool_choice, dict):
|
232 |
+
tools = filter_tools(tool_choice, tools)
|
233 |
+
if tools:
|
234 |
+
messages.append(
|
235 |
+
{
|
236 |
+
"role": "system",
|
237 |
+
"content": None,
|
238 |
+
"tools": tools
|
239 |
+
}
|
240 |
+
)
|
241 |
+
msg_has_sys = True
|
242 |
+
|
243 |
+
# add to metadata
|
244 |
+
if isinstance(tool_choice, dict) and tools:
|
245 |
+
messages.append(
|
246 |
+
{
|
247 |
+
"role": "assistant",
|
248 |
+
"metadata": tool_choice["function"]["name"],
|
249 |
+
"content": ""
|
250 |
+
}
|
251 |
+
)
|
252 |
+
|
253 |
+
for m in _messages:
|
254 |
+
role, content, func_call = m.role, m.content, m.function_call
|
255 |
+
if role == "function":
|
256 |
+
messages.append(
|
257 |
+
{
|
258 |
+
"role": "observation",
|
259 |
+
"content": content
|
260 |
+
}
|
261 |
+
)
|
262 |
+
elif role == "assistant" and func_call is not None:
|
263 |
+
for response in content.split("<|assistant|>"):
|
264 |
+
if "\n" in response:
|
265 |
+
metadata, sub_content = response.split("\n", maxsplit=1)
|
266 |
+
else:
|
267 |
+
metadata, sub_content = "", response
|
268 |
+
messages.append(
|
269 |
+
{
|
270 |
+
"role": role,
|
271 |
+
"metadata": metadata,
|
272 |
+
"content": sub_content.strip()
|
273 |
+
}
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
if role == "system" and msg_has_sys:
|
277 |
+
msg_has_sys = False
|
278 |
+
continue
|
279 |
+
messages.append({"role": role, "content": content})
|
280 |
+
|
281 |
+
return messages
|
282 |
+
|
283 |
+
|
284 |
+
@app.get("/health")
|
285 |
+
async def health() -> Response:
|
286 |
+
"""Health check."""
|
287 |
+
return Response(status_code=200)
|
288 |
+
|
289 |
+
|
290 |
+
@app.get("/v1/models", response_model=ModelList)
|
291 |
+
async def list_models():
|
292 |
+
model_card = ModelCard(id="glm-4")
|
293 |
+
return ModelList(data=[model_card])
|
294 |
+
|
295 |
+
|
296 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
297 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
298 |
+
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
299 |
+
raise HTTPException(status_code=400, detail="Invalid request")
|
300 |
+
|
301 |
+
gen_params = dict(
|
302 |
+
messages=request.messages,
|
303 |
+
temperature=request.temperature,
|
304 |
+
top_p=request.top_p,
|
305 |
+
max_tokens=request.max_tokens or 1024,
|
306 |
+
echo=False,
|
307 |
+
stream=request.stream,
|
308 |
+
repetition_penalty=request.repetition_penalty,
|
309 |
+
tools=request.tools,
|
310 |
+
tool_choice=request.tool_choice,
|
311 |
+
)
|
312 |
+
logger.debug(f"==== request ====\n{gen_params}")
|
313 |
+
|
314 |
+
if request.stream:
|
315 |
+
predict_stream_generator = predict_stream(request.model, gen_params)
|
316 |
+
output = await anext(predict_stream_generator)
|
317 |
+
if output:
|
318 |
+
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
319 |
+
logger.debug(f"First result output:\n{output}")
|
320 |
+
|
321 |
+
function_call = None
|
322 |
+
if output and request.tools:
|
323 |
+
try:
|
324 |
+
function_call = process_response(output, use_tool=True)
|
325 |
+
except:
|
326 |
+
logger.warning("Failed to parse tool call")
|
327 |
+
|
328 |
+
# CallFunction
|
329 |
+
if isinstance(function_call, dict):
|
330 |
+
function_call = FunctionCallResponse(**function_call)
|
331 |
+
tool_response = ""
|
332 |
+
if not gen_params.get("messages"):
|
333 |
+
gen_params["messages"] = []
|
334 |
+
gen_params["messages"].append(ChatMessage(role="assistant", content=output))
|
335 |
+
gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
|
336 |
+
generate = predict(request.model, gen_params)
|
337 |
+
return EventSourceResponse(generate, media_type="text/event-stream")
|
338 |
+
else:
|
339 |
+
generate = parse_output_text(request.model, output)
|
340 |
+
return EventSourceResponse(generate, media_type="text/event-stream")
|
341 |
+
|
342 |
+
response = ""
|
343 |
+
async for response in generate_stream_glm4(gen_params):
|
344 |
+
pass
|
345 |
+
|
346 |
+
if response["text"].startswith("\n"):
|
347 |
+
response["text"] = response["text"][1:]
|
348 |
+
response["text"] = response["text"].strip()
|
349 |
+
|
350 |
+
usage = UsageInfo()
|
351 |
+
function_call, finish_reason = None, "stop"
|
352 |
+
if request.tools:
|
353 |
+
try:
|
354 |
+
function_call = process_response(response["text"], use_tool=True)
|
355 |
+
except:
|
356 |
+
logger.warning(
|
357 |
+
"Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
|
358 |
+
|
359 |
+
if isinstance(function_call, dict):
|
360 |
+
finish_reason = "function_call"
|
361 |
+
function_call = FunctionCallResponse(**function_call)
|
362 |
+
|
363 |
+
message = ChatMessage(
|
364 |
+
role="assistant",
|
365 |
+
content=response["text"],
|
366 |
+
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
367 |
+
)
|
368 |
+
|
369 |
+
logger.debug(f"==== message ====\n{message}")
|
370 |
+
|
371 |
+
choice_data = ChatCompletionResponseChoice(
|
372 |
+
index=0,
|
373 |
+
message=message,
|
374 |
+
finish_reason=finish_reason,
|
375 |
+
)
|
376 |
+
task_usage = UsageInfo.model_validate(response["usage"])
|
377 |
+
for usage_key, usage_value in task_usage.model_dump().items():
|
378 |
+
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
379 |
+
|
380 |
+
return ChatCompletionResponse(
|
381 |
+
model=request.model,
|
382 |
+
id="", # for open_source model, id is empty
|
383 |
+
choices=[choice_data],
|
384 |
+
object="chat.completion",
|
385 |
+
usage=usage
|
386 |
+
)
|
387 |
+
|
388 |
+
|
389 |
+
async def predict(model_id: str, params: dict):
|
390 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
391 |
+
index=0,
|
392 |
+
delta=DeltaMessage(role="assistant"),
|
393 |
+
finish_reason=None
|
394 |
+
)
|
395 |
+
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
396 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
397 |
+
|
398 |
+
previous_text = ""
|
399 |
+
async for new_response in generate_stream_glm4(params):
|
400 |
+
decoded_unicode = new_response["text"]
|
401 |
+
delta_text = decoded_unicode[len(previous_text):]
|
402 |
+
previous_text = decoded_unicode
|
403 |
+
|
404 |
+
finish_reason = new_response["finish_reason"]
|
405 |
+
if len(delta_text) == 0 and finish_reason != "function_call":
|
406 |
+
continue
|
407 |
+
|
408 |
+
function_call = None
|
409 |
+
if finish_reason == "function_call":
|
410 |
+
try:
|
411 |
+
function_call = process_response(decoded_unicode, use_tool=True)
|
412 |
+
except:
|
413 |
+
logger.warning(
|
414 |
+
"Failed to parse tool call, maybe the response is not a tool call or have been answered.")
|
415 |
+
|
416 |
+
if isinstance(function_call, dict):
|
417 |
+
function_call = FunctionCallResponse(**function_call)
|
418 |
+
|
419 |
+
delta = DeltaMessage(
|
420 |
+
content=delta_text,
|
421 |
+
role="assistant",
|
422 |
+
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
423 |
+
)
|
424 |
+
|
425 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
426 |
+
index=0,
|
427 |
+
delta=delta,
|
428 |
+
finish_reason=finish_reason
|
429 |
+
)
|
430 |
+
chunk = ChatCompletionResponse(
|
431 |
+
model=model_id,
|
432 |
+
id="",
|
433 |
+
choices=[choice_data],
|
434 |
+
object="chat.completion.chunk"
|
435 |
+
)
|
436 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
437 |
+
|
438 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
439 |
+
index=0,
|
440 |
+
delta=DeltaMessage(),
|
441 |
+
finish_reason="stop"
|
442 |
+
)
|
443 |
+
chunk = ChatCompletionResponse(
|
444 |
+
model=model_id,
|
445 |
+
id="",
|
446 |
+
choices=[choice_data],
|
447 |
+
object="chat.completion.chunk"
|
448 |
+
)
|
449 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
450 |
+
yield '[DONE]'
|
451 |
+
|
452 |
+
|
453 |
+
async def predict_stream(model_id, gen_params):
|
454 |
+
output = ""
|
455 |
+
is_function_call = False
|
456 |
+
has_send_first_chunk = False
|
457 |
+
async for new_response in generate_stream_glm4(gen_params):
|
458 |
+
decoded_unicode = new_response["text"]
|
459 |
+
delta_text = decoded_unicode[len(output):]
|
460 |
+
output = decoded_unicode
|
461 |
+
|
462 |
+
if not is_function_call and len(output) > 7:
|
463 |
+
is_function_call = output and 'get_' in output
|
464 |
+
if is_function_call:
|
465 |
+
continue
|
466 |
+
|
467 |
+
finish_reason = new_response["finish_reason"]
|
468 |
+
if not has_send_first_chunk:
|
469 |
+
message = DeltaMessage(
|
470 |
+
content="",
|
471 |
+
role="assistant",
|
472 |
+
function_call=None,
|
473 |
+
)
|
474 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
475 |
+
index=0,
|
476 |
+
delta=message,
|
477 |
+
finish_reason=finish_reason
|
478 |
+
)
|
479 |
+
chunk = ChatCompletionResponse(
|
480 |
+
model=model_id,
|
481 |
+
id="",
|
482 |
+
choices=[choice_data],
|
483 |
+
created=int(time.time()),
|
484 |
+
object="chat.completion.chunk"
|
485 |
+
)
|
486 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
487 |
+
|
488 |
+
send_msg = delta_text if has_send_first_chunk else output
|
489 |
+
has_send_first_chunk = True
|
490 |
+
message = DeltaMessage(
|
491 |
+
content=send_msg,
|
492 |
+
role="assistant",
|
493 |
+
function_call=None,
|
494 |
+
)
|
495 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
496 |
+
index=0,
|
497 |
+
delta=message,
|
498 |
+
finish_reason=finish_reason
|
499 |
+
)
|
500 |
+
chunk = ChatCompletionResponse(
|
501 |
+
model=model_id,
|
502 |
+
id="",
|
503 |
+
choices=[choice_data],
|
504 |
+
created=int(time.time()),
|
505 |
+
object="chat.completion.chunk"
|
506 |
+
)
|
507 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
508 |
+
|
509 |
+
if is_function_call:
|
510 |
+
yield output
|
511 |
+
else:
|
512 |
+
yield '[DONE]'
|
513 |
+
|
514 |
+
|
515 |
+
async def parse_output_text(model_id: str, value: str):
|
516 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
517 |
+
index=0,
|
518 |
+
delta=DeltaMessage(role="assistant", content=value),
|
519 |
+
finish_reason=None
|
520 |
+
)
|
521 |
+
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
522 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
523 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
524 |
+
index=0,
|
525 |
+
delta=DeltaMessage(),
|
526 |
+
finish_reason="stop"
|
527 |
+
)
|
528 |
+
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
529 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
530 |
+
yield '[DONE]'
|
531 |
+
|
532 |
+
|
533 |
+
if __name__ == "__main__":
|
534 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
535 |
+
engine_args = AsyncEngineArgs(
|
536 |
+
model=MODEL_PATH,
|
537 |
+
tokenizer=MODEL_PATH,
|
538 |
+
tensor_parallel_size=1,
|
539 |
+
dtype="bfloat16",
|
540 |
+
trust_remote_code=True,
|
541 |
+
gpu_memory_utilization=0.9,
|
542 |
+
enforce_eager=True,
|
543 |
+
worker_use_ray=True,
|
544 |
+
engine_use_ray=False,
|
545 |
+
disable_log_requests=True,
|
546 |
+
max_model_len=MAX_MODEL_LENGTH,
|
547 |
+
)
|
548 |
+
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
549 |
+
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# use vllm
|
2 |
+
# vllm>=0.4.3
|
3 |
+
|
4 |
+
torch>=2.3.0
|
5 |
+
torchvision>=0.18.0
|
6 |
+
transformers==4.40.0
|
7 |
+
huggingface-hub>=0.23.1
|
8 |
+
sentencepiece>=0.2.0
|
9 |
+
pydantic>=2.7.1
|
10 |
+
timm>=0.9.16
|
11 |
+
tiktoken>=0.7.0
|
12 |
+
accelerate>=0.30.1
|
13 |
+
sentence_transformers>=2.7.0
|
14 |
+
|
15 |
+
# web demo
|
16 |
+
gradio>=4.33.0
|
17 |
+
|
18 |
+
# openai demo
|
19 |
+
openai>=1.31.1
|
20 |
+
einops>=0.7.0
|
21 |
+
sse-starlette>=2.1.0
|
22 |
+
|
23 |
+
# INT4
|
24 |
+
bitsandbytes>=0.43.1
|
25 |
+
|
26 |
+
# PEFT model, not need if you don't use PEFT finetune model.
|
27 |
+
peft>=0.11.0
|
trans_batch_demo.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
Here is an example of using batch request glm-4-9b,
|
4 |
+
here you need to build the conversation format yourself and then call the batch function to make batch requests.
|
5 |
+
Please note that in this demo, the memory consumption is significantly higher.
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
from typing import Optional, Union
|
10 |
+
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
|
11 |
+
|
12 |
+
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
13 |
+
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
15 |
+
MODEL_PATH,
|
16 |
+
trust_remote_code=True,
|
17 |
+
encode_special_tokens=True)
|
18 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
|
19 |
+
|
20 |
+
|
21 |
+
def process_model_outputs(inputs, outputs, tokenizer):
|
22 |
+
responses = []
|
23 |
+
for input_ids, output_ids in zip(inputs.input_ids, outputs):
|
24 |
+
response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
|
25 |
+
responses.append(response)
|
26 |
+
return responses
|
27 |
+
|
28 |
+
|
29 |
+
def batch(
|
30 |
+
model,
|
31 |
+
tokenizer,
|
32 |
+
messages: Union[str, list[str]],
|
33 |
+
max_input_tokens: int = 8192,
|
34 |
+
max_new_tokens: int = 8192,
|
35 |
+
num_beams: int = 1,
|
36 |
+
do_sample: bool = True,
|
37 |
+
top_p: float = 0.8,
|
38 |
+
temperature: float = 0.8,
|
39 |
+
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
40 |
+
):
|
41 |
+
messages = [messages] if isinstance(messages, str) else messages
|
42 |
+
batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
|
43 |
+
max_length=max_input_tokens).to(model.device)
|
44 |
+
|
45 |
+
gen_kwargs = {
|
46 |
+
"max_new_tokens": max_new_tokens,
|
47 |
+
"num_beams": num_beams,
|
48 |
+
"do_sample": do_sample,
|
49 |
+
"top_p": top_p,
|
50 |
+
"temperature": temperature,
|
51 |
+
"logits_processor": logits_processor,
|
52 |
+
"eos_token_id": model.config.eos_token_id
|
53 |
+
}
|
54 |
+
batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
|
55 |
+
batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
|
56 |
+
return batched_response
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
|
61 |
+
batch_message = [
|
62 |
+
[
|
63 |
+
{"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
|
64 |
+
{"role": "assistant", "content": "因为他们结婚时你还没有出生"},
|
65 |
+
{"role": "user", "content": "我刚才的提问是"}
|
66 |
+
],
|
67 |
+
[
|
68 |
+
{"role": "user", "content": "你好,你是谁"}
|
69 |
+
]
|
70 |
+
]
|
71 |
+
|
72 |
+
batch_inputs = []
|
73 |
+
max_input_tokens = 1024
|
74 |
+
for i, messages in enumerate(batch_message):
|
75 |
+
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
76 |
+
max_input_tokens = max(max_input_tokens, len(new_batch_input))
|
77 |
+
batch_inputs.append(new_batch_input)
|
78 |
+
gen_kwargs = {
|
79 |
+
"max_input_tokens": max_input_tokens,
|
80 |
+
"max_new_tokens": 8192,
|
81 |
+
"do_sample": True,
|
82 |
+
"top_p": 0.8,
|
83 |
+
"temperature": 0.8,
|
84 |
+
"num_beams": 1,
|
85 |
+
}
|
86 |
+
|
87 |
+
batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
|
88 |
+
for response in batch_responses:
|
89 |
+
print("=" * 10)
|
90 |
+
print(response)
|
trans_cli_demo.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a CLI demo with transformers backend for the glm-4-9b model,
|
3 |
+
allowing users to interact with the model through a command-line interface.
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
- Run the script to start the CLI demo.
|
7 |
+
- Interact with the model by typing questions and receiving responses.
|
8 |
+
|
9 |
+
Note: The script includes a modification to handle markdown to plain text conversion,
|
10 |
+
ensuring that the CLI interface displays formatted text correctly.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
from threading import Thread
|
16 |
+
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, AutoModel
|
17 |
+
|
18 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
19 |
+
|
20 |
+
## If use peft model.
|
21 |
+
# def load_model_and_tokenizer(model_dir, trust_remote_code: bool = True):
|
22 |
+
# if (model_dir / 'adapter_config.json').exists():
|
23 |
+
# model = AutoModel.from_pretrained(
|
24 |
+
# model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
25 |
+
# )
|
26 |
+
# tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
27 |
+
# else:
|
28 |
+
# model = AutoModel.from_pretrained(
|
29 |
+
# model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
30 |
+
# )
|
31 |
+
# tokenizer_dir = model_dir
|
32 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
33 |
+
# tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
|
34 |
+
# )
|
35 |
+
# return model, tokenizer
|
36 |
+
|
37 |
+
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
39 |
+
MODEL_PATH,
|
40 |
+
trust_remote_code=True,
|
41 |
+
encode_special_tokens=True
|
42 |
+
)
|
43 |
+
model = AutoModel.from_pretrained(
|
44 |
+
MODEL_PATH,
|
45 |
+
trust_remote_code=True,
|
46 |
+
device_map="auto").eval()
|
47 |
+
|
48 |
+
|
49 |
+
class StopOnTokens(StoppingCriteria):
|
50 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
51 |
+
stop_ids = model.config.eos_token_id
|
52 |
+
for stop_id in stop_ids:
|
53 |
+
if input_ids[0][-1] == stop_id:
|
54 |
+
return True
|
55 |
+
return False
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
history = []
|
60 |
+
max_length = 8192
|
61 |
+
top_p = 0.8
|
62 |
+
temperature = 0.6
|
63 |
+
stop = StopOnTokens()
|
64 |
+
|
65 |
+
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
66 |
+
while True:
|
67 |
+
user_input = input("\nYou: ")
|
68 |
+
if user_input.lower() in ["exit", "quit"]:
|
69 |
+
break
|
70 |
+
history.append([user_input, ""])
|
71 |
+
|
72 |
+
messages = []
|
73 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
74 |
+
if idx == len(history) - 1 and not model_msg:
|
75 |
+
messages.append({"role": "user", "content": user_msg})
|
76 |
+
break
|
77 |
+
if user_msg:
|
78 |
+
messages.append({"role": "user", "content": user_msg})
|
79 |
+
if model_msg:
|
80 |
+
messages.append({"role": "assistant", "content": model_msg})
|
81 |
+
model_inputs = tokenizer.apply_chat_template(
|
82 |
+
messages,
|
83 |
+
add_generation_prompt=True,
|
84 |
+
tokenize=True,
|
85 |
+
return_tensors="pt"
|
86 |
+
).to(model.device)
|
87 |
+
streamer = TextIteratorStreamer(
|
88 |
+
tokenizer=tokenizer,
|
89 |
+
timeout=60,
|
90 |
+
skip_prompt=True,
|
91 |
+
skip_special_tokens=True
|
92 |
+
)
|
93 |
+
generate_kwargs = {
|
94 |
+
"input_ids": model_inputs,
|
95 |
+
"streamer": streamer,
|
96 |
+
"max_new_tokens": max_length,
|
97 |
+
"do_sample": True,
|
98 |
+
"top_p": top_p,
|
99 |
+
"temperature": temperature,
|
100 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
101 |
+
"repetition_penalty": 1.2,
|
102 |
+
"eos_token_id": model.config.eos_token_id,
|
103 |
+
}
|
104 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
105 |
+
t.start()
|
106 |
+
print("GLM-4:", end="", flush=True)
|
107 |
+
for new_token in streamer:
|
108 |
+
if new_token:
|
109 |
+
print(new_token, end="", flush=True)
|
110 |
+
history[-1][1] += new_token
|
111 |
+
|
112 |
+
history[-1][1] = history[-1][1].strip()
|
trans_cli_vision_demo.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a CLI demo with transformers backend for the glm-4v-9b model,
|
3 |
+
allowing users to interact with the model through a command-line interface.
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
- Run the script to start the CLI demo.
|
7 |
+
- Interact with the model by typing questions and receiving responses.
|
8 |
+
|
9 |
+
Note: The script includes a modification to handle markdown to plain text conversion,
|
10 |
+
ensuring that the CLI interface displays formatted text correctly.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
from threading import Thread
|
16 |
+
from transformers import (
|
17 |
+
AutoTokenizer,
|
18 |
+
StoppingCriteria,
|
19 |
+
StoppingCriteriaList,
|
20 |
+
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
|
21 |
+
)
|
22 |
+
|
23 |
+
from PIL import Image
|
24 |
+
|
25 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
|
26 |
+
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
28 |
+
MODEL_PATH,
|
29 |
+
trust_remote_code=True,
|
30 |
+
encode_special_tokens=True
|
31 |
+
)
|
32 |
+
model = AutoModel.from_pretrained(
|
33 |
+
MODEL_PATH,
|
34 |
+
trust_remote_code=True,
|
35 |
+
device_map="auto",
|
36 |
+
torch_dtype=torch.bfloat16
|
37 |
+
).eval()
|
38 |
+
|
39 |
+
## For INT4 inference
|
40 |
+
# model = AutoModel.from_pretrained(
|
41 |
+
# MODEL_PATH,
|
42 |
+
# trust_remote_code=True,
|
43 |
+
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
44 |
+
# torch_dtype=torch.bfloat16,
|
45 |
+
# low_cpu_mem_usage=True
|
46 |
+
# ).eval()
|
47 |
+
|
48 |
+
class StopOnTokens(StoppingCriteria):
|
49 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
50 |
+
stop_ids = model.config.eos_token_id
|
51 |
+
for stop_id in stop_ids:
|
52 |
+
if input_ids[0][-1] == stop_id:
|
53 |
+
return True
|
54 |
+
return False
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
history = []
|
59 |
+
max_length = 1024
|
60 |
+
top_p = 0.8
|
61 |
+
temperature = 0.6
|
62 |
+
stop = StopOnTokens()
|
63 |
+
uploaded = False
|
64 |
+
image = None
|
65 |
+
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
66 |
+
image_path = input("Image Path:")
|
67 |
+
try:
|
68 |
+
image = Image.open(image_path).convert("RGB")
|
69 |
+
except:
|
70 |
+
print("Invalid image path. Continuing with text conversation.")
|
71 |
+
while True:
|
72 |
+
user_input = input("\nYou: ")
|
73 |
+
if user_input.lower() in ["exit", "quit"]:
|
74 |
+
break
|
75 |
+
history.append([user_input, ""])
|
76 |
+
|
77 |
+
messages = []
|
78 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
79 |
+
if idx == len(history) - 1 and not model_msg:
|
80 |
+
messages.append({"role": "user", "content": user_msg})
|
81 |
+
if image and not uploaded:
|
82 |
+
messages[-1].update({"image": image})
|
83 |
+
uploaded = True
|
84 |
+
break
|
85 |
+
if user_msg:
|
86 |
+
messages.append({"role": "user", "content": user_msg})
|
87 |
+
if model_msg:
|
88 |
+
messages.append({"role": "assistant", "content": model_msg})
|
89 |
+
model_inputs = tokenizer.apply_chat_template(
|
90 |
+
messages,
|
91 |
+
add_generation_prompt=True,
|
92 |
+
tokenize=True,
|
93 |
+
return_tensors="pt",
|
94 |
+
return_dict=True
|
95 |
+
).to(next(model.parameters()).device)
|
96 |
+
streamer = TextIteratorStreamer(
|
97 |
+
tokenizer=tokenizer,
|
98 |
+
timeout=60,
|
99 |
+
skip_prompt=True,
|
100 |
+
skip_special_tokens=True
|
101 |
+
)
|
102 |
+
generate_kwargs = {
|
103 |
+
**model_inputs,
|
104 |
+
"streamer": streamer,
|
105 |
+
"max_new_tokens": max_length,
|
106 |
+
"do_sample": True,
|
107 |
+
"top_p": top_p,
|
108 |
+
"temperature": temperature,
|
109 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
110 |
+
"repetition_penalty": 1.2,
|
111 |
+
"eos_token_id": [151329, 151336, 151338],
|
112 |
+
}
|
113 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
114 |
+
t.start()
|
115 |
+
print("GLM-4:", end="", flush=True)
|
116 |
+
for new_token in streamer:
|
117 |
+
if new_token:
|
118 |
+
print(new_token, end="", flush=True)
|
119 |
+
history[-1][1] += new_token
|
120 |
+
|
121 |
+
history[-1][1] = history[-1][1].strip()
|
trans_stress_test.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
4 |
+
import torch
|
5 |
+
from threading import Thread
|
6 |
+
|
7 |
+
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
8 |
+
|
9 |
+
|
10 |
+
def stress_test(token_len, n, num_gpu):
|
11 |
+
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
13 |
+
MODEL_PATH,
|
14 |
+
trust_remote_code=True,
|
15 |
+
padding_side="left"
|
16 |
+
)
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(
|
18 |
+
MODEL_PATH,
|
19 |
+
trust_remote_code=True,
|
20 |
+
torch_dtype=torch.bfloat16
|
21 |
+
).to(device).eval()
|
22 |
+
|
23 |
+
# Use INT4 weight infer
|
24 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
25 |
+
# MODEL_PATH,
|
26 |
+
# trust_remote_code=True,
|
27 |
+
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
28 |
+
# low_cpu_mem_usage=True,
|
29 |
+
# ).eval()
|
30 |
+
|
31 |
+
times = []
|
32 |
+
decode_times = []
|
33 |
+
|
34 |
+
print("Warming up...")
|
35 |
+
vocab_size = tokenizer.vocab_size
|
36 |
+
warmup_token_len = 20
|
37 |
+
random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
|
38 |
+
start_tokens = [151331, 151333, 151336, 198]
|
39 |
+
end_tokens = [151337]
|
40 |
+
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
|
41 |
+
device)
|
42 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
|
43 |
+
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
|
44 |
+
warmup_inputs = {
|
45 |
+
'input_ids': input_ids,
|
46 |
+
'attention_mask': attention_mask,
|
47 |
+
'position_ids': position_ids
|
48 |
+
}
|
49 |
+
with torch.no_grad():
|
50 |
+
_ = model.generate(
|
51 |
+
input_ids=warmup_inputs['input_ids'],
|
52 |
+
attention_mask=warmup_inputs['attention_mask'],
|
53 |
+
max_new_tokens=2048,
|
54 |
+
do_sample=False,
|
55 |
+
repetition_penalty=1.0,
|
56 |
+
eos_token_id=[151329, 151336, 151338]
|
57 |
+
)
|
58 |
+
print("Warming up complete. Starting stress test...")
|
59 |
+
|
60 |
+
for i in range(n):
|
61 |
+
random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
|
62 |
+
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
|
63 |
+
0).to(device)
|
64 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
|
65 |
+
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
|
66 |
+
test_inputs = {
|
67 |
+
'input_ids': input_ids,
|
68 |
+
'attention_mask': attention_mask,
|
69 |
+
'position_ids': position_ids
|
70 |
+
}
|
71 |
+
|
72 |
+
streamer = TextIteratorStreamer(
|
73 |
+
tokenizer=tokenizer,
|
74 |
+
timeout=36000,
|
75 |
+
skip_prompt=True,
|
76 |
+
skip_special_tokens=True
|
77 |
+
)
|
78 |
+
|
79 |
+
generate_kwargs = {
|
80 |
+
"input_ids": test_inputs['input_ids'],
|
81 |
+
"attention_mask": test_inputs['attention_mask'],
|
82 |
+
"max_new_tokens": 512,
|
83 |
+
"do_sample": False,
|
84 |
+
"repetition_penalty": 1.0,
|
85 |
+
"eos_token_id": [151329, 151336, 151338],
|
86 |
+
"streamer": streamer
|
87 |
+
}
|
88 |
+
|
89 |
+
start_time = time.time()
|
90 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
91 |
+
t.start()
|
92 |
+
|
93 |
+
first_token_time = None
|
94 |
+
all_token_times = []
|
95 |
+
|
96 |
+
for token in streamer:
|
97 |
+
current_time = time.time()
|
98 |
+
if first_token_time is None:
|
99 |
+
first_token_time = current_time
|
100 |
+
times.append(first_token_time - start_time)
|
101 |
+
all_token_times.append(current_time)
|
102 |
+
|
103 |
+
t.join()
|
104 |
+
end_time = time.time()
|
105 |
+
|
106 |
+
avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
|
107 |
+
decode_times.append(avg_decode_time_per_token)
|
108 |
+
print(
|
109 |
+
f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")
|
110 |
+
|
111 |
+
torch.cuda.empty_cache()
|
112 |
+
|
113 |
+
avg_first_token_time = sum(times) / n
|
114 |
+
avg_decode_time = sum(decode_times) / n
|
115 |
+
print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
|
116 |
+
print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
|
117 |
+
return times, avg_first_token_time, decode_times, avg_decode_time
|
118 |
+
|
119 |
+
|
120 |
+
def main():
|
121 |
+
parser = argparse.ArgumentParser(description="Stress test for model inference")
|
122 |
+
parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
|
123 |
+
parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
|
124 |
+
parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
|
125 |
+
args = parser.parse_args()
|
126 |
+
|
127 |
+
token_len = args.token_len
|
128 |
+
n = args.n
|
129 |
+
num_gpu = args.num_gpu
|
130 |
+
|
131 |
+
stress_test(token_len, n, num_gpu)
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
main()
|
trans_web_demo.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates an interactive web demo for the GLM-4-9B model using Gradio,
|
3 |
+
a Python library for building quick and easy UI components for machine learning models.
|
4 |
+
It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface,
|
5 |
+
allowing users to interact with the model through a chat-like interface.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import gradio as gr
|
10 |
+
import torch
|
11 |
+
from threading import Thread
|
12 |
+
|
13 |
+
from typing import Union
|
14 |
+
from pathlib import Path
|
15 |
+
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
16 |
+
from transformers import (
|
17 |
+
AutoModelForCausalLM,
|
18 |
+
AutoTokenizer,
|
19 |
+
PreTrainedModel,
|
20 |
+
PreTrainedTokenizer,
|
21 |
+
PreTrainedTokenizerFast,
|
22 |
+
StoppingCriteria,
|
23 |
+
StoppingCriteriaList,
|
24 |
+
TextIteratorStreamer
|
25 |
+
)
|
26 |
+
|
27 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
28 |
+
|
29 |
+
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
30 |
+
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
31 |
+
|
32 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', '..\models\glm-4-9b-chat')
|
33 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
34 |
+
|
35 |
+
|
36 |
+
def _resolve_path(path: Union[str, Path]) -> Path:
|
37 |
+
return Path(path).expanduser().resolve()
|
38 |
+
|
39 |
+
|
40 |
+
def load_model_and_tokenizer(
|
41 |
+
model_dir: Union[str, Path], trust_remote_code: bool = True
|
42 |
+
) -> tuple[ModelType, TokenizerType]:
|
43 |
+
model_dir = _resolve_path(model_dir)
|
44 |
+
if (model_dir / 'adapter_config.json').exists():
|
45 |
+
model = AutoPeftModelForCausalLM.from_pretrained(
|
46 |
+
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
47 |
+
)
|
48 |
+
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
49 |
+
else:
|
50 |
+
model = AutoModelForCausalLM.from_pretrained(
|
51 |
+
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
52 |
+
).to(DEVICE).eval()
|
53 |
+
tokenizer_dir = model_dir
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
55 |
+
tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
|
56 |
+
)
|
57 |
+
return model, tokenizer
|
58 |
+
|
59 |
+
|
60 |
+
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
|
61 |
+
|
62 |
+
|
63 |
+
class StopOnTokens(StoppingCriteria):
|
64 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
65 |
+
stop_ids = model.config.eos_token_id
|
66 |
+
for stop_id in stop_ids:
|
67 |
+
if input_ids[0][-1] == stop_id:
|
68 |
+
return True
|
69 |
+
return False
|
70 |
+
|
71 |
+
|
72 |
+
def parse_text(text):
|
73 |
+
lines = text.split("\n")
|
74 |
+
lines = [line for line in lines if line != ""]
|
75 |
+
count = 0
|
76 |
+
for i, line in enumerate(lines):
|
77 |
+
if "```" in line:
|
78 |
+
count += 1
|
79 |
+
items = line.split('`')
|
80 |
+
if count % 2 == 1:
|
81 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
82 |
+
else:
|
83 |
+
lines[i] = f'<br></code></pre>'
|
84 |
+
else:
|
85 |
+
if i > 0:
|
86 |
+
if count % 2 == 1:
|
87 |
+
line = line.replace("`", "\`")
|
88 |
+
line = line.replace("<", "<")
|
89 |
+
line = line.replace(">", ">")
|
90 |
+
line = line.replace(" ", " ")
|
91 |
+
line = line.replace("*", "*")
|
92 |
+
line = line.replace("_", "_")
|
93 |
+
line = line.replace("-", "-")
|
94 |
+
line = line.replace(".", ".")
|
95 |
+
line = line.replace("!", "!")
|
96 |
+
line = line.replace("(", "(")
|
97 |
+
line = line.replace(")", ")")
|
98 |
+
line = line.replace("$", "$")
|
99 |
+
lines[i] = "<br>" + line
|
100 |
+
text = "".join(lines)
|
101 |
+
return text
|
102 |
+
|
103 |
+
|
104 |
+
def predict(history, max_length, top_p, temperature):
|
105 |
+
stop = StopOnTokens()
|
106 |
+
messages = []
|
107 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
108 |
+
if idx == len(history) - 1 and not model_msg:
|
109 |
+
messages.append({"role": "user", "content": user_msg})
|
110 |
+
break
|
111 |
+
if user_msg:
|
112 |
+
messages.append({"role": "user", "content": user_msg})
|
113 |
+
if model_msg:
|
114 |
+
messages.append({"role": "assistant", "content": model_msg})
|
115 |
+
|
116 |
+
model_inputs = tokenizer.apply_chat_template(messages,
|
117 |
+
add_generation_prompt=True,
|
118 |
+
tokenize=True,
|
119 |
+
return_tensors="pt").to(next(model.parameters()).device)
|
120 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
|
121 |
+
generate_kwargs = {
|
122 |
+
"input_ids": model_inputs,
|
123 |
+
"streamer": streamer,
|
124 |
+
"max_new_tokens": max_length,
|
125 |
+
"do_sample": True,
|
126 |
+
"top_p": top_p,
|
127 |
+
"temperature": temperature,
|
128 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
129 |
+
"repetition_penalty": 1.2,
|
130 |
+
"eos_token_id": model.config.eos_token_id,
|
131 |
+
}
|
132 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
133 |
+
t.start()
|
134 |
+
for new_token in streamer:
|
135 |
+
if new_token:
|
136 |
+
history[-1][1] += new_token
|
137 |
+
yield history
|
138 |
+
|
139 |
+
|
140 |
+
with gr.Blocks() as demo:
|
141 |
+
gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""")
|
142 |
+
chatbot = gr.Chatbot()
|
143 |
+
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column(scale=4):
|
146 |
+
with gr.Column(scale=12):
|
147 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
|
148 |
+
with gr.Column(min_width=32, scale=1):
|
149 |
+
submitBtn = gr.Button("Submit")
|
150 |
+
with gr.Column(scale=1):
|
151 |
+
emptyBtn = gr.Button("Clear History")
|
152 |
+
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
|
153 |
+
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
|
154 |
+
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
|
155 |
+
|
156 |
+
|
157 |
+
def user(query, history):
|
158 |
+
return "", history + [[parse_text(query), ""]]
|
159 |
+
|
160 |
+
|
161 |
+
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
|
162 |
+
predict, [chatbot, max_length, top_p, temperature], chatbot
|
163 |
+
)
|
164 |
+
emptyBtn.click(lambda: None, None, chatbot, queue=False)
|
165 |
+
|
166 |
+
demo.queue()
|
167 |
+
demo.launch(server_name="0.0.0.0", server_port=8501, inbrowser=False, share=True)
|
vllm_cli_demo.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a CLI demo with vllm backand for the glm-4-9b model,
|
3 |
+
allowing users to interact with the model through a command-line interface.
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
- Run the script to start the CLI demo.
|
7 |
+
- Interact with the model by typing questions and receiving responses.
|
8 |
+
|
9 |
+
Note: The script includes a modification to handle markdown to plain text conversion,
|
10 |
+
ensuring that the CLI interface displays formatted text correctly.
|
11 |
+
"""
|
12 |
+
import time
|
13 |
+
import asyncio
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
16 |
+
from typing import List, Dict
|
17 |
+
|
18 |
+
MODEL_PATH = 'THUDM/glm-4-9b'
|
19 |
+
|
20 |
+
|
21 |
+
def load_model_and_tokenizer(model_dir: str):
|
22 |
+
engine_args = AsyncEngineArgs(
|
23 |
+
model=model_dir,
|
24 |
+
tokenizer=model_dir,
|
25 |
+
tensor_parallel_size=1,
|
26 |
+
dtype="bfloat16",
|
27 |
+
trust_remote_code=True,
|
28 |
+
gpu_memory_utilization=0.3,
|
29 |
+
enforce_eager=True,
|
30 |
+
worker_use_ray=True,
|
31 |
+
engine_use_ray=False,
|
32 |
+
disable_log_requests=True
|
33 |
+
# 如果遇见 OOM 现象,建议开启下述参数
|
34 |
+
# enable_chunked_prefill=True,
|
35 |
+
# max_num_batched_tokens=8192
|
36 |
+
)
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
38 |
+
model_dir,
|
39 |
+
trust_remote_code=True,
|
40 |
+
encode_special_tokens=True
|
41 |
+
)
|
42 |
+
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
43 |
+
return engine, tokenizer
|
44 |
+
|
45 |
+
|
46 |
+
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH)
|
47 |
+
|
48 |
+
|
49 |
+
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
|
50 |
+
inputs = tokenizer.apply_chat_template(
|
51 |
+
messages,
|
52 |
+
add_generation_prompt=True,
|
53 |
+
tokenize=False
|
54 |
+
)
|
55 |
+
params_dict = {
|
56 |
+
"n": 1,
|
57 |
+
"best_of": 1,
|
58 |
+
"presence_penalty": 1.0,
|
59 |
+
"frequency_penalty": 0.0,
|
60 |
+
"temperature": temperature,
|
61 |
+
"top_p": top_p,
|
62 |
+
"top_k": -1,
|
63 |
+
"use_beam_search": False,
|
64 |
+
"length_penalty": 1,
|
65 |
+
"early_stopping": False,
|
66 |
+
"stop_token_ids": [151329, 151336, 151338],
|
67 |
+
"ignore_eos": False,
|
68 |
+
"max_tokens": max_dec_len,
|
69 |
+
"logprobs": None,
|
70 |
+
"prompt_logprobs": None,
|
71 |
+
"skip_special_tokens": True,
|
72 |
+
}
|
73 |
+
sampling_params = SamplingParams(**params_dict)
|
74 |
+
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
75 |
+
yield output.outputs[0].text
|
76 |
+
|
77 |
+
|
78 |
+
async def chat():
|
79 |
+
history = []
|
80 |
+
max_length = 8192
|
81 |
+
top_p = 0.8
|
82 |
+
temperature = 0.6
|
83 |
+
|
84 |
+
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
85 |
+
while True:
|
86 |
+
user_input = input("\nYou: ")
|
87 |
+
if user_input.lower() in ["exit", "quit"]:
|
88 |
+
break
|
89 |
+
history.append([user_input, ""])
|
90 |
+
|
91 |
+
messages = []
|
92 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
93 |
+
if idx == len(history) - 1 and not model_msg:
|
94 |
+
messages.append({"role": "user", "content": user_msg})
|
95 |
+
break
|
96 |
+
if user_msg:
|
97 |
+
messages.append({"role": "user", "content": user_msg})
|
98 |
+
if model_msg:
|
99 |
+
messages.append({"role": "assistant", "content": model_msg})
|
100 |
+
|
101 |
+
print("\nGLM-4: ", end="")
|
102 |
+
current_length = 0
|
103 |
+
output = ""
|
104 |
+
async for output in vllm_gen(messages, top_p, temperature, max_length):
|
105 |
+
print(output[current_length:], end="", flush=True)
|
106 |
+
current_length = len(output)
|
107 |
+
history[-1][1] = output
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
asyncio.run(chat())
|