Spaces:
Running
Running
import gradio as gr | |
import time | |
from openai import OpenAI | |
import re | |
import os | |
client = OpenAI(api_key=os.environ.get("openai")) | |
model="gpt-3.5-turbo" | |
# llm = Llama(model_path="./snorkel-mistral-pairrm-dpo.Q4_K_M.gguf", | |
# chat_format="chatml", | |
# n_gpu_layers=0, # cpu only | |
# n_ctx=6000) | |
def split_text(text, llm, chunk_size): | |
text_newline = text.split('\n') | |
text_newline = [t for t in text_newline if len(t)>0] | |
summary_list=[] | |
new_list=[] | |
for i, t in enumerate(text_newline): | |
new_list.append(t) | |
n_tokens=get_num_tokens('\n\n\n'.join(new_list), llm) | |
if i==(len(text_newline)-1): | |
summary_list.append('\n\n'.join(new_list)) | |
elif n_tokens>chunk_size: | |
summary_list.append('\n\n'.join(new_list)) | |
new_list=[] | |
return summary_list | |
def all_to_list(all_sum, llm, chunk_size): | |
summary_list = split_text(all_sum, llm, chunk_size) | |
len_chunks = [get_num_tokens(chunk, llm) for chunk in summary_list] | |
print(f'len_chunks: {len_chunks}') | |
print(f'total parts: {len(summary_list)}') | |
return summary_list | |
def clean_output(text): | |
text = text.replace('`','') | |
text = re.sub(r'\d+\.', '', text) # removes numeric bullet points | |
text = text.replace('- ',' ') | |
text = text.replace('*','') | |
text = text.replace('+','') | |
return text | |
def get_content_length(messages, llm): | |
# print(messages) | |
# user_list=[m for m in messages if m['role']=='user'] | |
# assistant_list=[m for m in messages if m['role']=='assistant'] | |
system_list=[m for m in messages if m['role']=='system'] | |
# print(f'system: {system_list}') | |
content_total=system_list[0]['content'] | |
for i, (m) in enumerate(messages[1:]): | |
content_total+=m['content'] | |
return get_num_tokens(content_total, llm) | |
def pop_first_user_assistant(messages): | |
new_messages=[entry for i, entry in enumerate(messages) if i not in [1,2]] | |
return new_messages | |
def get_num_tokens(text, llm): | |
bytes_string = text.encode('utf-8') | |
tokens = llm.tokenize(bytes_string) | |
return len(tokens) | |
def response_stream(): | |
global writer_messages, editor_messages, turn | |
if turn=='writer': | |
yield | |
else: | |
yield | |
def adverse(message, history): | |
global writer_messages, editor_messages, turn | |
total_response = '' | |
for i in range(4): | |
# update writer_messages | |
if len(writer_messages)==1: # first call | |
writer_messages.append({ | |
'role':'user', | |
'content':message, | |
}) | |
# check whose turn it is | |
turn = 'writer' if len(writer_messages)%2==0 else 'editor' | |
list_of_messages = writer_messages if turn=='writer' else editor_messages | |
print(f'turn: {turn}\n\nlist_of_messages: {list_of_messages}') | |
total_response+=f'\n\n\nturn: {turn}\n' | |
############################# | |
# call llm.create_chat_completion for whoever's turn | |
# response_iter | |
# response_str = f'writer {len(writer_messages)}' if turn=='writer' else f'editor {len(editor_messages)}' | |
# response_iter = iter(response_str.split(' ')) | |
# response_iter = llm.create_chat_completion( | |
# list_of_messages, # Prompt | |
# max_tokens=-1, | |
# stop=["###"], | |
# stream=True | |
# ) | |
response_iter = client.chat.completions.create( | |
model=model, | |
messages=list_of_messages, | |
stream=True, | |
) | |
response='' | |
for chunk in response_iter: | |
try: | |
response+=chunk.choices[0].delta.content | |
total_response+=chunk.choices[0].delta.content | |
time.sleep(1) | |
# print(f'chunk: {chunk}') | |
yield total_response | |
except Exception as e: | |
print(e) | |
total_response+='\n\n' | |
if turn=='editor': | |
response+='\nNow rewrite your response keeping my suggestions in mind.\n' | |
############################# | |
# update writer_messages and editor_messages | |
if turn=='writer': | |
writer_messages.append({ | |
'role':'assistant', | |
'content':response, | |
}) | |
editor_messages.append({ | |
'role':'user', | |
'content':response, | |
}) | |
else: # editor | |
writer_messages.append({ | |
'role':'user', | |
'content':response, | |
}) | |
editor_messages.append({ | |
'role':'assistant', | |
'content':response, | |
}) | |
max_tokens=4_000 | |
chunk_size=1000 | |
max_words = 10_000 | |
print(f'max_words: {max_words}') | |
# llm = Llama(model_path="E:\\yt\\bookSummary/Snorkel-Mistral-PairRM-DPO/snorkel-mistral-pairrm-dpo.Q4_K_M.gguf", chat_format="chatml", n_gpu_layers=-1, n_ctx=6000) | |
writer_system_prompt_unformatted = '''You are a writer. | |
The topic is {topic}.''' | |
editor_system_prompt_unformatted = '''You are an editor. You give feedback about my writing. | |
The topic is {topic}. | |
You should reinforce me to adjust my content and improve it. | |
You should push me to make my response as close as possible to the topic.''' | |
writer_messages = [{'role':'system','content':writer_system_prompt_unformatted}] | |
editor_messages = [{'role':'system','content':editor_system_prompt_unformatted}] | |
turn = 'writer' | |
def set_system_prompts(x): | |
global writer_system_prompt, editor_system_prompt, writer_messages, editor_messages, writer_system_prompt_unformatted, editor_system_prompt_unformatted | |
writer_system_prompt = writer_system_prompt_unformatted.format(topic=x) | |
editor_system_prompt = editor_system_prompt_unformatted.format(topic=x) | |
writer_messages = [{'role':'system','content':writer_system_prompt}] | |
editor_messages = [{'role':'system','content':editor_system_prompt}] | |
return f'writer system prompt:\n{writer_system_prompt}\n\neditor system prompt:\n{editor_system_prompt}' | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Multi Agent LLMs for End-to-End Story Generation | |
Start typing below to see the output. | |
""") | |
gr.Interface( | |
fn=set_system_prompts, | |
inputs=gr.Textbox(placeholder="What is the topic?", label = 'Topic', lines=4), | |
outputs=gr.Textbox(label='System prompt to use', lines=4) | |
) | |
out_test = gr.Textbox(lines=4) | |
button = gr.Button("test") | |
button.click(lambda : f"{writer_system_prompt} \n\n\n{editor_system_prompt}", outputs=out_test) | |
chat = gr.ChatInterface( | |
fn=adverse, | |
examples=["Start the story", "Write a poem", 'The funniest joke ever!'], | |
title="Multi-Agent Bot", | |
autofocus=False, | |
fill_height=True, | |
).queue() | |
if __name__ == "__main__": | |
demo.launch() | |