davoodwadi commited on
Commit
acce665
1 Parent(s): ce8b1fb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from llama_cpp import Llama
4
+ from pathlib import Path
5
+ from tqdm.auto import tqdm
6
+ import re
7
+
8
+ llm = Llama(model_path="./snorkel-mistral-pairrm-dpo.Q4_K_M.gguf",
9
+ chat_format="chatml",
10
+ n_gpu_layers=0, # cpu only
11
+ n_ctx=6000)
12
+ def split_text(text, llm, chunk_size):
13
+ text_newline = text.split('\n')
14
+ text_newline = [t for t in text_newline if len(t)>0]
15
+ summary_list=[]
16
+ new_list=[]
17
+ for i, t in enumerate(text_newline):
18
+ new_list.append(t)
19
+ n_tokens=get_num_tokens('\n\n\n'.join(new_list), llm)
20
+ if i==(len(text_newline)-1):
21
+ summary_list.append('\n\n'.join(new_list))
22
+ elif n_tokens>chunk_size:
23
+ summary_list.append('\n\n'.join(new_list))
24
+ new_list=[]
25
+ return summary_list
26
+ def all_to_list(all_sum, llm, chunk_size):
27
+ summary_list = split_text(all_sum, llm, chunk_size)
28
+ len_chunks = [get_num_tokens(chunk, llm) for chunk in summary_list]
29
+ print(f'len_chunks: {len_chunks}')
30
+ print(f'total parts: {len(summary_list)}')
31
+ return summary_list
32
+
33
+ def clean_output(text):
34
+ text = text.replace('`','')
35
+ text = re.sub(r'\d+\.', '', text) # removes numeric bullet points
36
+ text = text.replace('- ',' ')
37
+ text = text.replace('*','')
38
+ text = text.replace('+','')
39
+ return text
40
+
41
+ def get_content_length(messages, llm):
42
+ # print(messages)
43
+ # user_list=[m for m in messages if m['role']=='user']
44
+ # assistant_list=[m for m in messages if m['role']=='assistant']
45
+ system_list=[m for m in messages if m['role']=='system']
46
+ # print(f'system: {system_list}')
47
+ content_total=system_list[0]['content']
48
+ for i, (m) in enumerate(messages[1:]):
49
+ content_total+=m['content']
50
+ return get_num_tokens(content_total, llm)
51
+ def pop_first_user_assistant(messages):
52
+ new_messages=[entry for i, entry in enumerate(messages) if i not in [1,2]]
53
+ return new_messages
54
+ def get_num_tokens(text, llm):
55
+ bytes_string = text.encode('utf-8')
56
+ tokens = llm.tokenize(bytes_string)
57
+ return len(tokens)
58
+
59
+ def response_stream():
60
+ global writer_messages, editor_messages, turn
61
+ if turn=='writer':
62
+ yield
63
+ else:
64
+ yield
65
+ def adverse(message, history):
66
+ global writer_messages, editor_messages, turn
67
+ total_response = ''
68
+ for i in range(4):
69
+ # update writer_messages
70
+ if len(writer_messages)==1: # first call
71
+ writer_messages.append({
72
+ 'role':'user',
73
+ 'content':message,
74
+ })
75
+ # check whose turn it is
76
+ turn = 'writer' if len(writer_messages)%2==0 else 'editor'
77
+ list_of_messages = writer_messages if turn=='writer' else editor_messages
78
+ print(f'turn: {turn}\n\nlist_of_messages: {list_of_messages}')
79
+ total_response+=f'\n\n\nturn: {turn}\n'
80
+ #############################
81
+
82
+ # call llm.create_chat_completion for whoever's turn
83
+ # response_iter
84
+ # response_str = f'writer {len(writer_messages)}' if turn=='writer' else f'editor {len(editor_messages)}'
85
+ # response_iter = iter(response_str.split(' '))
86
+ response_iter = llm.create_chat_completion(
87
+ list_of_messages, # Prompt
88
+ max_tokens=-1,
89
+ stop=["###"],
90
+ stream=True
91
+ )
92
+ response=''
93
+ for chunk in response_iter:
94
+ try:
95
+ response+=chunk['choices'][0]['delta']['content']
96
+ total_response+=chunk['choices'][0]['delta']['content']
97
+ time.sleep(1)
98
+ # print(f'chunk: {chunk}')
99
+ yield total_response
100
+ except Exception as e:
101
+ print(e)
102
+ total_response+='\n\n'
103
+ if turn=='editor':
104
+ response+='\nNow rewrite your response keeping my suggestions in mind.\n'
105
+ #############################
106
+ # update writer_messages and editor_messages
107
+ if turn=='writer':
108
+ writer_messages.append({
109
+ 'role':'assistant',
110
+ 'content':response,
111
+ })
112
+ editor_messages.append({
113
+ 'role':'user',
114
+ 'content':response,
115
+ })
116
+ else: # editor
117
+ writer_messages.append({
118
+ 'role':'user',
119
+ 'content':response,
120
+ })
121
+ editor_messages.append({
122
+ 'role':'assistant',
123
+ 'content':response,
124
+ })
125
+
126
+
127
+ max_tokens=4_000
128
+ chunk_size=1000
129
+ max_words = 10_000
130
+ print(f'max_words: {max_words}')
131
+ # llm = Llama(model_path="E:\\yt\\bookSummary/Snorkel-Mistral-PairRM-DPO/snorkel-mistral-pairrm-dpo.Q4_K_M.gguf", chat_format="chatml", n_gpu_layers=-1, n_ctx=6000)
132
+
133
+ setting = 'Scotland in twenty twenty five'
134
+ book = 'The Great Gatsby by F. Scott Fitzgerald'
135
+ total_summary=''
136
+ writer_system_prompt_unformatted = '''You are a writer.
137
+ The topic is {topic}.
138
+ I am your writing editor. I will give you feedback to guide your response in the right direction. You will obey my feedback adjust your response based on my feedback.'''
139
+
140
+ editor_system_prompt_unformatted = '''You are an editor.
141
+ The topic is {topic}.
142
+ You should reinforce me to adjust my content and improve it.
143
+ You should push me to make my response as close as possible to the topic'''
144
+
145
+ # print(writer_system_prompt.format(book=book, setting=setting))
146
+
147
+ writer_messages = [{'role':'system','content':writer_system_prompt_unformatted}]
148
+ editor_messages = [{'role':'system','content':editor_system_prompt_unformatted}]
149
+
150
+ turn = 'writer'
151
+ def set_system_prompts(x):
152
+ global writer_system_prompt, editor_system_prompt, writer_messages, editor_messages, writer_system_prompt_unformatted, editor_system_prompt_unformatted
153
+ writer_system_prompt = writer_system_prompt_unformatted.format(topic=x)
154
+ editor_system_prompt = editor_system_prompt_unformatted.format(topic=x)
155
+ writer_messages = [{'role':'system','content':writer_system_prompt}]
156
+ editor_messages = [{'role':'system','content':editor_system_prompt}]
157
+ return f'writer system prompt:\n{writer_system_prompt}\n\neditor system prompt:\n{editor_system_prompt}'
158
+ with gr.Blocks() as demo:
159
+ gr.Markdown(
160
+ """
161
+ # Multi Agent LLMs for End-to-End Story Generation
162
+ Start typing below to see the output.
163
+ """)
164
+ # inp_topic = gr.Textbox(placeholder="What is the topic?", label = 'Topic', lines=4)
165
+ # # inp_setting = gr.Textbox(placeholder="What is the setting of the book?", label='Setting')
166
+ # out_system = gr.Textbox(label='System prompt to use', lines=4)
167
+ # button = gr.Button("Set system prompt")
168
+ # button.click(set_system_prompts, inputs=inp_topic, outputs=out_system)
169
+ gr.Interface(
170
+ fn=set_system_prompts,
171
+ inputs=gr.Textbox(placeholder="What is the topic?", label = 'Topic', lines=4),
172
+ outputs=gr.Textbox(label='System prompt to use', lines=4)
173
+ )
174
+
175
+ out_test = gr.Textbox(lines=4)
176
+ button = gr.Button("test")
177
+ button.click(lambda : f"{writer_system_prompt} \n\n\n{editor_system_prompt}", outputs=out_test)
178
+
179
+ chat = gr.ChatInterface(fn=adverse, examples=["Start the story", "Write a poem", 'The funniest joke ever!'], title="Multi-Agent Bot").queue()
180
+
181
+ if __name__ == "__main__":
182
+ demo.launch()