databricks_rag / app.py
taka-yayoi's picture
Update app.py
148dcd4 verified
import itertools
import gradio as gr
import requests
import os
from gradio.themes.utils import sizes
import json
import pandas as pd
import base64
import io
from PIL import Image
import numpy as np
def respond(message, history):
if len(message.strip()) == 0:
return "質問を入力してください"
local_token = os.getenv('API_TOKEN')
local_endpoint = os.getenv('API_ENDPOINT')
if local_token is None or local_endpoint is None:
return "ERROR missing env variables"
# Add your API token to the headers
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {local_token}'
}
#prompt = list(itertools.chain.from_iterable(history))
#prompt.append(message)
# プロンプトの作成
prompt = pd.DataFrame(
{"query": [message]}
)
print(prompt)
ds_dict = {"dataframe_split": prompt.to_dict(orient="split")}
data_json = json.dumps(ds_dict, allow_nan=True)
try:
# モデルサービングエンドポイントに問い合わせ
response = requests.request(method="POST", headers=headers, url=local_endpoint, data=data_json)
response_data = response.json()
print(response_data)
except Exception as error:
response_data = f"ERROR status_code: {type(error).__name__}"
#+ str(response.status_code) + " response:" + response.text
return response_data["predictions"][0]
theme = gr.themes.Soft(
text_size=sizes.text_sm,radius_size=sizes.radius_sm, spacing_size=sizes.spacing_sm,
)
demo = gr.ChatInterface(
respond,
chatbot=gr.Chatbot(show_label=False, container=False, show_copy_button=True, bubble_full_width=True),
textbox=gr.Textbox(placeholder="質問を入力してください",
container=False, scale=7),
title="Databricks QAチャットボット",
description="TBD",
examples=[["Databricksクラスターとは?"],
["Unity Catalogの有効化方法"],
["リネージの保持期間"],],
cache_examples=False,
theme=theme,
retry_btn=None,
undo_btn=None,
clear_btn="Clear",
)
if __name__ == "__main__":
demo.launch()