shawn810720 yentinglin commited on
Commit
2973e07
0 Parent(s):

Duplicate from yentinglin/Taiwan-LLaMa2

Browse files

Co-authored-by: Yen-Ting Lin <yentinglin@users.noreply.huggingface.co>

Files changed (5) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +266 -0
  4. conversation.py +271 -0
  5. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tw Llama Demo
3
+ emoji: 💻
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: yentinglin/Taiwan-LLaMa2
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ from text_generation import Client
5
+ from conversation import get_default_conv_template
6
+ from transformers import AutoTokenizer
7
+ from pymongo import MongoClient
8
+
9
+ DB_NAME = os.getenv("MONGO_DBNAME", "taiwan-llm")
10
+ USER = os.getenv("MONGO_USER")
11
+ PASSWORD = os.getenv("MONGO_PASSWORD")
12
+
13
+ uri = f"mongodb+srv://{USER}:{PASSWORD}@{DB_NAME}.kvwjiok.mongodb.net/?retryWrites=true&w=majority"
14
+ mongo_client = MongoClient(uri)
15
+ db = mongo_client[DB_NAME]
16
+ conversations_collection = db['conversations']
17
+
18
+ DESCRIPTION = """
19
+ # Language Models for Taiwanese Culture
20
+
21
+ <p align="center">
22
+ ✍️ <a href="https://huggingface.co/spaces/yentinglin/Taiwan-LLaMa2" target="_blank">Online Demo</a>
23
+
24
+ 🤗 <a href="https://huggingface.co/yentinglin" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/yentinglin56" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/pdf/2305.13711.pdf" target="_blank">[Paper Coming Soon]</a>
25
+ • 👨️ <a href="https://github.com/MiuLab/Taiwan-LLaMa/tree/main" target="_blank">Github Repo</a>
26
+ <br/><br/>
27
+ <img src="https://www.csie.ntu.edu.tw/~miulab/taiwan-llama/logo-v2.png" width="100"> <br/>
28
+ </p>
29
+
30
+
31
+ Taiwan-LLaMa is a fine-tuned model specifically designed for traditional mandarin applications. It is built upon the LLaMa 2 architecture and includes a pretraining phase with over 5 billion tokens and fine-tuning with over 490k multi-turn conversational data in Traditional Mandarin.
32
+
33
+ ## Key Features
34
+
35
+ 1. **Traditional Mandarin Support**: The model is fine-tuned to understand and generate text in Traditional Mandarin, making it suitable for Taiwanese culture and related applications.
36
+
37
+ 2. **Instruction-Tuned**: Further fine-tuned on conversational data to offer context-aware and instruction-following responses.
38
+
39
+ 3. **Performance on Vicuna Benchmark**: Taiwan-LLaMa's relative performance on Vicuna Benchmark is measured against models like GPT-4 and ChatGPT. It's particularly optimized for Taiwanese culture.
40
+
41
+ 4. **Flexible Customization**: Advanced options for controlling the model's behavior like system prompt, temperature, top-p, and top-k are available in the demo.
42
+
43
+ ## Model Versions
44
+
45
+ Different versions of Taiwan-LLaMa are available:
46
+
47
+ - **Taiwan-LLaMa v1.0 (This demo)**: Optimized for Taiwanese Culture
48
+ - **Taiwan-LLaMa v0.9**: Partial instruction set
49
+ - **Taiwan-LLaMa v0.0**: No Traditional Mandarin pretraining
50
+
51
+ The models can be accessed from the provided links in the Hugging Face repository.
52
+
53
+ Try out the demo to interact with Taiwan-LLaMa and experience its capabilities in handling Traditional Mandarin!
54
+ """
55
+
56
+ LICENSE = """
57
+ ## Licenses
58
+
59
+ - Code is licensed under Apache 2.0 License.
60
+ - Models are licensed under the LLAMA 2 Community License.
61
+ - By using this model, you agree to the terms and conditions specified in the license.
62
+ - By using this demo, you agree to share your input utterances with us to improve the model.
63
+
64
+ ## Acknowledgements
65
+
66
+ Taiwan-LLaMa project acknowledges the efforts of the [Meta LLaMa team](https://github.com/facebookresearch/llama) and [Vicuna team](https://github.com/lm-sys/FastChat) in democratizing large language models.
67
+ """
68
+
69
+ DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. You are built by NTU Miulab by Yen-Ting Lin for research purpose."
70
+
71
+ endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080")
72
+ client = Client(endpoint_url, timeout=120)
73
+ eos_token = "</s>"
74
+ MAX_MAX_NEW_TOKENS = 1024
75
+ DEFAULT_MAX_NEW_TOKENS = 1024
76
+
77
+ max_prompt_length = 4096 - MAX_MAX_NEW_TOKENS - 10
78
+
79
+ model_name = "yentinglin/Taiwan-LLaMa-v1.0"
80
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
81
+
82
+ with gr.Blocks() as demo:
83
+ gr.Markdown(DESCRIPTION)
84
+
85
+ chatbot = gr.Chatbot()
86
+ with gr.Row():
87
+ msg = gr.Textbox(
88
+ container=False,
89
+ show_label=False,
90
+ placeholder='Type a message...',
91
+ scale=10,
92
+ )
93
+ submit_button = gr.Button('Submit',
94
+ variant='primary',
95
+ scale=1,
96
+ min_width=0)
97
+
98
+ with gr.Row():
99
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
100
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
101
+ clear = gr.Button('🗑️ Clear', variant='secondary')
102
+
103
+ saved_input = gr.State()
104
+
105
+ with gr.Accordion(label='Advanced options', open=False):
106
+ system_prompt = gr.Textbox(label='System prompt',
107
+ value=DEFAULT_SYSTEM_PROMPT,
108
+ lines=6)
109
+ max_new_tokens = gr.Slider(
110
+ label='Max new tokens',
111
+ minimum=1,
112
+ maximum=MAX_MAX_NEW_TOKENS,
113
+ step=1,
114
+ value=DEFAULT_MAX_NEW_TOKENS,
115
+ )
116
+ temperature = gr.Slider(
117
+ label='Temperature',
118
+ minimum=0.1,
119
+ maximum=1.0,
120
+ step=0.1,
121
+ value=0.7,
122
+ )
123
+ top_p = gr.Slider(
124
+ label='Top-p (nucleus sampling)',
125
+ minimum=0.05,
126
+ maximum=1.0,
127
+ step=0.05,
128
+ value=0.9,
129
+ )
130
+ top_k = gr.Slider(
131
+ label='Top-k',
132
+ minimum=1,
133
+ maximum=1000,
134
+ step=1,
135
+ value=50,
136
+ )
137
+
138
+ def user(user_message, history):
139
+ return "", history + [[user_message, None]]
140
+
141
+
142
+ def bot(history, max_new_tokens, temperature, top_p, top_k, system_prompt):
143
+ conv = get_default_conv_template("vicuna").copy()
144
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
145
+ conv.system = system_prompt
146
+ for user, bot in history:
147
+ conv.append_message(roles['human'], user)
148
+ conv.append_message(roles["gpt"], bot)
149
+ msg = conv.get_prompt()
150
+ prompt_tokens = tokenizer.encode(msg)
151
+ length_of_prompt = len(prompt_tokens)
152
+ if length_of_prompt > max_prompt_length:
153
+ msg = tokenizer.decode(prompt_tokens[-max_prompt_length + 1:])
154
+
155
+ history[-1][1] = ""
156
+ for response in client.generate_stream(
157
+ msg,
158
+ max_new_tokens=max_new_tokens,
159
+ temperature=temperature,
160
+ top_p=top_p,
161
+ top_k=top_k,
162
+ ):
163
+ if not response.token.special:
164
+ character = response.token.text
165
+ history[-1][1] += character
166
+ yield history
167
+
168
+ # After generating the response, store the conversation history in MongoDB
169
+ conversation_document = {
170
+ "model_name": model_name,
171
+ "history": history,
172
+ "system_prompt": system_prompt,
173
+ "max_new_tokens": max_new_tokens,
174
+ "temperature": temperature,
175
+ "top_p": top_p,
176
+ "top_k": top_k,
177
+ }
178
+ conversations_collection.insert_one(conversation_document)
179
+
180
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
181
+ fn=bot,
182
+ inputs=[
183
+ chatbot,
184
+ max_new_tokens,
185
+ temperature,
186
+ top_p,
187
+ top_k,
188
+ system_prompt,
189
+ ],
190
+ outputs=chatbot
191
+ )
192
+ submit_button.click(
193
+ user, [msg, chatbot], [msg, chatbot], queue=False
194
+ ).then(
195
+ fn=bot,
196
+ inputs=[
197
+ chatbot,
198
+ max_new_tokens,
199
+ temperature,
200
+ top_p,
201
+ top_k,
202
+ system_prompt,
203
+ ],
204
+ outputs=chatbot
205
+ )
206
+
207
+
208
+ def delete_prev_fn(
209
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
210
+ try:
211
+ message, _ = history.pop()
212
+ except IndexError:
213
+ message = ''
214
+ return history, message or ''
215
+
216
+
217
+ def display_input(message: str,
218
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
219
+ history.append((message, ''))
220
+ return history
221
+
222
+ retry_button.click(
223
+ fn=delete_prev_fn,
224
+ inputs=chatbot,
225
+ outputs=[chatbot, saved_input],
226
+ api_name=False,
227
+ queue=False,
228
+ ).then(
229
+ fn=display_input,
230
+ inputs=[saved_input, chatbot],
231
+ outputs=chatbot,
232
+ api_name=False,
233
+ queue=False,
234
+ ).then(
235
+ fn=bot,
236
+ inputs=[
237
+ chatbot,
238
+ max_new_tokens,
239
+ temperature,
240
+ top_p,
241
+ top_k,
242
+ system_prompt,
243
+ ],
244
+ outputs=chatbot,
245
+ )
246
+
247
+ undo_button.click(
248
+ fn=delete_prev_fn,
249
+ inputs=chatbot,
250
+ outputs=[chatbot, saved_input],
251
+ api_name=False,
252
+ queue=False,
253
+ ).then(
254
+ fn=lambda x: x,
255
+ inputs=[saved_input],
256
+ outputs=msg,
257
+ api_name=False,
258
+ queue=False,
259
+ )
260
+
261
+ clear.click(lambda: None, None, chatbot, queue=False)
262
+
263
+ gr.Markdown(LICENSE)
264
+
265
+ demo.queue(concurrency_count=4, max_size=128)
266
+ demo.launch()
conversation.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt template.
3
+ Now we support
4
+ - Vicuna
5
+ - Koala
6
+ - OpenAssistant/oasst-sft-1-pythia-12b
7
+ - StabilityAI/stablelm-tuned-alpha-7b
8
+ - databricks/dolly-v2-12b
9
+ - THUDM/chatglm-6b
10
+ - Alpaca/LLaMa
11
+ """
12
+
13
+ import dataclasses
14
+ from enum import auto, Enum
15
+ from typing import List, Tuple, Any
16
+
17
+
18
+ class SeparatorStyle(Enum):
19
+ """Different separator style."""
20
+
21
+ SINGLE = auto()
22
+ TWO = auto()
23
+ DOLLY = auto()
24
+ OASST_PYTHIA = auto()
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class Conversation:
29
+ """A class that keeps all conversation history."""
30
+
31
+ system: str
32
+ roles: List[str]
33
+ messages: List[List[str]]
34
+ offset: int
35
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
36
+ sep: str = "###"
37
+ sep2: str = None
38
+
39
+ # Used for gradio server
40
+ skip_next: bool = False
41
+ conv_id: Any = None
42
+
43
+ def get_prompt(self):
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system
46
+ for role, message in self.messages:
47
+ if message:
48
+ ret += self.sep + " " + role + ": " + message
49
+ else:
50
+ ret += self.sep + " " + role + ":"
51
+ return ret
52
+ elif self.sep_style == SeparatorStyle.TWO:
53
+ seps = [self.sep, self.sep2]
54
+ ret = self.system + seps[0]
55
+ for i, (role, message) in enumerate(self.messages):
56
+ if message:
57
+ ret += role + ": " + message + seps[i % 2]
58
+ else:
59
+ ret += role + ":"
60
+ return ret
61
+ elif self.sep_style == SeparatorStyle.DOLLY:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system
64
+ for i, (role, message) in enumerate(self.messages):
65
+ if message:
66
+ ret += role + ":\n" + message + seps[i % 2]
67
+ if i % 2 == 1:
68
+ ret += "\n\n"
69
+ else:
70
+ ret += role + ":\n"
71
+ return ret
72
+ elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
73
+ ret = self.system
74
+ for role, message in self.messages:
75
+ if message:
76
+ ret += role + message + self.sep
77
+ else:
78
+ ret += role
79
+ return ret
80
+ else:
81
+ raise ValueError(f"Invalid style: {self.sep_style}")
82
+
83
+ def append_message(self, role, message):
84
+ self.messages.append([role, message])
85
+
86
+ def to_gradio_chatbot(self):
87
+ ret = []
88
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
89
+ if i % 2 == 0:
90
+ ret.append([msg, None])
91
+ else:
92
+ ret[-1][-1] = msg
93
+ return ret
94
+
95
+ def copy(self):
96
+ return Conversation(
97
+ system=self.system,
98
+ roles=self.roles,
99
+ messages=[[x, y] for x, y in self.messages],
100
+ offset=self.offset,
101
+ sep_style=self.sep_style,
102
+ sep=self.sep,
103
+ sep2=self.sep2,
104
+ conv_id=self.conv_id,
105
+ )
106
+
107
+ def dict(self):
108
+ return {
109
+ "system": self.system,
110
+ "roles": self.roles,
111
+ "messages": self.messages,
112
+ "offset": self.offset,
113
+ "sep": self.sep,
114
+ "sep2": self.sep2,
115
+ "conv_id": self.conv_id,
116
+ }
117
+
118
+
119
+ conv_one_shot = Conversation(
120
+ system="A chat between a curious human and an artificial intelligence assistant. "
121
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
122
+ roles=("Human", "Assistant"),
123
+ messages=(
124
+ (
125
+ "Human",
126
+ "What are the key differences between renewable and non-renewable energy sources?",
127
+ ),
128
+ (
129
+ "Assistant",
130
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
131
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
132
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
133
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
134
+ "renewable and non-renewable energy sources:\n"
135
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
136
+ "energy sources are finite and will eventually run out.\n"
137
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
138
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
139
+ "and other negative effects.\n"
140
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
141
+ "have lower operational costs than non-renewable sources.\n"
142
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
143
+ "locations than non-renewable sources.\n"
144
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
145
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
146
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
147
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.",
148
+ ),
149
+ ),
150
+ offset=2,
151
+ sep_style=SeparatorStyle.SINGLE,
152
+ sep="###",
153
+ )
154
+
155
+
156
+ conv_vicuna_v1_1 = Conversation(
157
+ system="A chat between a curious user and an artificial intelligence assistant. "
158
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. You are built by NTU Miulab by Yen-Ting Lin for research purpose.",
159
+ # system="一位好奇的用戶和一個人工智能助理之間的聊天。你是一位助理。請對用戶的問題提供有用、詳細和有禮貌的答案。",
160
+ roles=("USER", "ASSISTANT"),
161
+ messages=(),
162
+ offset=0,
163
+ sep_style=SeparatorStyle.TWO,
164
+ sep=" ",
165
+ sep2="</s>",
166
+ )
167
+
168
+ conv_story = Conversation(
169
+ system="A chat between a curious user and an artificial intelligence assistant. "
170
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
171
+ roles=("USER", "ASSISTANT"),
172
+ messages=(),
173
+ offset=0,
174
+ sep_style=SeparatorStyle.TWO,
175
+ sep=" ",
176
+ sep2="<|endoftext|>",
177
+ )
178
+
179
+ conv_koala_v1 = Conversation(
180
+ system="BEGINNING OF CONVERSATION:",
181
+ roles=("USER", "GPT"),
182
+ messages=(),
183
+ offset=0,
184
+ sep_style=SeparatorStyle.TWO,
185
+ sep=" ",
186
+ sep2="</s>",
187
+ )
188
+
189
+ conv_dolly = Conversation(
190
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
191
+ roles=("### Instruction", "### Response"),
192
+ messages=(),
193
+ offset=0,
194
+ sep_style=SeparatorStyle.DOLLY,
195
+ sep="\n\n",
196
+ sep2="### End",
197
+ )
198
+
199
+ conv_oasst = Conversation(
200
+ system="",
201
+ roles=("<|prompter|>", "<|assistant|>"),
202
+ messages=(),
203
+ offset=0,
204
+ sep_style=SeparatorStyle.OASST_PYTHIA,
205
+ sep="<|endoftext|>",
206
+ )
207
+
208
+ conv_stablelm = Conversation(
209
+ system="""<|SYSTEM|># StableLM Tuned (Alpha version)
210
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
211
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
212
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
213
+ - StableLM will refuse to participate in anything that could harm a human.
214
+ """,
215
+ roles=("<|USER|>", "<|ASSISTANT|>"),
216
+ messages=(),
217
+ offset=0,
218
+ sep_style=SeparatorStyle.OASST_PYTHIA,
219
+ sep="",
220
+ )
221
+
222
+ conv_templates = {
223
+ "conv_one_shot": conv_one_shot,
224
+ "vicuna_v1.1": conv_vicuna_v1_1,
225
+ "koala_v1": conv_koala_v1,
226
+ "dolly": conv_dolly,
227
+ "oasst": conv_oasst,
228
+ }
229
+
230
+
231
+ def get_default_conv_template(model_name):
232
+ model_name = model_name.lower()
233
+ if "vicuna" in model_name or "output" in model_name:
234
+ return conv_vicuna_v1_1
235
+ elif "koala" in model_name:
236
+ return conv_koala_v1
237
+ elif "dolly-v2" in model_name:
238
+ return conv_dolly
239
+ elif "oasst" in model_name and "pythia" in model_name:
240
+ return conv_oasst
241
+ elif "stablelm" in model_name:
242
+ return conv_stablelm
243
+ return conv_one_shot
244
+
245
+
246
+ def compute_skip_echo_len(model_name, conv, prompt):
247
+ model_name = model_name.lower()
248
+ if "chatglm" in model_name:
249
+ skip_echo_len = len(conv.messages[-2][1]) + 1
250
+ elif "dolly-v2" in model_name:
251
+ special_toks = ["### Instruction:", "### Response:", "### End"]
252
+ skip_echo_len = len(prompt)
253
+ for tok in special_toks:
254
+ skip_echo_len -= prompt.count(tok) * len(tok)
255
+ elif "oasst" in model_name and "pythia" in model_name:
256
+ special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
257
+ skip_echo_len = len(prompt)
258
+ for tok in special_toks:
259
+ skip_echo_len -= prompt.count(tok) * len(tok)
260
+ elif "stablelm" in model_name:
261
+ special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
262
+ skip_echo_len = len(prompt)
263
+ for tok in special_toks:
264
+ skip_echo_len -= prompt.count(tok) * len(tok)
265
+ else:
266
+ skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
267
+ return skip_echo_len
268
+
269
+
270
+ if __name__ == "__main__":
271
+ print(default_conversation.get_prompt())
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ text-generation==0.6.0
2
+ transformers==4.31.0
3
+ pymongo==4.4.1