Shaltiel's picture
Added rtl + extended
b08fdd4
raw
history blame
2.35 kB
import os
import gradio as gr
from http import HTTPStatus
import openai
from typing import Generator, List, Optional, Tuple, Dict
from urllib.error import HTTPError
API_URL = os.getenv('API_URL')
API_KEY = os.getenv('API_KEY')
oai_client = openai.OpenAI(api_key=API_KEY, base_url=API_URL)
History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]
def clear_session() -> History:
return '', []
def history_to_messages(history: History) -> Messages:
messages = []
for h in history:
messages.append({'role': 'user', 'content': h[0]})
messages.append({'role': 'assistant', 'content': h[1]})
return messages
def messages_to_history(messages: Messages) -> Tuple[str, History]:
history = []
for q, r in zip(messages[0::2], messages[1::2]):
history.append([q['content'], r['content']])
return history
def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]:
if query is None:
query = ''
if history is None:
history = []
messages = history_to_messages(history)
messages.append({'role': 'user', 'content': query})
gen = oai_client.chat.completions.create(
model='dicta-il/dictalm2.0-instruct',
messages=messages,
temperature=0.7,
max_tokens=1024,
top_p=0.9,
stream=True
)
messages.append({'role': 'assistant', 'content': ''})
for completion in gen:
print(completion)
text = completion.choices[0].delta.content
if not text: continue
messages[-1]['content'] += text
history = messages_to_history(messages)
yield '', history
with gr.Blocks() as demo:
gr.Markdown("""<center><font size=8>DictaLM2.0-Instruct Chat Demo</center>""")
chatbot = gr.Chatbot(label='dicta-il/dictalm2.0-instruct', rtl=True)
textbox = gr.Textbox(lines=2, label='Input')
with gr.Row():
clear_history = gr.Button("🧹 Clear history")
sumbit = gr.Button("🚀 Send")
sumbit.click(model_chat,
inputs=[textbox, chatbot],
outputs=[textbox, chatbot])
clear_history.click(fn=clear_session,
inputs=[],
outputs=[textbox, chatbot])
demo.queue(api_open=False).launch(max_threads=10,height=800, share=False)