tax-free's picture
Update app.py
f0623c6 verified
import gradio as gr
import random
import time
import csv
import os
import pandas as pd
from openai import OpenAI
import numpy as np
import ast
import json
# ユーザーの諸々の情報やプロンプトを受け取り返答する
def search_teacher(job_type, job_start_dates, tool_function_arguments):
user_request = ', '.join(list(tool_function_arguments.values())[:-2])
user_message_embedded_vector = \
client.embeddings.create(input=[user_request.replace("\n", " ")], model='text-embedding-3-small').data[
0].embedding
data = pd.read_csv(os.environ['DATA_PATH'])
# job_type の日本語を英語に変換
if job_type == '常勤':
job_type = 'full time'
elif job_type == '非常勤':
job_type = 'part time'
# job_start_dates の日本語を英語に変換
job_start_date_translation = {
'今年度': 'This Year',
'来年度': 'Next Year',
'来来年度': 'Year After Next'
}
# job_start_dates が文字列の場合リストに変換(単一選択の場合を考慮)
if isinstance(job_start_dates, str):
job_start_dates = [job_start_dates]
# 英語に変換
job_start_dates = [job_start_date_translation.get(date, date) for date in job_start_dates]
# job_type と job_start_dates でフィルタリング
filtered_data = data[(data['job_type'] == job_type) & (data['job_start_date'].isin(job_start_dates))].copy()
result = []
for index, row in filtered_data.iterrows():
teacher_embedded_vector = ast.literal_eval(row['embedding'])
similarity = np.dot(user_message_embedded_vector, teacher_embedded_vector) / (
np.linalg.norm(user_message_embedded_vector) * np.linalg.norm(teacher_embedded_vector))
if len(result) < 3:
result.append((index, similarity))
result.sort(key=lambda x: x[1], reverse=True)
else:
if similarity > result[-1][1]:
result[-1] = (index, similarity)
result.sort(key=lambda x: x[1], reverse=True)
formatted_result = []
for index, similarity in result:
name = filtered_data.loc[index, 'name']
temp = f"{name}: {similarity}"
formatted_result.append(temp)
return ', '.join(formatted_result)
def openai_api(job_type, job_start_dates, history):
# GPTにユーザーの入力を送信
message = client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=history[-1][0]
)
# 送信した入力を実行
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
)
while True:
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id
)
if run.status == 'completed':
break
elif run.status == 'requires_action':
tool_id = run.required_action.submit_tool_outputs.tool_calls[0].id
tool_function_arguments = json.loads(
run.required_action.submit_tool_outputs.tool_calls[0].function.arguments)
tool_function_output = search_teacher(job_type, job_start_dates, tool_function_arguments)
run = client.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run.id,
tool_outputs=[
{
"tool_call_id": tool_id,
"output": tool_function_output,
}
]
)
time.sleep(3)
time.sleep(0.5)
messages = client.beta.threads.messages.list(
thread_id=thread.id,
order="asc"
)
return messages.data[-1].content[0].text.value
def user(user_job_type, user_job_start_date, user_message, history):
return None, history + [[user_message, None]]
def bot(job_type, job_start_date, history):
prompt = ""
# apiを叩くためにデータを加工するなりする
for chat in history[:-1]:
prompt += '"' + chat[0] + '", "' + chat[1] + '"'
bot_message = openai_api(job_type, job_start_date, history)
history[-1][1] = ""
for character in bot_message:
history[-1][1] += character
time.sleep(0.01)
yield history
with gr.Blocks(theme=gr.themes.Soft()) as demo:
global assistant
global thread
global client
client = OpenAI(max_retries=5)
assistant = client.beta.assistants.create(
name='connpath_demo',
instructions=os.environ['INSTRUCTIONS'],
model="gpt-4-0125-preview",
tools=[{
"type": "function",
"function": {
"name": "connpath_demo_gpt",
"description": "検索を支援する",
"parameters": {
"type": "object",
"properties": {
"educational_goals": {"type": "string", "description": "教育の目的"},
"student_profile": {"type": "string", "description": "対象のプロフィール"},
"required_skills_experience": {"type": "string", "description": "求められるスキル"},
"teaching_method_environment": {"type": "string", "description": "教育する環境"},
"evaluation_feedback": {"type": "string", "description": "教育の評価方法"}
},
"required": ["educational_goals", "student_profile", "required_skills_experience",
"teaching_method_environment", "evaluation_feedback"]
}
}
}]
)
thread = client.beta.threads.create()
# これがuser_job_typeを保持 (str型)
job_type = gr.Radio(["常勤", "非常勤"],
label="Job Type",
info="探している雇用形態について")
# これがuser_job_start_dateを保持 (list型)
job_start_date = gr.CheckboxGroup(["今年度", "来年度", "来来年度"],
label="Start Date",
info="探している就業時期について")
# これがuser_messageを保持
msg = gr.Textbox(label="input message",
info="課題について教えてください(例:国語の先生を探しています)")
# これがhistoryを保持
chatbot = gr.Chatbot(show_copy_button=True)
clear = gr.Button("clear history")
msg.submit(user,
[job_type, job_start_date, msg, chatbot],
[msg, chatbot],
queue=True
).then(bot, [job_type, job_start_date, chatbot], chatbot)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch(auth=(os.environ['USER_NAME'], os.environ['PASSWORD']))