wiklif commited on
Commit
25d52ff
·
1 Parent(s): 1b4f1a9

revert to old api, fix prompt

Browse files
Files changed (1) hide show
  1. app.py +43 -47
app.py CHANGED
@@ -1,56 +1,52 @@
1
- import spaces
2
  from huggingface_hub import InferenceClient
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- client = InferenceClient('mistralai/Mixtral-8x7B-Instruct-v0.1')
 
6
 
7
- @spaces.GPU(duration=60)
8
- def generate_response(chat, kwargs):
9
- output = ''
10
- stream = client.text_generation(chat, **kwargs, stream=True, details=True, return_full_text=False)
11
  for response in stream:
12
  output += response.token.text
 
13
  return output
14
 
15
- def function(prompt, history):
16
- chat = "<s>"
17
- for user_prompt, bot_response in history:
18
- chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> "
19
- chat += f"[INST] {prompt} [/INST]"
20
- kwargs = dict(
21
- temperature=0.80,
22
- max_new_tokens=2048,
23
- top_p=0.95,
24
- repetition_penalty=1.0,
25
- do_sample=True, # Upewnij się, że używasz próbkowania
26
- seed=1337
27
- )
28
 
29
- try:
30
- output = generate_response(chat, kwargs)
31
- yield output
32
- except:
33
- yield ''
34
-
35
- interface = gr.ChatInterface(
36
- fn=function,
37
- chatbot=gr.Chatbot(
38
- avatar_images=None,
39
- container=False,
40
- show_copy_button=True,
41
- layout='bubble',
42
- render_markdown=True,
43
- line_breaks=True
44
- ),
45
- css='h1 {font-size:22px;} h2 {font-size:20px;} h3 {font-size:18px;} h4 {font-size:16px;}',
46
- autofocus=True,
47
- fill_height=True,
48
- analytics_enabled=False,
49
- submit_btn='Chat',
50
- stop_btn=None,
51
- retry_btn=None,
52
- undo_btn=None,
53
- clear_btn=None
54
- )
55
-
56
- interface.launch(show_api=True)
 
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
+ import spaces
4
+
5
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
6
+
7
+ def format_prompt(message, history):
8
+ prompt = ""
9
+ for user_prompt, bot_response in history:
10
+ prompt += f"<s>[INST] {user_prompt} [/INST] {bot_response}</s>"
11
+ prompt += f"<s>[INST] {message} [/INST]</s>"
12
+ return prompt
13
+
14
+ @spaces.GPU
15
+ def generate(
16
+ prompt, history, temperature=0, max_new_tokens=4096, top_p=0.95, repetition_penalty=1.0,
17
+ ):
18
+ temperature = float(temperature)
19
+ if temperature < 1e-2:
20
+ temperature = 1e-2
21
+ top_p = float(top_p)
22
+
23
+ generate_kwargs = dict(
24
+ temperature=temperature,
25
+ max_new_tokens=max_new_tokens,
26
+ top_p=top_p,
27
+ repetition_penalty=repetition_penalty,
28
+ do_sample=True,
29
+ seed=42,
30
+ )
31
+
32
+ formatted_prompt = format_prompt(prompt, history)
33
 
34
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
+ output = ""
36
 
 
 
 
 
37
  for response in stream:
38
  output += response.token.text
39
+ yield output
40
  return output
41
 
42
+ mychatbot = gr.Chatbot(
43
+ bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ demo = gr.ChatInterface(fn=generate,
46
+ chatbot=mychatbot,
47
+ title="Test API :)",
48
+ retry_btn=None,
49
+ undo_btn=None
50
+ )
51
+
52
+ demo.queue().launch(show_api=True)