davoodwadi's picture
Upload 2 files
31b44f7 verified
raw
history blame
No virus
6.92 kB
import gradio as gr
import time
from openai import OpenAI
import re
import os
client = OpenAI(api_key=os.environ.get("openai"))
model="gpt-3.5-turbo"
# llm = Llama(model_path="./snorkel-mistral-pairrm-dpo.Q4_K_M.gguf",
# chat_format="chatml",
# n_gpu_layers=0, # cpu only
# n_ctx=6000)
def split_text(text, llm, chunk_size):
text_newline = text.split('\n')
text_newline = [t for t in text_newline if len(t)>0]
summary_list=[]
new_list=[]
for i, t in enumerate(text_newline):
new_list.append(t)
n_tokens=get_num_tokens('\n\n\n'.join(new_list), llm)
if i==(len(text_newline)-1):
summary_list.append('\n\n'.join(new_list))
elif n_tokens>chunk_size:
summary_list.append('\n\n'.join(new_list))
new_list=[]
return summary_list
def all_to_list(all_sum, llm, chunk_size):
summary_list = split_text(all_sum, llm, chunk_size)
len_chunks = [get_num_tokens(chunk, llm) for chunk in summary_list]
print(f'len_chunks: {len_chunks}')
print(f'total parts: {len(summary_list)}')
return summary_list
def clean_output(text):
text = text.replace('`','')
text = re.sub(r'\d+\.', '', text) # removes numeric bullet points
text = text.replace('- ',' ')
text = text.replace('*','')
text = text.replace('+','')
return text
def get_content_length(messages, llm):
# print(messages)
# user_list=[m for m in messages if m['role']=='user']
# assistant_list=[m for m in messages if m['role']=='assistant']
system_list=[m for m in messages if m['role']=='system']
# print(f'system: {system_list}')
content_total=system_list[0]['content']
for i, (m) in enumerate(messages[1:]):
content_total+=m['content']
return get_num_tokens(content_total, llm)
def pop_first_user_assistant(messages):
new_messages=[entry for i, entry in enumerate(messages) if i not in [1,2]]
return new_messages
def get_num_tokens(text, llm):
bytes_string = text.encode('utf-8')
tokens = llm.tokenize(bytes_string)
return len(tokens)
def response_stream():
global writer_messages, editor_messages, turn
if turn=='writer':
yield
else:
yield
def adverse(message, history):
global writer_messages, editor_messages, turn
total_response = ''
for i in range(4):
# update writer_messages
if len(writer_messages)==1: # first call
writer_messages.append({
'role':'user',
'content':message,
})
# check whose turn it is
turn = 'writer' if len(writer_messages)%2==0 else 'editor'
list_of_messages = writer_messages if turn=='writer' else editor_messages
print(f'turn: {turn}\n\nlist_of_messages: {list_of_messages}')
total_response+=f'\n\n\nturn: {turn}\n'
#############################
# call llm.create_chat_completion for whoever's turn
# response_iter
# response_str = f'writer {len(writer_messages)}' if turn=='writer' else f'editor {len(editor_messages)}'
# response_iter = iter(response_str.split(' '))
# response_iter = llm.create_chat_completion(
# list_of_messages, # Prompt
# max_tokens=-1,
# stop=["###"],
# stream=True
# )
response_iter = client.chat.completions.create(
model=model,
messages=list_of_messages,
stream=True,
)
response=''
for chunk in response_iter:
try:
response+=chunk.choices[0].delta.content
total_response+=chunk.choices[0].delta.content
time.sleep(1)
# print(f'chunk: {chunk}')
yield total_response
except Exception as e:
print(e)
total_response+='\n\n'
if turn=='editor':
response+='\nNow rewrite your response keeping my suggestions in mind.\n'
#############################
# update writer_messages and editor_messages
if turn=='writer':
writer_messages.append({
'role':'assistant',
'content':response,
})
editor_messages.append({
'role':'user',
'content':response,
})
else: # editor
writer_messages.append({
'role':'user',
'content':response,
})
editor_messages.append({
'role':'assistant',
'content':response,
})
max_tokens=4_000
chunk_size=1000
max_words = 10_000
print(f'max_words: {max_words}')
# llm = Llama(model_path="E:\\yt\\bookSummary/Snorkel-Mistral-PairRM-DPO/snorkel-mistral-pairrm-dpo.Q4_K_M.gguf", chat_format="chatml", n_gpu_layers=-1, n_ctx=6000)
writer_system_prompt_unformatted = '''You are a writer.
The topic is {topic}.'''
editor_system_prompt_unformatted = '''You are an editor. You give feedback about my writing.
The topic is {topic}.
You should reinforce me to adjust my content and improve it.
You should push me to make my response as close as possible to the topic.'''
writer_messages = [{'role':'system','content':writer_system_prompt_unformatted}]
editor_messages = [{'role':'system','content':editor_system_prompt_unformatted}]
turn = 'writer'
def set_system_prompts(x):
global writer_system_prompt, editor_system_prompt, writer_messages, editor_messages, writer_system_prompt_unformatted, editor_system_prompt_unformatted
writer_system_prompt = writer_system_prompt_unformatted.format(topic=x)
editor_system_prompt = editor_system_prompt_unformatted.format(topic=x)
writer_messages = [{'role':'system','content':writer_system_prompt}]
editor_messages = [{'role':'system','content':editor_system_prompt}]
return f'writer system prompt:\n{writer_system_prompt}\n\neditor system prompt:\n{editor_system_prompt}'
with gr.Blocks() as demo:
gr.Markdown(
"""
# Multi Agent LLMs for End-to-End Story Generation
Start typing below to see the output.
""")
gr.Interface(
fn=set_system_prompts,
inputs=gr.Textbox(placeholder="What is the topic?", label = 'Topic', lines=4),
outputs=gr.Textbox(label='System prompt to use', lines=4)
)
out_test = gr.Textbox(lines=4)
button = gr.Button("test")
button.click(lambda : f"{writer_system_prompt} \n\n\n{editor_system_prompt}", outputs=out_test)
chat = gr.ChatInterface(
fn=adverse,
examples=["Start the story", "Write a poem", 'The funniest joke ever!'],
title="Multi-Agent Bot",
autofocus=False,
fill_height=True,
).queue()
if __name__ == "__main__":
demo.launch()