File size: 6,923 Bytes
acce665
 
31b44f7
acce665
31b44f7
acce665
31b44f7
 
 
 
 
 
 
acce665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b44f7
 
 
 
 
 
 
 
 
 
 
acce665
 
 
31b44f7
 
acce665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b44f7
acce665
31b44f7
acce665
 
31b44f7
acce665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0330679
 
 
 
 
 
 
acce665
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
import time
from openai import OpenAI
import re
import os

client = OpenAI(api_key=os.environ.get("openai"))
model="gpt-3.5-turbo"

# llm = Llama(model_path="./snorkel-mistral-pairrm-dpo.Q4_K_M.gguf", 
#                   chat_format="chatml", 
#                   n_gpu_layers=0, # cpu only
#                   n_ctx=6000)
def split_text(text, llm, chunk_size):
    text_newline = text.split('\n')
    text_newline = [t for t in text_newline if len(t)>0]
    summary_list=[]
    new_list=[]
    for i, t in enumerate(text_newline):
        new_list.append(t)
        n_tokens=get_num_tokens('\n\n\n'.join(new_list), llm)
        if i==(len(text_newline)-1):
            summary_list.append('\n\n'.join(new_list))
        elif n_tokens>chunk_size:
            summary_list.append('\n\n'.join(new_list))
            new_list=[]
    return summary_list
def all_to_list(all_sum, llm, chunk_size):
    summary_list = split_text(all_sum, llm, chunk_size)
    len_chunks = [get_num_tokens(chunk, llm) for chunk in summary_list]
    print(f'len_chunks: {len_chunks}')
    print(f'total parts: {len(summary_list)}')
    return summary_list

def clean_output(text):
    text = text.replace('`','')
    text = re.sub(r'\d+\.', '', text) # removes numeric bullet points
    text = text.replace('- ',' ')
    text = text.replace('*','')
    text = text.replace('+','')
    return text

def get_content_length(messages, llm):
    # print(messages)
    # user_list=[m for m in messages if m['role']=='user']
    # assistant_list=[m for m in messages if m['role']=='assistant']
    system_list=[m for m in messages if m['role']=='system']
    # print(f'system: {system_list}')
    content_total=system_list[0]['content']
    for i, (m) in enumerate(messages[1:]):
        content_total+=m['content']
    return get_num_tokens(content_total, llm)
def pop_first_user_assistant(messages):
    new_messages=[entry for i, entry in enumerate(messages) if i not in [1,2]]
    return new_messages
def get_num_tokens(text, llm):
    bytes_string = text.encode('utf-8')
    tokens = llm.tokenize(bytes_string)
    return len(tokens)

def response_stream():
    global writer_messages, editor_messages, turn
    if turn=='writer':
        yield 
    else:
        yield
def adverse(message, history):
    global writer_messages, editor_messages, turn
    total_response = ''
    for i in range(4):
        # update writer_messages
        if len(writer_messages)==1: # first call
            writer_messages.append({
                'role':'user',
                'content':message,
            })
        # check whose turn it is 
        turn = 'writer' if len(writer_messages)%2==0 else 'editor'
        list_of_messages = writer_messages if turn=='writer' else editor_messages
        print(f'turn: {turn}\n\nlist_of_messages: {list_of_messages}')
        total_response+=f'\n\n\nturn: {turn}\n'
        #############################
        
        # call llm.create_chat_completion for whoever's turn
        # response_iter
        # response_str = f'writer {len(writer_messages)}' if turn=='writer' else f'editor {len(editor_messages)}'
        # response_iter = iter(response_str.split(' '))
        # response_iter = llm.create_chat_completion(
        #                 list_of_messages, # Prompt
        #                 max_tokens=-1,
        #                 stop=["###"],
        #                 stream=True
        #                 )
        response_iter = client.chat.completions.create(
            model=model,
            messages=list_of_messages,
            stream=True,
        )
        response=''
        for chunk in response_iter:
            try:
                response+=chunk.choices[0].delta.content
                total_response+=chunk.choices[0].delta.content
                time.sleep(1)
                # print(f'chunk: {chunk}')
                yield total_response
            except Exception as e:
                print(e)
        total_response+='\n\n'
        if turn=='editor': 
            response+='\nNow rewrite your response keeping my suggestions in mind.\n'
        #############################
        # update writer_messages and editor_messages
        if turn=='writer':
            writer_messages.append({
                'role':'assistant',
                'content':response,
            })
            editor_messages.append({
                'role':'user',
                'content':response,
            })
        else: # editor
            writer_messages.append({
                'role':'user',
                'content':response,
            })
            editor_messages.append({
                'role':'assistant',
                'content':response,
            })


max_tokens=4_000
chunk_size=1000
max_words = 10_000
print(f'max_words: {max_words}')
# llm = Llama(model_path="E:\\yt\\bookSummary/Snorkel-Mistral-PairRM-DPO/snorkel-mistral-pairrm-dpo.Q4_K_M.gguf", chat_format="chatml", n_gpu_layers=-1, n_ctx=6000)

writer_system_prompt_unformatted = '''You are a writer.
The topic is {topic}.'''

editor_system_prompt_unformatted = '''You are an editor. You give feedback about my writing.
The topic is {topic}.
You should reinforce me to adjust my content and improve it. 
You should push me to make my response as close as possible to the topic.'''


writer_messages = [{'role':'system','content':writer_system_prompt_unformatted}]
editor_messages = [{'role':'system','content':editor_system_prompt_unformatted}]

turn = 'writer'
def set_system_prompts(x):
    global writer_system_prompt, editor_system_prompt, writer_messages, editor_messages, writer_system_prompt_unformatted, editor_system_prompt_unformatted
    writer_system_prompt = writer_system_prompt_unformatted.format(topic=x)
    editor_system_prompt = editor_system_prompt_unformatted.format(topic=x)
    writer_messages = [{'role':'system','content':writer_system_prompt}]
    editor_messages = [{'role':'system','content':editor_system_prompt}]
    return f'writer system prompt:\n{writer_system_prompt}\n\neditor system prompt:\n{editor_system_prompt}'
with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Multi Agent LLMs for End-to-End Story Generation
    Start typing below to see the output.
    """)
    gr.Interface(
        fn=set_system_prompts, 
        inputs=gr.Textbox(placeholder="What is the topic?", label = 'Topic', lines=4), 
        outputs=gr.Textbox(label='System prompt to use', lines=4)
        )

    out_test = gr.Textbox(lines=4)
    button = gr.Button("test")
    button.click(lambda : f"{writer_system_prompt} \n\n\n{editor_system_prompt}", outputs=out_test)
    
    chat = gr.ChatInterface(
        fn=adverse, 
        examples=["Start the story", "Write a poem", 'The funniest joke ever!'], 
        title="Multi-Agent Bot",
        autofocus=False,
        fill_height=True,
        ).queue()

if __name__ == "__main__":
    demo.launch()