Spaces:
Sleeping
Sleeping
import gradio as gr | |
import random | |
import time | |
import csv | |
import os | |
import pandas as pd | |
from openai import OpenAI | |
import numpy as np | |
import ast | |
import json | |
# ユーザーの諸々の情報やプロンプトを受け取り返答する | |
def search_teacher(job_type, job_start_dates, tool_function_arguments): | |
user_request = ', '.join(list(tool_function_arguments.values())[:-2]) | |
user_message_embedded_vector = \ | |
client.embeddings.create(input=[user_request.replace("\n", " ")], model='text-embedding-3-small').data[ | |
0].embedding | |
data = pd.read_csv(os.environ['DATA_PATH']) | |
# job_type の日本語を英語に変換 | |
if job_type == '常勤': | |
job_type = 'full time' | |
elif job_type == '非常勤': | |
job_type = 'part time' | |
# job_start_dates の日本語を英語に変換 | |
job_start_date_translation = { | |
'今年度': 'This Year', | |
'来年度': 'Next Year', | |
'来来年度': 'Year After Next' | |
} | |
# job_start_dates が文字列の場合リストに変換(単一選択の場合を考慮) | |
if isinstance(job_start_dates, str): | |
job_start_dates = [job_start_dates] | |
# 英語に変換 | |
job_start_dates = [job_start_date_translation.get(date, date) for date in job_start_dates] | |
# job_type と job_start_dates でフィルタリング | |
filtered_data = data[(data['job_type'] == job_type) & (data['job_start_date'].isin(job_start_dates))].copy() | |
result = [] | |
for index, row in filtered_data.iterrows(): | |
teacher_embedded_vector = ast.literal_eval(row['embedding']) | |
similarity = np.dot(user_message_embedded_vector, teacher_embedded_vector) / ( | |
np.linalg.norm(user_message_embedded_vector) * np.linalg.norm(teacher_embedded_vector)) | |
if len(result) < 3: | |
result.append((index, similarity)) | |
result.sort(key=lambda x: x[1], reverse=True) | |
else: | |
if similarity > result[-1][1]: | |
result[-1] = (index, similarity) | |
result.sort(key=lambda x: x[1], reverse=True) | |
formatted_result = [] | |
for index, similarity in result: | |
name = filtered_data.loc[index, 'name'] | |
temp = f"{name}: {similarity}" | |
formatted_result.append(temp) | |
return ', '.join(formatted_result) | |
def openai_api(job_type, job_start_dates, history): | |
# GPTにユーザーの入力を送信 | |
message = client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=history[-1][0] | |
) | |
# 送信した入力を実行 | |
run = client.beta.threads.runs.create( | |
thread_id=thread.id, | |
assistant_id=assistant.id, | |
) | |
while True: | |
run = client.beta.threads.runs.retrieve( | |
thread_id=thread.id, | |
run_id=run.id | |
) | |
if run.status == 'completed': | |
break | |
elif run.status == 'requires_action': | |
tool_id = run.required_action.submit_tool_outputs.tool_calls[0].id | |
tool_function_arguments = json.loads( | |
run.required_action.submit_tool_outputs.tool_calls[0].function.arguments) | |
tool_function_output = search_teacher(job_type, job_start_dates, tool_function_arguments) | |
run = client.beta.threads.runs.submit_tool_outputs( | |
thread_id=thread.id, | |
run_id=run.id, | |
tool_outputs=[ | |
{ | |
"tool_call_id": tool_id, | |
"output": tool_function_output, | |
} | |
] | |
) | |
time.sleep(3) | |
time.sleep(0.5) | |
messages = client.beta.threads.messages.list( | |
thread_id=thread.id, | |
order="asc" | |
) | |
return messages.data[-1].content[0].text.value | |
def user(user_job_type, user_job_start_date, user_message, history): | |
return None, history + [[user_message, None]] | |
def bot(job_type, job_start_date, history): | |
prompt = "" | |
# apiを叩くためにデータを加工するなりする | |
for chat in history[:-1]: | |
prompt += '"' + chat[0] + '", "' + chat[1] + '"' | |
bot_message = openai_api(job_type, job_start_date, history) | |
history[-1][1] = "" | |
for character in bot_message: | |
history[-1][1] += character | |
time.sleep(0.01) | |
yield history | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
global assistant | |
global thread | |
global client | |
client = OpenAI(max_retries=5) | |
assistant = client.beta.assistants.create( | |
name='connpath_demo', | |
instructions=os.environ['INSTRUCTIONS'], | |
model="gpt-4-0125-preview", | |
tools=[{ | |
"type": "function", | |
"function": { | |
"name": "connpath_demo_gpt", | |
"description": "検索を支援する", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"educational_goals": {"type": "string", "description": "教育の目的"}, | |
"student_profile": {"type": "string", "description": "対象のプロフィール"}, | |
"required_skills_experience": {"type": "string", "description": "求められるスキル"}, | |
"teaching_method_environment": {"type": "string", "description": "教育する環境"}, | |
"evaluation_feedback": {"type": "string", "description": "教育の評価方法"} | |
}, | |
"required": ["educational_goals", "student_profile", "required_skills_experience", | |
"teaching_method_environment", "evaluation_feedback"] | |
} | |
} | |
}] | |
) | |
thread = client.beta.threads.create() | |
# これがuser_job_typeを保持 (str型) | |
job_type = gr.Radio(["常勤", "非常勤"], | |
label="Job Type", | |
info="探している雇用形態について") | |
# これがuser_job_start_dateを保持 (list型) | |
job_start_date = gr.CheckboxGroup(["今年度", "来年度", "来来年度"], | |
label="Start Date", | |
info="探している就業時期について") | |
# これがuser_messageを保持 | |
msg = gr.Textbox(label="input message", | |
info="課題について教えてください(例:国語の先生を探しています)") | |
# これがhistoryを保持 | |
chatbot = gr.Chatbot(show_copy_button=True) | |
clear = gr.Button("clear history") | |
msg.submit(user, | |
[job_type, job_start_date, msg, chatbot], | |
[msg, chatbot], | |
queue=True | |
).then(bot, [job_type, job_start_date, chatbot], chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch(auth=(os.environ['USER_NAME'], os.environ['PASSWORD'])) |