3v324v23 commited on
Commit
a401787
1 Parent(s): c8aaa2b
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*-coding:utf-8-*-
2
+
3
+ from typing import Optional
4
+ import datetime
5
+ import os
6
+ from threading import Event, Thread
7
+ from uuid import uuid4
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ StoppingCriteria,
16
+ StoppingCriteriaList,
17
+ TextIteratorStreamer,
18
+ )
19
+
20
+
21
+ model_name = "golaxy/chinese-bloom-3b"
22
+ max_new_tokens = 2048
23
+
24
+
25
+ print(f"Starting to load the model {model_name} into memory")
26
+
27
+ tok = AutoTokenizer.from_pretrained(model_name)
28
+ #m = AutoModelForCausalLM.from_pretrained(model_name).eval()
29
+ m = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
30
+ print("m=====>device",m.device)
31
+ # tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
32
+ stop_token_ids = [tok.eos_token_id]
33
+
34
+ print(f"Successfully loaded the model {model_name} into memory")
35
+
36
+
37
+
38
+ class StopOnTokens(StoppingCriteria):
39
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
40
+ for stop_id in stop_token_ids:
41
+ if input_ids[0][-1] == stop_id:
42
+ return True
43
+ return False
44
+
45
+
46
+ PROMPT_DICT = {
47
+ "prompt_input": (
48
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
49
+ "Write a response that appropriately completes the request.\n\n"
50
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
51
+ ),
52
+ "prompt_no_input": (
53
+ "Below is an instruction that describes a task. "
54
+ "Write a response that appropriately completes the request.\n\n"
55
+ "### Instruction:\n{instruction}\n\n### Response:"
56
+ ),
57
+ }
58
+
59
+
60
+ def generate_input(instruction: Optional[str] = None, input_str: Optional[str] = None) -> str:
61
+ if input_str is None:
62
+ return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction})
63
+ else:
64
+ return PROMPT_DICT['prompt_input'].format_map({'instruction': instruction, 'input': input_str})
65
+
66
+
67
+ def convert_history_to_text(history):
68
+
69
+ user_input = history[-1][0]
70
+
71
+ text = generate_input(user_input)
72
+ return text
73
+
74
+
75
+
76
+
77
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
78
+ logging_url = os.getenv("LOGGING_URL", None)
79
+ if logging_url is None:
80
+ return
81
+
82
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
83
+
84
+ data = {
85
+ "conversation_id": conversation_id,
86
+ "timestamp": timestamp,
87
+ "history": history,
88
+ "messages": messages,
89
+ "generate_kwargs": generate_kwargs,
90
+ }
91
+
92
+ try:
93
+ requests.post(logging_url, json=data)
94
+ except requests.exceptions.RequestException as e:
95
+ print(f"Error logging conversation: {e}")
96
+
97
+
98
+ def user(message, history):
99
+ # Append the user's message to the conversation history
100
+ return "", history + [[message, ""]]
101
+
102
+
103
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
104
+ print(f"history: {history}")
105
+ # Initialize a StopOnTokens object
106
+ stop = StopOnTokens()
107
+
108
+ # Construct the input message string for the model by concatenating the current system message and conversation history
109
+ messages = convert_history_to_text(history)
110
+
111
+ # Tokenize the messages string
112
+ input_ids = tok(messages, return_tensors="pt").input_ids
113
+ input_ids = input_ids.to(m.device)
114
+ streamer = TextIteratorStreamer(
115
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
116
+ generate_kwargs = dict(
117
+ input_ids=input_ids,
118
+ max_new_tokens=max_new_tokens,
119
+ temperature=temperature,
120
+ do_sample=temperature > 0.0,
121
+ top_p=top_p,
122
+ top_k=top_k,
123
+ repetition_penalty=repetition_penalty,
124
+ streamer=streamer,
125
+ stopping_criteria=StoppingCriteriaList([stop]),
126
+ )
127
+ print(generate_kwargs)
128
+ stream_complete = Event()
129
+
130
+ def generate_and_signal_complete():
131
+ m.generate(**generate_kwargs)
132
+ stream_complete.set()
133
+
134
+ def log_after_stream_complete():
135
+ stream_complete.wait()
136
+ log_conversation(
137
+ conversation_id,
138
+ history,
139
+ messages,
140
+ {
141
+ "top_k": top_k,
142
+ "top_p": top_p,
143
+ "temperature": temperature,
144
+ "repetition_penalty": repetition_penalty,
145
+ },
146
+ )
147
+
148
+ t1 = Thread(target=generate_and_signal_complete)
149
+ t1.start()
150
+
151
+ t2 = Thread(target=log_after_stream_complete)
152
+ t2.start()
153
+
154
+ # Initialize an empty string to store the generated text
155
+ partial_text = ""
156
+ for new_text in streamer:
157
+ partial_text += new_text
158
+ history[-1][1] = partial_text
159
+ yield history
160
+
161
+
162
+ def get_uuid():
163
+ return str(uuid4())
164
+
165
+
166
+ with gr.Blocks(
167
+ theme=gr.themes.Soft(),
168
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
169
+ ) as demo:
170
+ conversation_id = gr.State(get_uuid)
171
+ chatbot = gr.Chatbot().style(height=500)
172
+ with gr.Row():
173
+ with gr.Column():
174
+ msg = gr.Textbox(
175
+ label="Chat Message Box",
176
+ placeholder="Chat Message Box",
177
+ show_label=False,
178
+ ).style(container=False)
179
+ with gr.Column():
180
+ with gr.Row():
181
+ submit = gr.Button("Submit")
182
+ stop = gr.Button("Stop")
183
+ clear = gr.Button("Clear")
184
+ with gr.Row():
185
+ with gr.Accordion("Advanced Options:", open=False):
186
+ with gr.Row():
187
+ with gr.Column():
188
+ with gr.Row():
189
+ temperature = gr.Slider(
190
+ label="Temperature",
191
+ value=0.1,
192
+ minimum=0.0,
193
+ maximum=1.0,
194
+ step=0.1,
195
+ interactive=True,
196
+ info="Higher values produce more diverse outputs",
197
+ )
198
+ with gr.Column():
199
+ with gr.Row():
200
+ top_p = gr.Slider(
201
+ label="Top-p (nucleus sampling)",
202
+ value=1.0,
203
+ minimum=0.0,
204
+ maximum=1,
205
+ step=0.01,
206
+ interactive=True,
207
+ info=(
208
+ "Sample from the smallest possible set of tokens whose cumulative probability "
209
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
210
+ ),
211
+ )
212
+ with gr.Column():
213
+ with gr.Row():
214
+ top_k = gr.Slider(
215
+ label="Top-k",
216
+ value=0,
217
+ minimum=0.0,
218
+ maximum=200,
219
+ step=1,
220
+ interactive=True,
221
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
222
+ )
223
+ with gr.Column():
224
+ with gr.Row():
225
+ repetition_penalty = gr.Slider(
226
+ label="Repetition Penalty",
227
+ value=1.1,
228
+ minimum=1.0,
229
+ maximum=2.0,
230
+ step=0.1,
231
+ interactive=True,
232
+ info="Penalize repetition — 1.0 to disable.",
233
+ )
234
+ # with gr.Row():
235
+ # gr.Markdown(
236
+ # "demo 2",
237
+ # elem_classes=["disclaimer"],
238
+ # )
239
+
240
+ submit_event = msg.submit(
241
+ fn=user,
242
+ inputs=[msg, chatbot],
243
+ outputs=[msg, chatbot],
244
+ queue=False,
245
+ ).then(
246
+ fn=bot,
247
+ inputs=[
248
+ chatbot,
249
+ temperature,
250
+ top_p,
251
+ top_k,
252
+ repetition_penalty,
253
+ conversation_id,
254
+ ],
255
+ outputs=chatbot,
256
+ queue=True,
257
+ )
258
+ submit_click_event = submit.click(
259
+ fn=user,
260
+ inputs=[msg, chatbot],
261
+ outputs=[msg, chatbot],
262
+ queue=False,
263
+ ).then(
264
+ fn=bot,
265
+ inputs=[
266
+ chatbot,
267
+ temperature,
268
+ top_p,
269
+ top_k,
270
+ repetition_penalty,
271
+ conversation_id,
272
+ ],
273
+ outputs=chatbot,
274
+ queue=True,
275
+ )
276
+ stop.click(
277
+ fn=None,
278
+ inputs=None,
279
+ outputs=None,
280
+ cancels=[submit_event, submit_click_event],
281
+ queue=False,
282
+ )
283
+ clear.click(lambda: None, None, chatbot, queue=False)
284
+
285
+ demo.queue(max_size=128, concurrency_count=2)
286
+ demo.launch(server_name="0.0.0.0",server_port=7777)
287
+