alpindale commited on
Commit
b931852
1 Parent(s): 8057df6

Upload folder using huggingface_hub

Browse files
ChatApp/app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import os
3
+ import logging
4
+ import gradio as gr
5
+ import gc
6
+ from interface.hddr_llama_onnx_interface import LlamaOnnxInterface
7
+ from interface.empty_stub_interface import EmptyStubInterface
8
+ from ChatApp.app_modules.utils import (
9
+ reset_textbox,
10
+ transfer_input,
11
+ reset_state,
12
+ delete_last_conversation,
13
+ cancel_outputing,
14
+ )
15
+ from ChatApp.app_modules.presets import (
16
+ small_and_beautiful_theme,
17
+ title,
18
+ description_top,
19
+ description,
20
+ )
21
+ from ChatApp.app_modules.overwrites import postprocess
22
+
23
+ logging.basicConfig(
24
+ level=logging.DEBUG,
25
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
26
+ )
27
+
28
+ # we can filter this dictionary at the start according to the actual available files on disk
29
+ empty_stub_model_name = "_Empty Stub_"
30
+
31
+ top_directory = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
32
+
33
+ tokenizer_path = os.path.join(top_directory, "tokenizer.model")
34
+
35
+ available_models = {
36
+ "Llama-2 7B Float16": {
37
+ "onnx_file": os.path.join(
38
+ top_directory, "FP16", "LlamaV2_7B_float16.onnx"
39
+ ),
40
+ "tokenizer_path": tokenizer_path,
41
+ "embedding_file": os.path.join(top_directory, "embeddings.pth"),
42
+ },
43
+ "Llama-2 7B FP32": {
44
+ "onnx_file": os.path.join(
45
+ top_directory, "FP32", "LlamaV2_7B_FT_float16.onnx"
46
+ ),
47
+ "tokenizer_path": tokenizer_path,
48
+ "embedding_file": os.path.join(
49
+ top_directory, "embeddings.pth"
50
+ ),
51
+ },
52
+ }
53
+
54
+
55
+ interface = EmptyStubInterface()
56
+ interface.initialize()
57
+
58
+ # interface = None
59
+
60
+ gr.Chatbot.postprocess = postprocess
61
+
62
+ with open("ChatApp/assets/custom.css", "r", encoding="utf-8") as f:
63
+ custom_css = f.read()
64
+
65
+
66
+ def change_model_listener(new_model_name):
67
+ if new_model_name is None:
68
+ new_model_name = empty_stub_model_name
69
+
70
+ global interface
71
+
72
+ # if a model exists - shut it down before trying to create the new one
73
+ if interface is not None:
74
+ interface.shutdown()
75
+ del interface
76
+ gc.collect()
77
+
78
+ logging.info(f"Creating a new model [{new_model_name}]")
79
+ if new_model_name == empty_stub_model_name:
80
+ interface = EmptyStubInterface()
81
+ interface.initialize()
82
+ else:
83
+ d = available_models[new_model_name]
84
+ interface = LlamaOnnxInterface(
85
+ onnx_file=d["onnx_file"],
86
+ tokenizer_path=d["tokenizer_path"],
87
+ embedding_file=d["embedding_file"],
88
+ )
89
+ interface.initialize()
90
+
91
+ return new_model_name
92
+
93
+
94
+ def interface_predict(*args):
95
+ global interface
96
+ res = interface.predict(*args)
97
+
98
+ for x in res:
99
+ yield x
100
+
101
+
102
+ def interface_retry(*args):
103
+ global interface
104
+ res = interface.retry(*args)
105
+
106
+ for x in res:
107
+ yield x
108
+
109
+
110
+ with gr.Blocks(css=custom_css, theme=small_and_beautiful_theme) as demo:
111
+ history = gr.State([])
112
+ user_question = gr.State("")
113
+ with gr.Row():
114
+ gr.HTML(title)
115
+ status_display = gr.Markdown("Success", elem_id="status_display")
116
+ gr.Markdown(description_top)
117
+
118
+ with gr.Row():
119
+ with gr.Column(scale=5):
120
+ with gr.Row():
121
+ chatbot = gr.Chatbot(elem_id="chuanhu_chatbot", height=900)
122
+ with gr.Row():
123
+ with gr.Column(scale=12):
124
+ user_input = gr.Textbox(show_label=False, placeholder="Enter text")
125
+ with gr.Column(min_width=70, scale=1):
126
+ submit_button = gr.Button("Send")
127
+ with gr.Column(min_width=70, scale=1):
128
+ cancel_button = gr.Button("Stop")
129
+ with gr.Row():
130
+ empty_button = gr.Button(
131
+ "🧹 New Conversation",
132
+ )
133
+ retry_button = gr.Button("🔄 Regenerate")
134
+ delete_last_button = gr.Button("🗑️ Remove Last Turn")
135
+ with gr.Column():
136
+ with gr.Column(min_width=50, scale=1):
137
+ with gr.Tab(label="Parameter Setting"):
138
+ gr.Markdown("# Model")
139
+ model_name = gr.Dropdown(
140
+ choices=[empty_stub_model_name] + list(available_models.keys()),
141
+ label="Model",
142
+ show_label=False, # default="Empty STUB",
143
+ )
144
+ model_name.change(
145
+ change_model_listener, inputs=[model_name], outputs=[model_name]
146
+ )
147
+
148
+ gr.Markdown("# Parameters")
149
+ top_p = gr.Slider(
150
+ minimum=-0,
151
+ maximum=1.0,
152
+ value=0.9,
153
+ step=0.05,
154
+ interactive=True,
155
+ label="Top-p",
156
+ )
157
+ temperature = gr.Slider(
158
+ minimum=0.1,
159
+ maximum=2.0,
160
+ value=0.75,
161
+ step=0.1,
162
+ interactive=True,
163
+ label="Temperature",
164
+ )
165
+ max_length_tokens = gr.Slider(
166
+ minimum=0,
167
+ maximum=512,
168
+ value=256,
169
+ step=8,
170
+ interactive=True,
171
+ label="Max Generation Tokens",
172
+ )
173
+ max_context_length_tokens = gr.Slider(
174
+ minimum=0,
175
+ maximum=4096,
176
+ value=2048,
177
+ step=128,
178
+ interactive=True,
179
+ label="Max History Tokens",
180
+ )
181
+ gr.Markdown(description)
182
+
183
+ predict_args = dict(
184
+ # fn=interface.predict,
185
+ fn=interface_predict,
186
+ inputs=[
187
+ user_question,
188
+ chatbot,
189
+ history,
190
+ top_p,
191
+ temperature,
192
+ max_length_tokens,
193
+ max_context_length_tokens,
194
+ ],
195
+ outputs=[chatbot, history, status_display],
196
+ show_progress=True,
197
+ )
198
+ retry_args = dict(
199
+ fn=interface_retry,
200
+ inputs=[
201
+ user_input,
202
+ chatbot,
203
+ history,
204
+ top_p,
205
+ temperature,
206
+ max_length_tokens,
207
+ max_context_length_tokens,
208
+ ],
209
+ outputs=[chatbot, history, status_display],
210
+ show_progress=True,
211
+ )
212
+
213
+ reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display])
214
+
215
+ # Chatbot
216
+ transfer_input_args = dict(
217
+ fn=transfer_input,
218
+ inputs=[user_input],
219
+ outputs=[user_question, user_input, submit_button],
220
+ show_progress=True,
221
+ )
222
+
223
+ predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
224
+
225
+ predict_event2 = submit_button.click(**transfer_input_args).then(**predict_args)
226
+
227
+ empty_button.click(
228
+ reset_state,
229
+ outputs=[chatbot, history, status_display],
230
+ show_progress=True,
231
+ )
232
+ empty_button.click(**reset_args)
233
+
234
+ predict_event3 = retry_button.click(**retry_args)
235
+
236
+ delete_last_button.click(
237
+ delete_last_conversation,
238
+ [chatbot, history],
239
+ [chatbot, history, status_display],
240
+ show_progress=True,
241
+ )
242
+ cancel_button.click(
243
+ cancel_outputing,
244
+ [],
245
+ [status_display],
246
+ cancels=[predict_event1, predict_event2, predict_event3],
247
+ )
248
+
249
+ demo.load(change_model_listener, inputs=None, outputs=model_name)
250
+
251
+ demo.title = "Llama-2 Chat UI"
252
+
253
+ demo.queue(concurrency_count=1).launch()
ChatApp/app_modules/__pycache__/overwrites.cpython-39.pyc ADDED
Binary file (1.15 kB). View file
 
ChatApp/app_modules/__pycache__/presets.cpython-39.pyc ADDED
Binary file (1.92 kB). View file
 
ChatApp/app_modules/__pycache__/utils.cpython-39.pyc ADDED
Binary file (6.25 kB). View file
 
ChatApp/app_modules/overwrites.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Tuple
3
+
4
+ from app_modules.presets import gr
5
+ from app_modules.utils import detect_converted_mark, convert_asis, convert_mdtext
6
+
7
+
8
+ def postprocess(
9
+ self, y: List[Tuple[str | None, str | None]]
10
+ ) -> List[Tuple[str | None, str | None]]:
11
+ """
12
+ Parameters:
13
+ y: List of tuples representing the message and response pairs.
14
+ Each message and response should be a string,
15
+ which may be in Markdown format.
16
+ Returns:
17
+ List of tuples representing the message and response.
18
+ Each message and response will be a string of HTML.
19
+ """
20
+ if y is None or y == []:
21
+ return []
22
+ temp = []
23
+ for x in y:
24
+ user, bot = x
25
+ if not detect_converted_mark(user):
26
+ user = convert_asis(user)
27
+ if not detect_converted_mark(bot):
28
+ bot = convert_mdtext(bot)
29
+ temp.append((user, bot))
30
+ return temp
31
+
32
+
33
+ GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
ChatApp/app_modules/presets.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import gradio as gr
3
+
4
+
5
+ title = """<h1 align="left" style="min-width:200px; margin-top:0;">Llama-2 Chat UI</h1>"""
6
+ description_top = """\
7
+ <div align="left">
8
+ Use at your own risk...
9
+ </p >
10
+ </div>
11
+ """
12
+ description = """\
13
+ <div align="center" style="margin:16px 0">
14
+ This is a chat demo using the ONNX versions of the Llama 2 model
15
+ </div>
16
+ """
17
+ CONCURRENT_COUNT = 100
18
+
19
+
20
+ ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"
21
+
22
+ small_and_beautiful_theme = gr.themes.Soft(
23
+ primary_hue=gr.themes.Color(
24
+ c50="#02C160",
25
+ c100="rgba(2, 193, 96, 0.2)",
26
+ c200="#02C160",
27
+ c300="rgba(2, 193, 96, 0.32)",
28
+ c400="rgba(2, 193, 96, 0.32)",
29
+ c500="rgba(2, 193, 96, 1.0)",
30
+ c600="rgba(2, 193, 96, 1.0)",
31
+ c700="rgba(2, 193, 96, 0.32)",
32
+ c800="rgba(2, 193, 96, 0.32)",
33
+ c900="#02C160",
34
+ c950="#02C160",
35
+ ),
36
+ secondary_hue=gr.themes.Color(
37
+ c50="#576b95",
38
+ c100="#576b95",
39
+ c200="#576b95",
40
+ c300="#576b95",
41
+ c400="#576b95",
42
+ c500="#576b95",
43
+ c600="#576b95",
44
+ c700="#576b95",
45
+ c800="#576b95",
46
+ c900="#576b95",
47
+ c950="#576b95",
48
+ ),
49
+ neutral_hue=gr.themes.Color(
50
+ name="gray",
51
+ c50="#f9fafb",
52
+ c100="#f3f4f6",
53
+ c200="#e5e7eb",
54
+ c300="#d1d5db",
55
+ c400="#B2B2B2",
56
+ c500="#808080",
57
+ c600="#636363",
58
+ c700="#515151",
59
+ c800="#393939",
60
+ c900="#272727",
61
+ c950="#171717",
62
+ ),
63
+ radius_size=gr.themes.sizes.radius_sm,
64
+ ).set(
65
+ button_primary_background_fill="#06AE56",
66
+ button_primary_background_fill_dark="#06AE56",
67
+ button_primary_background_fill_hover="#07C863",
68
+ button_primary_border_color="#06AE56",
69
+ button_primary_border_color_dark="#06AE56",
70
+ button_primary_text_color="#FFFFFF",
71
+ button_primary_text_color_dark="#FFFFFF",
72
+ button_secondary_background_fill="#F2F2F2",
73
+ button_secondary_background_fill_dark="#2B2B2B",
74
+ button_secondary_text_color="#393939",
75
+ button_secondary_text_color_dark="#FFFFFF",
76
+ background_fill_primary="#F7F7F7",
77
+ background_fill_primary_dark="#1F1F1F",
78
+ block_title_text_color="*primary_500",
79
+ block_title_background_fill="*primary_100",
80
+ input_background_fill="#F6F6F6",
81
+ )
ChatApp/app_modules/utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ from __future__ import annotations
3
+ import logging
4
+ import re
5
+ import html
6
+
7
+ import gradio as gr
8
+ import mdtex2html
9
+ from markdown import markdown
10
+ from pygments import highlight
11
+ from pygments.lexers import guess_lexer, get_lexer_by_name, ClassNotFound
12
+ from pygments.formatters import HtmlFormatter
13
+
14
+ from ChatApp.app_modules.presets import ALREADY_CONVERTED_MARK
15
+
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
19
+ )
20
+
21
+
22
+ def markdown_to_html_with_syntax_highlight(md_str):
23
+ def replacer(match):
24
+ lang = match.group(1) or "text"
25
+ code = match.group(2)
26
+ lang = lang.strip()
27
+ # print(1,lang)
28
+ if lang == "text":
29
+ lexer = guess_lexer(code)
30
+ lang = lexer.name
31
+ # print(2,lang)
32
+ try:
33
+ lexer = get_lexer_by_name(lang, stripall=True)
34
+ except ValueError:
35
+ lexer = get_lexer_by_name("python", stripall=True)
36
+ formatter = HtmlFormatter()
37
+ # print(3,lexer.name)
38
+ highlighted_code = highlight(code, lexer, formatter)
39
+
40
+ return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
41
+
42
+ code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
43
+ md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
44
+
45
+ html_str = markdown(md_str)
46
+ return html_str
47
+
48
+
49
+ def normalize_markdown(md_text: str) -> str:
50
+ lines = md_text.split("\n")
51
+ normalized_lines = []
52
+ inside_list = False
53
+
54
+ for i, line in enumerate(lines):
55
+ if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
56
+ if not inside_list and i > 0 and lines[i - 1].strip() != "":
57
+ normalized_lines.append("")
58
+ inside_list = True
59
+ normalized_lines.append(line)
60
+ elif inside_list and line.strip() == "":
61
+ if i < len(lines) - 1 and not re.match(
62
+ r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
63
+ ):
64
+ normalized_lines.append(line)
65
+ continue
66
+ else:
67
+ inside_list = False
68
+ normalized_lines.append(line)
69
+
70
+ return "\n".join(normalized_lines)
71
+
72
+
73
+ def convert_mdtext(md_text):
74
+ code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
75
+ inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
76
+ code_blocks = code_block_pattern.findall(md_text)
77
+ non_code_parts = code_block_pattern.split(md_text)[::2]
78
+
79
+ result = []
80
+ for non_code, code in zip(non_code_parts, code_blocks + [""]):
81
+ if non_code.strip():
82
+ non_code = normalize_markdown(non_code)
83
+ if inline_code_pattern.search(non_code):
84
+ result.append(markdown(non_code, extensions=["tables"]))
85
+ else:
86
+ result.append(mdtex2html.convert(non_code, extensions=["tables"]))
87
+ if code.strip():
88
+ code = f"\n```{code}\n\n```"
89
+ code = markdown_to_html_with_syntax_highlight(code)
90
+ result.append(code)
91
+ result = "".join(result)
92
+ result += ALREADY_CONVERTED_MARK
93
+ return result
94
+
95
+
96
+ def convert_asis(userinput):
97
+ return (
98
+ f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
99
+ + ALREADY_CONVERTED_MARK
100
+ )
101
+
102
+
103
+ def detect_converted_mark(userinput):
104
+ if userinput.endswith(ALREADY_CONVERTED_MARK):
105
+ return True
106
+ else:
107
+ return False
108
+
109
+
110
+ def detect_language(code):
111
+ if code.startswith("\n"):
112
+ first_line = ""
113
+ else:
114
+ first_line = code.strip().split("\n", 1)[0]
115
+ language = first_line.lower() if first_line else ""
116
+ code_without_language = code[len(first_line) :].lstrip() if first_line else code
117
+ return language, code_without_language
118
+
119
+
120
+ def convert_to_markdown(text):
121
+ text = text.replace("$", "&#36;")
122
+
123
+ def replace_leading_tabs_and_spaces(line):
124
+ new_line = []
125
+
126
+ for char in line:
127
+ if char == "\t":
128
+ new_line.append("&#9;")
129
+ elif char == " ":
130
+ new_line.append("&nbsp;")
131
+ else:
132
+ break
133
+ return "".join(new_line) + line[len(new_line) :]
134
+
135
+ markdown_text = ""
136
+ lines = text.split("\n")
137
+ in_code_block = False
138
+
139
+ for line in lines:
140
+ if in_code_block is False and line.startswith("```"):
141
+ in_code_block = True
142
+ markdown_text += f"{line}\n"
143
+ elif in_code_block is True and line.startswith("```"):
144
+ in_code_block = False
145
+ markdown_text += f"{line}\n"
146
+ elif in_code_block:
147
+ markdown_text += f"{line}\n"
148
+ else:
149
+ line = replace_leading_tabs_and_spaces(line)
150
+ line = re.sub(r"^(#)", r"\\\1", line)
151
+ markdown_text += f"{line} \n"
152
+
153
+ return markdown_text
154
+
155
+
156
+ def add_language_tag(text):
157
+ def detect_language(code_block):
158
+ try:
159
+ lexer = guess_lexer(code_block)
160
+ return lexer.name.lower()
161
+ except ClassNotFound:
162
+ return ""
163
+
164
+ code_block_pattern = re.compile(r"(```)(\w*\n[^`]+```)", re.MULTILINE)
165
+
166
+ def replacement(match):
167
+ code_block = match.group(2)
168
+ if match.group(2).startswith("\n"):
169
+ language = detect_language(code_block)
170
+ if language:
171
+ return f"```{language}{code_block}```"
172
+ else:
173
+ return f"```\n{code_block}```"
174
+ else:
175
+ return match.group(1) + code_block + "```"
176
+
177
+ text2 = code_block_pattern.sub(replacement, text)
178
+ return text2
179
+
180
+
181
+ def delete_last_conversation(chatbot, history):
182
+ if len(chatbot) > 0:
183
+ chatbot.pop()
184
+
185
+ if len(history) > 0:
186
+ history.pop()
187
+
188
+ return (
189
+ chatbot,
190
+ history,
191
+ "Delete Done",
192
+ )
193
+
194
+
195
+ def reset_state():
196
+ return [], [], "Reset Done"
197
+
198
+
199
+ def reset_textbox():
200
+ return gr.update(value=""), ""
201
+
202
+
203
+ def cancel_outputing():
204
+ return "Stop Done"
205
+
206
+
207
+ def transfer_input(inputs):
208
+ return (
209
+ inputs,
210
+ gr.update(value=""),
211
+ gr.Button.update(visible=True),
212
+ )
213
+
214
+
215
+ class State:
216
+ interrupted = False
217
+
218
+ def interrupt(self):
219
+ self.interrupted = True
220
+
221
+ def recover(self):
222
+ self.interrupted = False
223
+
224
+
225
+ shared_state = State()
226
+
227
+
228
+ def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
229
+ for stop_word in stop_words:
230
+ if s.endswith(stop_word):
231
+ return True
232
+ for i in range(1, len(stop_word)):
233
+ if s.endswith(stop_word[:i]):
234
+ return True
235
+ return False
ChatApp/assets/custom.css ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --chatbot-color-light: #F3F3F3;
3
+ --chatbot-color-dark: #121111;
4
+ }
5
+
6
+ /* status_display */
7
+ #status_display {
8
+ display: flex;
9
+ min-height: 2.5em;
10
+ align-items: flex-end;
11
+ justify-content: flex-end;
12
+ }
13
+
14
+ #status_display p {
15
+ font-size: .85em;
16
+ font-family: monospace;
17
+ color: var(--body-text-color-subdued);
18
+ }
19
+
20
+
21
+
22
+ /* usage_display */
23
+ #usage_display {
24
+ height: 1em;
25
+ }
26
+
27
+ #usage_display p {
28
+ padding: 0 1em;
29
+ font-size: .85em;
30
+ font-family: monospace;
31
+ color: var(--body-text-color-subdued);
32
+ }
33
+
34
+ /* list */
35
+ ol:not(.options),
36
+ ul:not(.options) {
37
+ padding-inline-start: 2em !important;
38
+ }
39
+
40
+ /* Thank @Keldos-Li for fixing it */
41
+ /* Light mode (default) */
42
+ #chuanhu_chatbot {
43
+ background-color: var(--chatbot-color-light) !important;
44
+ color: #000000 !important;
45
+ }
46
+
47
+ [data-testid="bot"] {
48
+ background-color: #FFFFFF !important;
49
+ }
50
+
51
+ [data-testid="user"] {
52
+ background-color: #95EC69 !important;
53
+ }
54
+
55
+ /* Dark mode */
56
+ .dark #chuanhu_chatbot {
57
+ background-color: var(--chatbot-color-dark) !important;
58
+ color: #FFFFFF !important;
59
+ }
60
+
61
+ .dark [data-testid="bot"] {
62
+ background-color: #2C2C2C !important;
63
+ }
64
+
65
+ .dark [data-testid="user"] {
66
+ background-color: #26B561 !important;
67
+ }
68
+
69
+ #chuanhu_chatbot {
70
+ height: 100%;
71
+ min-height: 400px;
72
+ }
73
+
74
+ [class *="message"] {
75
+ border-radius: var(--radius-xl) !important;
76
+ border: none;
77
+ padding: var(--spacing-xl) !important;
78
+ font-size: var(--text-md) !important;
79
+ line-height: var(--line-md) !important;
80
+ min-height: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
81
+ min-width: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
82
+ }
83
+
84
+ [data-testid="bot"] {
85
+ max-width: 85%;
86
+ border-bottom-left-radius: 0 !important;
87
+ }
88
+
89
+ [data-testid="user"] {
90
+ max-width: 85%;
91
+ width: auto !important;
92
+ border-bottom-right-radius: 0 !important;
93
+ }
94
+
95
+ /* Table */
96
+ table {
97
+ margin: 1em 0;
98
+ border-collapse: collapse;
99
+ empty-cells: show;
100
+ }
101
+
102
+ td,
103
+ th {
104
+ border: 1.2px solid var(--border-color-primary) !important;
105
+ padding: 0.2em;
106
+ }
107
+
108
+ thead {
109
+ background-color: rgba(175, 184, 193, 0.2);
110
+ }
111
+
112
+ thead th {
113
+ padding: .5em .2em;
114
+ }
115
+
116
+ /* Inline code */
117
+ #chuanhu_chatbot code {
118
+ display: inline;
119
+ white-space: break-spaces;
120
+ border-radius: 6px;
121
+ margin: 0 2px 0 2px;
122
+ padding: .2em .4em .1em .4em;
123
+ background-color: rgba(175, 184, 193, 0.2);
124
+ }
125
+
126
+ /* Code block */
127
+ #chuanhu_chatbot pre code {
128
+ display: block;
129
+ overflow: auto;
130
+ white-space: pre;
131
+ background-color: hsla(0, 0%, 0%, 80%) !important;
132
+ border-radius: 10px;
133
+ padding: 1.4em 1.2em 0em 1.4em;
134
+ margin: 1.2em 2em 1.2em 0.5em;
135
+ color: #FFFF;
136
+ box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
137
+ }
138
+
139
+ /* Hightlight */
140
+ #chuanhu_chatbot .highlight {
141
+ background-color: transparent
142
+ }
143
+
144
+ #chuanhu_chatbot .highlight .hll {
145
+ background-color: #49483e
146
+ }
147
+
148
+ #chuanhu_chatbot .highlight .c {
149
+ color: #75715e
150
+ }
151
+
152
+ /* Comment */
153
+ #chuanhu_chatbot .highlight .err {
154
+ color: #960050;
155
+ background-color: #1e0010
156
+ }
157
+
158
+ /* Error */
159
+ #chuanhu_chatbot .highlight .k {
160
+ color: #66d9ef
161
+ }
162
+
163
+ /* Keyword */
164
+ #chuanhu_chatbot .highlight .l {
165
+ color: #ae81ff
166
+ }
167
+
168
+ /* Literal */
169
+ #chuanhu_chatbot .highlight .n {
170
+ color: #8828f2
171
+ }
172
+
173
+ /* Name */
174
+ #chuanhu_chatbot .highlight .o {
175
+ color: #f92672
176
+ }
177
+
178
+ /* Operator */
179
+ #chuanhu_chatbot .highlight .p {
180
+ color: #482822
181
+ }
182
+
183
+ /* Punctuation */
184
+ #chuanhu_chatbot .highlight .ch {
185
+ color: #75715e
186
+ }
187
+
188
+ /* Comment.Hashbang */
189
+ #chuanhu_chatbot .highlight .cm {
190
+ color: #75715e
191
+ }
192
+
193
+ /* Comment.Multiline */
194
+ #chuanhu_chatbot .highlight .cp {
195
+ color: #75715e
196
+ }
197
+
198
+ /* Comment.Preproc */
199
+ #chuanhu_chatbot .highlight .cpf {
200
+ color: #75715e
201
+ }
202
+
203
+ /* Comment.PreprocFile */
204
+ #chuanhu_chatbot .highlight .c1 {
205
+ color: #75715e
206
+ }
207
+
208
+ /* Comment.Single */
209
+ #chuanhu_chatbot .highlight .cs {
210
+ color: #75715e
211
+ }
212
+
213
+ /* Comment.Special */
214
+ #chuanhu_chatbot .highlight .gd {
215
+ color: #f92672
216
+ }
217
+
218
+ /* Generic.Deleted */
219
+ #chuanhu_chatbot .highlight .ge {
220
+ font-style: italic
221
+ }
222
+
223
+ /* Generic.Emph */
224
+ #chuanhu_chatbot .highlight .gi {
225
+ color: #a6e22e
226
+ }
227
+
228
+ /* Generic.Inserted */
229
+ #chuanhu_chatbot .highlight .gs {
230
+ font-weight: bold
231
+ }
232
+
233
+ /* Generic.Strong */
234
+ #chuanhu_chatbot .highlight .gu {
235
+ color: #75715e
236
+ }
237
+
238
+ /* Generic.Subheading */
239
+ #chuanhu_chatbot .highlight .kc {
240
+ color: #66d9ef
241
+ }
242
+
243
+ /* Keyword.Constant */
244
+ #chuanhu_chatbot .highlight .kd {
245
+ color: #66d9ef
246
+ }
247
+
248
+ /* Keyword.Declaration */
249
+ #chuanhu_chatbot .highlight .kn {
250
+ color: #f92672
251
+ }
252
+
253
+ /* Keyword.Namespace */
254
+ #chuanhu_chatbot .highlight .kp {
255
+ color: #66d9ef
256
+ }
257
+
258
+ /* Keyword.Pseudo */
259
+ #chuanhu_chatbot .highlight .kr {
260
+ color: #66d9ef
261
+ }
262
+
263
+ /* Keyword.Reserved */
264
+ #chuanhu_chatbot .highlight .kt {
265
+ color: #66d9ef
266
+ }
267
+
268
+ /* Keyword.Type */
269
+ #chuanhu_chatbot .highlight .ld {
270
+ color: #162b74
271
+ }
272
+
273
+ /* Literal.Date */
274
+ #chuanhu_chatbot .highlight .m {
275
+ color: #ae81ff
276
+ }
277
+
278
+ /* Literal.Number */
279
+ #chuanhu_chatbot .highlight .s {
280
+ color: #062b84
281
+ }
282
+
283
+ /* Literal.String */
284
+ #chuanhu_chatbot .highlight .na {
285
+ color: #a6e22e
286
+ }
287
+
288
+ /* Name.Attribute */
289
+ #chuanhu_chatbot .highlight .nb {
290
+ color: #482822
291
+ }
292
+
293
+ /* Name.Builtin */
294
+ #chuanhu_chatbot .highlight .nc {
295
+ color: #a6e22e
296
+ }
297
+
298
+ /* Name.Class */
299
+ #chuanhu_chatbot .highlight .no {
300
+ color: #66d9ef
301
+ }
302
+
303
+ /* Name.Constant */
304
+ #chuanhu_chatbot .highlight .nd {
305
+ color: #a6e22e
306
+ }
307
+
308
+ /* Name.Decorator */
309
+ #chuanhu_chatbot .highlight .ni {
310
+ color: #482822
311
+ }
312
+
313
+ /* Name.Entity */
314
+ #chuanhu_chatbot .highlight .ne {
315
+ color: #a6e22e
316
+ }
317
+
318
+ /* Name.Exception */
319
+ #chuanhu_chatbot .highlight .nf {
320
+ color: #a6e22e
321
+ }
322
+
323
+ /* Name.Function */
324
+ #chuanhu_chatbot .highlight .nl {
325
+ color: #1818f2
326
+ }
327
+
328
+ /* Name.Label */
329
+ #chuanhu_chatbot .highlight .nn {
330
+ color: #482822
331
+ }
332
+
333
+ /* Name.Namespace */
334
+ #chuanhu_chatbot .highlight .nx {
335
+ color: #a6e22e
336
+ }
337
+
338
+ /* Name.Other */
339
+ #chuanhu_chatbot .highlight .py {
340
+ color: #482822
341
+ }
342
+
343
+ /* Name.Property */
344
+ #chuanhu_chatbot .highlight .nt {
345
+ color: #f92672
346
+ }
347
+
348
+ /* Name.Tag */
349
+ #chuanhu_chatbot .highlight .nv {
350
+ color: #482822
351
+ }
352
+
353
+ /* Name.Variable */
354
+ #chuanhu_chatbot .highlight .ow {
355
+ color: #f92672
356
+ }
357
+
358
+ /* Operator.Word */
359
+ #chuanhu_chatbot .highlight .w {
360
+ color: #482822
361
+ }
362
+
363
+ /* Text.Whitespace */
364
+ #chuanhu_chatbot .highlight .mb {
365
+ color: #ae81ff
366
+ }
367
+
368
+ /* Literal.Number.Bin */
369
+ #chuanhu_chatbot .highlight .mf {
370
+ color: #ae81ff
371
+ }
372
+
373
+ /* Literal.Number.Float */
374
+ #chuanhu_chatbot .highlight .mh {
375
+ color: #ae81ff
376
+ }
377
+
378
+ /* Literal.Number.Hex */
379
+ #chuanhu_chatbot .highlight .mi {
380
+ color: #ae81ff
381
+ }
382
+
383
+ /* Literal.Number.Integer */
384
+ #chuanhu_chatbot .highlight .mo {
385
+ color: #ae81ff
386
+ }
387
+
388
+ /* Literal.Number.Oct */
389
+ #chuanhu_chatbot .highlight .sa {
390
+ color: #162b74
391
+ }
392
+
393
+ /* Literal.String.Affix */
394
+ #chuanhu_chatbot .highlight .sb {
395
+ color: #161b74
396
+ }
397
+
398
+ /* Literal.String.Backtick */
399
+ #chuanhu_chatbot .highlight .sc {
400
+ color: #162b74
401
+ }
402
+
403
+ /* Literal.String.Char */
404
+ #chuanhu_chatbot .highlight .dl {
405
+ color: #162b74
406
+ }
407
+
408
+ /* Literal.String.Delimiter */
409
+ #chuanhu_chatbot .highlight .sd {
410
+ color: #162b74
411
+ }
412
+
413
+ /* Literal.String.Doc */
414
+ #chuanhu_chatbot .highlight .s2 {
415
+ color: #162b74
416
+ }
417
+
418
+ /* Literal.String.Double */
419
+ #chuanhu_chatbot .highlight .se {
420
+ color: #ae81ff
421
+ }
422
+
423
+ /* Literal.String.Escape */
424
+ #chuanhu_chatbot .highlight .sh {
425
+ color: #162b74
426
+ }
427
+
428
+ /* Literal.String.Heredoc */
429
+ #chuanhu_chatbot .highlight .si {
430
+ color: #162b74
431
+ }
432
+
433
+ /* Literal.String.Interpol */
434
+ #chuanhu_chatbot .highlight .sx {
435
+ color: #162b74
436
+ }
437
+
438
+ /* Literal.String.Other */
439
+ #chuanhu_chatbot .highlight .sr {
440
+ color: #162b74
441
+ }
442
+
443
+ /* Literal.String.Regex */
444
+ #chuanhu_chatbot .highlight .s1 {
445
+ color: #162b74
446
+ }
447
+
448
+ /* Literal.String.Single */
449
+ #chuanhu_chatbot .highlight .ss {
450
+ color: #162b74
451
+ }
452
+
453
+ /* Literal.String.Symbol */
454
+ #chuanhu_chatbot .highlight .bp {
455
+ color: #482822
456
+ }
457
+
458
+ /* Name.Builtin.Pseudo */
459
+ #chuanhu_chatbot .highlight .fm {
460
+ color: #a6e22e
461
+ }
462
+
463
+ /* Name.Function.Magic */
464
+ #chuanhu_chatbot .highlight .vc {
465
+ color: #482822
466
+ }
467
+
468
+ /* Name.Variable.Class */
469
+ #chuanhu_chatbot .highlight .vg {
470
+ color: #482822
471
+ }
472
+
473
+ /* Name.Variable.Global */
474
+ #chuanhu_chatbot .highlight .vi {
475
+ color: #482822
476
+ }
477
+
478
+ /* Name.Variable.Instance */
479
+ #chuanhu_chatbot .highlight .vm {
480
+ color: #482822
481
+ }
482
+
483
+ /* Name.Variable.Magic */
484
+ #chuanhu_chatbot .highlight .il {
485
+ color: #ae81ff
486
+ }
487
+
488
+ /* Literal.Number.Integer.Long */
ChatApp/assets/custom.js ADDED
@@ -0,0 +1 @@
 
 
1
+ // custom javascript here
ChatApp/interface/__pycache__/base_interface.cpython-39.pyc ADDED
Binary file (574 Bytes). View file
 
ChatApp/interface/__pycache__/empty_stub_interface.cpython-39.pyc ADDED
Binary file (1.33 kB). View file
 
ChatApp/interface/__pycache__/hddr_llama_onnx_interface.cpython-39.pyc ADDED
Binary file (8.94 kB). View file
 
ChatApp/interface/base_interface.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ class BaseLLMInterface:
2
+ def __init__(self):
3
+ pass
4
+
5
+ def foo(self):
6
+ pass
ChatApp/interface/empty_stub_interface.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app_modules.utils import logging
2
+
3
+
4
+ class EmptyStubInterface:
5
+ def __init__(self):
6
+ pass
7
+
8
+ def initialize(self):
9
+ pass
10
+
11
+ def shutdown(self):
12
+ pass
13
+
14
+ def predict(
15
+ self,
16
+ text,
17
+ chatbot,
18
+ history,
19
+ top_p,
20
+ temperature,
21
+ max_length_tokens,
22
+ max_context_length_tokens,
23
+ ):
24
+ logging.info("hi there")
25
+ logging.info("-" * 100)
26
+ # yield chatbot,history,"Empty context."
27
+ yield [[text, "No Model Found"]], [], "No Model Found"
28
+
29
+ def retry(
30
+ self,
31
+ text,
32
+ chatbot,
33
+ history,
34
+ top_p,
35
+ temperature,
36
+ max_length_tokens,
37
+ max_context_length_tokens,
38
+ ):
39
+ yield chatbot, history, "Empty context"
ChatApp/interface/hddr_llama_onnx_interface.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import onnxruntime
3
+ import numpy as np
4
+ from sentencepiece import SentencePieceProcessor
5
+ from typing import List
6
+ import os
7
+ import logging
8
+ import gc
9
+
10
+ from .base_interface import BaseLLMInterface
11
+
12
+ from ChatApp.app_modules.utils import (
13
+ is_stop_word_or_prefix,
14
+ convert_to_markdown,
15
+ shared_state,
16
+ )
17
+
18
+
19
+ class Tokenizer:
20
+ def __init__(self, model_path: str):
21
+ # reload tokenizer
22
+ assert os.path.isfile(model_path), model_path
23
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
24
+
25
+ # BOS / EOS token IDs
26
+ self.n_words: int = self.sp_model.vocab_size()
27
+ self.bos_id: int = self.sp_model.bos_id()
28
+ self.eos_id: int = self.sp_model.eos_id()
29
+ self.pad_id: int = self.sp_model.pad_id()
30
+
31
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
32
+
33
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
34
+ assert type(s) is str
35
+ t = self.sp_model.encode(s)
36
+ if bos:
37
+ t = [self.bos_id] + t
38
+ if eos:
39
+ t = t + [self.eos_id]
40
+ return t
41
+
42
+ def decode(self, t: List[int]) -> str:
43
+ return self.sp_model.decode(t)
44
+
45
+
46
+ class LlamaOnnxInterface(BaseLLMInterface):
47
+ def __init__(self, onnx_file="", embedding_file="", tokenizer_path=""):
48
+ super().__init__()
49
+
50
+ self.onnx_file = onnx_file
51
+ self.embedding_file = embedding_file
52
+ self.tokenizer_path = tokenizer_path
53
+
54
+ self.total_count = 0
55
+
56
+ def initialize(self):
57
+ # Create the ONNX session
58
+
59
+ logging.info(f"Creating ONNX session for [{self.onnx_file}]")
60
+ options = onnxruntime.SessionOptions()
61
+ self.llm_session = onnxruntime.InferenceSession(
62
+ self.onnx_file,
63
+ sess_options=options,
64
+ providers=[
65
+ "DmlExecutionProvider",
66
+ "CUDAExecutionProvider",
67
+ "CPUExecutionProvider",
68
+ ],
69
+ )
70
+
71
+ # get the data type used by the model
72
+ data_type_str = self.llm_session.get_inputs()[0].type
73
+ if data_type_str == "tensor(float16)":
74
+ self.data_type = np.float16
75
+ elif data_type_str == "tensor(float32)":
76
+ self.data_type = np.float32
77
+ else:
78
+ raise Exception(f"Unknown data type {data_type_str}")
79
+
80
+ logging.info(f"Detected Data Type [{self.data_type}]")
81
+
82
+ # Get the relevant shapes so we can create the inputs
83
+ for inputs_meta in self.llm_session._inputs_meta:
84
+ if inputs_meta.name == "x":
85
+ x_shape = inputs_meta.shape
86
+ elif inputs_meta.name == "attn_mask":
87
+ attn_mask_shape = inputs_meta.shape
88
+ elif inputs_meta.name == "k_cache":
89
+ k_cache_shape = inputs_meta.shape
90
+
91
+ self.hidden_size = x_shape[2]
92
+ self.max_seq_len = attn_mask_shape[1]
93
+ self.n_layers = k_cache_shape[1]
94
+ self.n_heads = k_cache_shape[3]
95
+
96
+ # Initialize the tokenizer and produce the initial tokens.
97
+ self.tokenizer = Tokenizer(model_path=self.tokenizer_path)
98
+
99
+ # create the embedding layer.
100
+ logging.info(
101
+ f"Creating the Embedding Layer. Size [{self.tokenizer.n_words}, {self.hidden_size}]"
102
+ )
103
+ self.embeddingLayer = torch.nn.Embedding(
104
+ self.tokenizer.n_words, self.hidden_size
105
+ )
106
+
107
+ # rg hack - dont have the embeddings.pth file - taking it from the original llama model
108
+ d = torch.load(self.embedding_file)
109
+ self.embeddingLayer.load_state_dict(d)
110
+ self.embeddingLayer.eval()
111
+
112
+ # Create the attention mask.
113
+ self.attn_mask = -10000.0 * torch.triu(
114
+ torch.ones(attn_mask_shape), diagonal=1
115
+ ).cpu().detach().numpy().astype(self.data_type)
116
+
117
+ # Create the K and V caches.
118
+ self.head_dim = int(self.hidden_size / self.n_heads)
119
+ self.k_cache = np.zeros(
120
+ [1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
121
+ dtype=self.data_type,
122
+ )
123
+ self.v_cache = np.zeros(
124
+ [1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
125
+ dtype=self.data_type,
126
+ )
127
+
128
+ def shutdown(self):
129
+ pass
130
+
131
+ def generate_prompt_with_history(self, text, history, tokenizer, max_length=2048):
132
+ prompt = "[|Human|]Hey there I am a human that would like to have\
133
+ a conversation with you.\n[|AI|]Sure, I am happy to answer most questions\
134
+ \n[|Human|]Great, I insist that we take turns.\n[|AI|]I agree, we should\
135
+ take turns.\n[|Human|]Great, can we also keep answers short\n[|AI|]Yes, \
136
+ short answers are usually best"
137
+
138
+ history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0], x[1]) for x in history]
139
+ history.append("\n[|Human|]{}\n[|AI|]".format(text))
140
+ history_text = ""
141
+ flag = False
142
+ for x in history[::-1]:
143
+ # tokens = self.tokenizer.encode(text, bos=True, eos=False)
144
+ if (
145
+ len(
146
+ self.tokenizer.encode(
147
+ prompt + history_text + x, bos=True, eos=False
148
+ )
149
+ )
150
+ <= max_length
151
+ ):
152
+ history_text = x + history_text
153
+ flag = True
154
+ else:
155
+ break
156
+ if flag:
157
+ return prompt + history_text, torch.tensor(
158
+ self.tokenizer.encode(prompt + history_text, bos=True, eos=False)
159
+ ).unsqueeze(0)
160
+ else:
161
+ return None
162
+
163
+ def sample_logits(
164
+ self,
165
+ logits: np.ndarray,
166
+ sampling_method: str = "greedy",
167
+ sampling_value: float = None,
168
+ temperature: float = 1.0,
169
+ ) -> np.ndarray:
170
+ if temperature == 0 or sampling_method == "greedy":
171
+ next_token = np.argmax(logits, axis=-1).astype(np.int64)
172
+
173
+ elif sampling_method == "top_k" or sampling_method == "top_p":
174
+ assert sampling_value is not None
175
+
176
+ # temperature, converting to probabilities and sorting are common to both top-k and top-p
177
+ # convert logits to 32-bit float to avoid numerical issues with np.exp
178
+ logits = logits.astype(np.float32)
179
+ # Scale the logits by the temperature
180
+ logits /= temperature
181
+ # Convert logits to probabilities
182
+ probs = np.exp(logits) / np.sum(np.exp(logits))
183
+ # Sort th probabilities and indexes
184
+ sorted_probs = np.sort(probs)[:, ::-1]
185
+ sorted_indices = np.argsort(probs)[:, ::-1]
186
+
187
+ # find the index of interest for each of the methods.
188
+ if sampling_method == "top_k":
189
+ index_of_interest = int(sampling_value)
190
+ elif sampling_method == "top_p":
191
+ p = sampling_value
192
+ cumulative_probs = np.cumsum(sorted_probs, axis=-1)
193
+ # find the value of the first cumalitive probability that exceeds p
194
+ for index_of_interest, cumulative_prob in enumerate(
195
+ cumulative_probs[0]
196
+ ):
197
+ if cumulative_prob > p:
198
+ break
199
+
200
+ probs_of_interest = sorted_probs[:, : index_of_interest + 1]
201
+ indices_of_interest = sorted_indices[:, : index_of_interest + 1]
202
+ # Normalize the probabilities and select the next token
203
+ probs_of_interest /= np.sum(probs_of_interest)
204
+ next_token = np.array(
205
+ [np.random.choice(indices_of_interest[0], p=probs_of_interest[0])]
206
+ )
207
+ else:
208
+ raise Exception(f"Unknown sampling method {sampling_method}")
209
+
210
+ return next_token
211
+
212
+ def greedy_search(
213
+ self,
214
+ input_ids,
215
+ model,
216
+ tokenizer,
217
+ stop_words: list,
218
+ max_length: int,
219
+ temperature: float = 1.0,
220
+ top_p: float = 1.0,
221
+ top_k: int = 25,
222
+ ):
223
+ generated_tokens = []
224
+ pos = np.array(0)
225
+
226
+ x = (
227
+ self.embeddingLayer(torch.tensor(input_ids))
228
+ .detach()
229
+ .cpu()
230
+ .numpy()
231
+ .astype(self.data_type)
232
+ )
233
+
234
+ for i in range(max_length):
235
+ results = self.llm_session.run(
236
+ None,
237
+ {
238
+ "x": x,
239
+ "attn_mask": self.attn_mask,
240
+ "k_cache": self.k_cache[:, :, :pos],
241
+ "v_cache": self.v_cache[:, :, :pos],
242
+ "pos": pos.astype(np.int64),
243
+ },
244
+ )
245
+ logits, k_out, v_out = results[:3]
246
+
247
+ next_token = self.sample_logits(logits, "top_p", top_p, temperature)
248
+ next_token = next_token.reshape(1, -1)
249
+
250
+ # Stop if/when we get an ENDOFTEXT token before reaching maximum sequence length
251
+ if next_token[0] == tokenizer.eos_id:
252
+ del logits
253
+ gc.collect()
254
+ return
255
+
256
+ input_ids = torch.cat((input_ids, torch.tensor(next_token)), dim=-1)
257
+
258
+ generated_tokens.append(next_token[0].item())
259
+ text = tokenizer.decode(generated_tokens)
260
+
261
+ seq_len = x.shape[1]
262
+ self.k_cache[:, :, pos : pos + seq_len] = k_out
263
+ self.v_cache[:, :, pos : pos + seq_len] = v_out
264
+ pos = np.array(int(pos) + seq_len)
265
+
266
+ x = (
267
+ self.embeddingLayer(torch.tensor(next_token))
268
+ .unsqueeze(0)
269
+ .reshape([1, 1, self.hidden_size])
270
+ .cpu()
271
+ .detach()
272
+ .numpy()
273
+ .astype(self.data_type)
274
+ )
275
+
276
+ yield text
277
+
278
+ if any([x in text for x in stop_words]):
279
+ del logits
280
+ gc.collect()
281
+ return
282
+
283
+ def predict(
284
+ self,
285
+ text,
286
+ chatbot,
287
+ history,
288
+ top_p,
289
+ temperature,
290
+ max_length_tokens,
291
+ max_context_length_tokens,
292
+ ):
293
+ if text == "":
294
+ yield chatbot, history, "Empty context."
295
+ return
296
+ try:
297
+ self.llm_session
298
+ except (ValueError, RuntimeError, TypeError):
299
+ yield [[text, "No Model Found"]], [], "No Model Found"
300
+ return
301
+
302
+ inputs = self.generate_prompt_with_history(
303
+ text, history, self.tokenizer, max_length=max_context_length_tokens
304
+ )
305
+
306
+ if inputs is None:
307
+ yield chatbot, history, "Input too long."
308
+ return
309
+ else:
310
+ prompt, inputs = inputs
311
+
312
+ input_ids = inputs[:, -max_context_length_tokens:]
313
+
314
+ # global total_count
315
+ self.total_count += 1
316
+ print(self.total_count)
317
+
318
+ self.head_dim = int(self.hidden_size / self.n_heads)
319
+ self.k_cache = np.zeros(
320
+ [1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
321
+ dtype=self.data_type,
322
+ )
323
+ self.v_cache = np.zeros(
324
+ [1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
325
+ dtype=self.data_type,
326
+ )
327
+
328
+ x = input_ids
329
+
330
+ for x in self.greedy_search(
331
+ input_ids,
332
+ self.llm_session,
333
+ self.tokenizer,
334
+ stop_words=["[|Human|]", "[|AI|]"],
335
+ max_length=max_length_tokens,
336
+ temperature=temperature,
337
+ top_p=top_p,
338
+ ):
339
+ if is_stop_word_or_prefix(x, ["[|Human|]", "[|AI|]"]) is False:
340
+ if "[|Human|]" in x:
341
+ x = x[: x.index("[|Human|]")].strip()
342
+ if "[|AI|]" in x:
343
+ x = x[: x.index("[|AI|]")].strip()
344
+ x = x.strip()
345
+ a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [
346
+ [text, convert_to_markdown(x)]
347
+ ], history + [[text, x]]
348
+ yield a, b, "Generating..."
349
+ if shared_state.interrupted:
350
+ shared_state.recover()
351
+ try:
352
+ yield a, b, "Stop: Success"
353
+ return
354
+ except Exception as e:
355
+ print(type(e).__name__, e)
356
+ pass
357
+
358
+ del input_ids
359
+ gc.collect()
360
+ torch.cuda.empty_cache()
361
+
362
+ try:
363
+ yield a, b, "Generate: Success"
364
+ except Exception as e:
365
+ print(type(e).__name__, e)
366
+ pass
367
+
368
+ return
369
+
370
+ def retry(
371
+ self,
372
+ text,
373
+ chatbot,
374
+ history,
375
+ top_p,
376
+ temperature,
377
+ max_length_tokens,
378
+ max_context_length_tokens,
379
+ ):
380
+ logging.info("Retry...")
381
+ if len(history) == 0:
382
+ yield chatbot, history, "Empty context"
383
+ return
384
+ chatbot.pop()
385
+ inputs = history.pop()[0]
386
+ for x in self.predict(
387
+ inputs,
388
+ chatbot,
389
+ history,
390
+ top_p,
391
+ temperature,
392
+ max_length_tokens,
393
+ max_context_length_tokens,
394
+ ):
395
+ yield x
ChatApp/requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ mdtex2html
3
+ pypinyin
4
+ tiktoken
5
+ socksio
6
+ tqdm
7
+ colorama
8
+ duckduckgo_search
9
+ Pygments
10
+ llama_index
11
+ langchain
12
+ markdown
13
+ markdown2
14
+ torch
15
+ git+https://github.com/huggingface/peft.git
16
+ git+https://github.com/huggingface/transformers.git
17
+ SentencePiece
18
+ onnxruntime-gpu