CreitinGameplays commited on
Commit
1dc34d7
1 Parent(s): 067742a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python
2
-
3
  import os
4
  from threading import Thread
5
  from typing import Iterator
@@ -9,32 +7,39 @@ import spaces
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
- DESCRIPTION = "# ConvAI 9b"
13
- hf_token = os.getenv("hf_token")
 
 
 
 
 
14
 
15
  if not torch.cuda.is_available():
16
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
17
 
18
- MAX_MAX_NEW_TOKENS = 2048
19
- DEFAULT_MAX_NEW_TOKENS = 512
20
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
 
22
  if torch.cuda.is_available():
23
  model_id = "CreitinGameplays/ConvAI-9b"
24
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=hf_token)
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
26
 
27
  @spaces.GPU
28
  def generate(
29
  message: str,
30
  chat_history: list[tuple[str, str]],
31
- max_new_tokens: int = 512,
32
- temperature: float = 0.2,
 
33
  top_p: float = 0.9,
34
  top_k: int = 50,
35
  repetition_penalty: float = 1.2,
36
  ) -> Iterator[str]:
37
  conversation = []
 
 
38
  for user, assistant in chat_history:
39
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
  conversation.append({"role": "user", "content": message})
@@ -45,7 +50,7 @@ def generate(
45
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
46
  input_ids = input_ids.to(model.device)
47
 
48
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
49
  generate_kwargs = dict(
50
  {"input_ids": input_ids},
51
  streamer=streamer,
@@ -69,6 +74,7 @@ def generate(
69
  chat_interface = gr.ChatInterface(
70
  fn=generate,
71
  additional_inputs=[
 
72
  gr.Slider(
73
  label="Max new tokens",
74
  minimum=1,
@@ -81,7 +87,7 @@ chat_interface = gr.ChatInterface(
81
  minimum=0.1,
82
  maximum=4.0,
83
  step=0.1,
84
- value=0.2,
85
  ),
86
  gr.Slider(
87
  label="Top-p (nucleus sampling)",
@@ -115,14 +121,5 @@ chat_interface = gr.ChatInterface(
115
  ],
116
  )
117
 
118
- with gr.Blocks(css="style.css") as demo:
119
- gr.Markdown(DESCRIPTION)
120
- gr.DuplicateButton(
121
- value="Duplicate Space for private use",
122
- elem_id="duplicate-button",
123
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
124
- )
125
- chat_interface.render()
126
-
127
  if __name__ == "__main__":
128
- demo.queue(max_size=20).launch()
 
 
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ MAX_MAX_NEW_TOKENS = 2048
11
+ DEFAULT_MAX_NEW_TOKENS = 1024
12
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
+
14
+ DESCRIPTION = """\
15
+ # ConvAI 9b
16
+ """
17
 
18
  if not torch.cuda.is_available():
19
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
20
 
 
 
 
21
 
22
  if torch.cuda.is_available():
23
  model_id = "CreitinGameplays/ConvAI-9b"
24
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ tokenizer.use_default_system_prompt = False
27
+
28
 
29
  @spaces.GPU
30
  def generate(
31
  message: str,
32
  chat_history: list[tuple[str, str]],
33
+ system_prompt: str,
34
+ max_new_tokens: int = 1024,
35
+ temperature: float = 0.4,
36
  top_p: float = 0.9,
37
  top_k: int = 50,
38
  repetition_penalty: float = 1.2,
39
  ) -> Iterator[str]:
40
  conversation = []
41
+ if system_prompt:
42
+ conversation.append({"role": "system", "content": system_prompt})
43
  for user, assistant in chat_history:
44
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
45
  conversation.append({"role": "user", "content": message})
 
50
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
51
  input_ids = input_ids.to(model.device)
52
 
53
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
54
  generate_kwargs = dict(
55
  {"input_ids": input_ids},
56
  streamer=streamer,
 
74
  chat_interface = gr.ChatInterface(
75
  fn=generate,
76
  additional_inputs=[
77
+ gr.Textbox(label="System prompt", lines=6),
78
  gr.Slider(
79
  label="Max new tokens",
80
  minimum=1,
 
87
  minimum=0.1,
88
  maximum=4.0,
89
  step=0.1,
90
+ value=0.4,
91
  ),
92
  gr.Slider(
93
  label="Top-p (nucleus sampling)",
 
121
  ],
122
  )
123
 
 
 
 
 
 
 
 
 
 
124
  if __name__ == "__main__":
125
+ demo.queue(max_size=20).launch()