Files changed (3) hide show
  1. app.py +25 -17
  2. model.py +8 -8
  3. settings.py +0 -8
app.py CHANGED
@@ -7,15 +7,26 @@ import gradio as gr
7
  import torch
8
 
9
  from model import run
10
- from settings import (ALLOW_CHANGING_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS,
11
- DEFAULT_SYSTEM_PROMPT, MAX_MAX_NEW_TOKENS)
12
 
13
- DESCRIPTION = '# Llama-2 7B chat'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  if not torch.cuda.is_available():
15
  DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
16
 
17
- WRITEUP = """This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, running transformers latest release. Read more about the Llamav2 release on Huggingface in our [Blog](https://huggingface.co/blog/llama2). To have your own dedicated endpoint, you can [deploy it on Inference Endpoints](https://ui.endpoints.huggingface.co/) or duplicate the Space and provide for a GPU. We also have the [Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) demo running on Spaces. """
18
-
19
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
20
  return '', message
21
 
@@ -35,7 +46,7 @@ def delete_prev_fn(
35
  return history, message or ''
36
 
37
 
38
- def fn(
39
  message: str,
40
  history_with_input: list[tuple[str, str]],
41
  system_prompt: str,
@@ -61,10 +72,8 @@ def fn(
61
 
62
  with gr.Blocks(css='style.css') as demo:
63
  gr.Markdown(DESCRIPTION)
64
- gr.Markdown(WRITEUP)
65
  gr.DuplicateButton(value='Duplicate Space for private use',
66
- elem_id='duplicate-button',
67
- visible=os.getenv('SHOW_DUPLICATE_BUTTON') == '1')
68
 
69
  with gr.Group():
70
  chatbot = gr.Chatbot(label='Chatbot')
@@ -89,8 +98,7 @@ with gr.Blocks(css='style.css') as demo:
89
  with gr.Accordion(label='Advanced options', open=False):
90
  system_prompt = gr.Textbox(label='System prompt',
91
  value=DEFAULT_SYSTEM_PROMPT,
92
- lines=6,
93
- interactive=ALLOW_CHANGING_SYSTEM_PROMPT)
94
  max_new_tokens = gr.Slider(
95
  label='Max new tokens',
96
  minimum=1,
@@ -101,9 +109,9 @@ with gr.Blocks(css='style.css') as demo:
101
  temperature = gr.Slider(
102
  label='Temperature',
103
  minimum=0.1,
104
- maximum=5.0,
105
  step=0.1,
106
- value=0.8,
107
  )
108
  top_p = gr.Slider(
109
  label='Top-p (nucleus sampling)',
@@ -115,7 +123,7 @@ with gr.Blocks(css='style.css') as demo:
115
  top_k = gr.Slider(
116
  label='Top-k',
117
  minimum=1,
118
- maximum=50,
119
  step=1,
120
  value=50,
121
  )
@@ -133,7 +141,7 @@ with gr.Blocks(css='style.css') as demo:
133
  api_name=False,
134
  queue=False,
135
  ).then(
136
- fn=fn,
137
  inputs=[
138
  saved_input,
139
  chatbot,
@@ -160,7 +168,7 @@ with gr.Blocks(css='style.css') as demo:
160
  api_name=False,
161
  queue=False,
162
  ).then(
163
- fn=fn,
164
  inputs=[
165
  saved_input,
166
  chatbot,
@@ -187,7 +195,7 @@ with gr.Blocks(css='style.css') as demo:
187
  api_name=False,
188
  queue=False,
189
  ).then(
190
- fn=fn,
191
  inputs=[
192
  saved_input,
193
  chatbot,
 
7
  import torch
8
 
9
  from model import run
 
 
10
 
11
+ DEFAULT_SYSTEM_PROMPT = """
12
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
13
+ """
14
+ MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+
17
+ DESCRIPTION = """
18
+ # Llama-2 7B Chat
19
+
20
+ This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, a Llama 2 model with 7B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
21
+
22
+ 🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
23
+
24
+ 🔨 Looking for an even more powerful model? Check out the large [70B model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
25
+ """
26
+
27
  if not torch.cuda.is_available():
28
  DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
29
 
 
 
30
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
31
  return '', message
32
 
 
46
  return history, message or ''
47
 
48
 
49
+ def generate(
50
  message: str,
51
  history_with_input: list[tuple[str, str]],
52
  system_prompt: str,
 
72
 
73
  with gr.Blocks(css='style.css') as demo:
74
  gr.Markdown(DESCRIPTION)
 
75
  gr.DuplicateButton(value='Duplicate Space for private use',
76
+ elem_id='duplicate-button')
 
77
 
78
  with gr.Group():
79
  chatbot = gr.Chatbot(label='Chatbot')
 
98
  with gr.Accordion(label='Advanced options', open=False):
99
  system_prompt = gr.Textbox(label='System prompt',
100
  value=DEFAULT_SYSTEM_PROMPT,
101
+ lines=6)
 
102
  max_new_tokens = gr.Slider(
103
  label='Max new tokens',
104
  minimum=1,
 
109
  temperature = gr.Slider(
110
  label='Temperature',
111
  minimum=0.1,
112
+ maximum=4.0,
113
  step=0.1,
114
+ value=1.0,
115
  )
116
  top_p = gr.Slider(
117
  label='Top-p (nucleus sampling)',
 
123
  top_k = gr.Slider(
124
  label='Top-k',
125
  minimum=1,
126
+ maximum=1000,
127
  step=1,
128
  value=50,
129
  )
 
141
  api_name=False,
142
  queue=False,
143
  ).then(
144
+ fn=generate,
145
  inputs=[
146
  saved_input,
147
  chatbot,
 
168
  api_name=False,
169
  queue=False,
170
  ).then(
171
+ fn=generate,
172
  inputs=[
173
  saved_input,
174
  chatbot,
 
195
  api_name=False,
196
  queue=False,
197
  ).then(
198
+ fn=generate,
199
  inputs=[
200
  saved_input,
201
  chatbot,
model.py CHANGED
@@ -2,16 +2,16 @@ from threading import Thread
2
  from typing import Iterator
3
 
4
  import torch
5
- from transformers import (AutoModelForCausalLM, AutoTokenizer,
6
- TextIteratorStreamer)
7
 
8
  model_id = 'meta-llama/Llama-2-7b-chat-hf'
9
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
 
11
- if device.type == 'cuda':
12
- model = AutoModelForCausalLM.from_pretrained(model_id,
13
- load_in_8bit=True,
14
- device_map='auto')
 
 
15
  else:
16
  model = None
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -34,7 +34,7 @@ def run(message: str,
34
  top_p: float = 0.95,
35
  top_k: int = 50) -> Iterator[str]:
36
  prompt = get_prompt(message, chat_history, system_prompt)
37
- inputs = tokenizer([prompt], return_tensors='pt').to(device)
38
 
39
  streamer = TextIteratorStreamer(tokenizer,
40
  timeout=10.,
 
2
  from typing import Iterator
3
 
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
6
 
7
  model_id = 'meta-llama/Llama-2-7b-chat-hf'
 
8
 
9
+ if torch.cuda.is_available():
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ torch_dtype=torch.float16,
13
+ device_map='auto'
14
+ )
15
  else:
16
  model = None
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
34
  top_p: float = 0.95,
35
  top_k: int = 50) -> Iterator[str]:
36
  prompt = get_prompt(message, chat_history, system_prompt)
37
+ inputs = tokenizer([prompt], return_tensors='pt').to("cuda")
38
 
39
  streamer = TextIteratorStreamer(tokenizer,
40
  timeout=10.,
settings.py DELETED
@@ -1,8 +0,0 @@
1
- import os
2
-
3
- DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
4
- ALLOW_CHANGING_SYSTEM_PROMPT = os.getenv('ALLOW_CHANGING_SYSTEM_PROMPT',
5
- '0') == '1'
6
-
7
- MAX_MAX_NEW_TOKENS = int(os.getenv('MAX_MAX_NEW_TOKENS', '1024'))
8
- DEFAULT_MAX_NEW_TOKENS = int(os.getenv('DEFAULT_MAX_NEW_TOKENS', '256'))