Spaces:
Sleeping
Sleeping
# type: ignore -- ignores linting import issues when using multiple virtual environments | |
import streamlit.components.v1 as components | |
import streamlit as st | |
import pandas as pd | |
import logging | |
from deeploy import Client | |
# reset Plotly theme after streamlit import | |
import plotly.io as pio | |
pio.templates.default = "plotly" | |
logging.basicConfig(level=logging.INFO) | |
st.set_page_config(layout="wide") | |
st.title("Your title") | |
def get_model_url(): | |
"""Function to get Deeploy model URL and split it into workspace and deployment ID.""" | |
model_url = st.text_area( | |
"Model URL (without the /explain endpoint, default is the demo deployment)", | |
"https://api.app.deeploy.ml/workspaces/708b5808-27af-461a-8ee5-80add68384c7/deployments/9155091a-0abb-45b3-8b3b-24ac33fa556b/", | |
height=125, | |
) | |
elems = model_url.split("/") | |
try: | |
workspace_id = elems[4] | |
deployment_id = elems[6] | |
except IndexError: | |
workspace_id = "" | |
deployment_id = "" | |
return model_url, workspace_id, deployment_id | |
def ChangeButtonColour(widget_label, font_color, background_color="transparent"): | |
"""Function to change the color of a button (after it is defined).""" | |
htmlstr = f""" | |
<script> | |
var elements = window.parent.document.querySelectorAll('button'); | |
for (var i = 0; i < elements.length; ++i) {{ | |
if (elements[i].innerText == '{widget_label}') {{ | |
elements[i].style.color ='{font_color}'; | |
elements[i].style.background = '{background_color}' | |
}} | |
}} | |
</script> | |
""" | |
components.html(f"{htmlstr}", height=0, width=0) | |
def predict_callback(): | |
with st.spinner("Loading prediction..."): | |
try: | |
print("Request body: ", request_body) | |
# Call the explain endpoint as it also includes the prediction | |
pred = client.predict( | |
request_body=st.session_state.request_body, deployment_id=deployment_id | |
) | |
except Exception as e: | |
logging.error(e) | |
st.error( | |
"Failed to get prediction." | |
+ "Check whether you are using the right model URL and token for predictions. " | |
+ "Contact Deeploy if the problem persists." | |
) | |
return | |
st.session_state.pred = pred | |
st.session_state.evaluation_submitted = False | |
def submit_and_clear(evaluation: str): | |
if evaluation == "yes": | |
st.session_state.evaluation_input["result"] = 0 # Agree with the prediction | |
else: | |
# Disagree with the prediction | |
st.session_state.evaluation_input["result"] = 1 | |
# In binary classification problems we can just flip the prediction | |
desired_output = not predictions[0] | |
st.session_state.evaluation_input["value"] = {"predictions": [desired_output]} | |
try: | |
# Call the explain endpoint as it also includes the prediction | |
client.evaluate( | |
deployment_id, request_log_id, prediction_log_id, st.session_state.evaluation_input | |
) | |
st.session_state.evaluation_submitted = True | |
st.session_state.pred = None | |
except Exception as e: | |
logging.error(e) | |
st.error( | |
"Failed to submit feedback." | |
+ "Check whether you are using the right model URL and token for evaluations. " | |
+ "Contact Deeploy if the problem persists." | |
) | |
# Define defaults for the session state | |
if "pred" not in st.session_state: | |
st.session_state.pred = None | |
if "evaluation_submitted" not in st.session_state: | |
st.session_state.evaluation_submitted = False | |
# Define sidebar for configuration of Deeploy connection | |
with st.sidebar: | |
st.image("deeploy_logo_wide.png", width=250) | |
# Ask for model URL and token | |
host = st.text_input("Host (Changing is optional)", "app.deeploy.ml") | |
model_url, workspace_id, deployment_id = get_model_url() | |
deployment_token = st.text_input("Deeploy API token", "my-secret-token") | |
if deployment_token == "my-secret-token": | |
st.warning("Please enter Deeploy API token.") | |
# In case you need to debug the workspace and deployment ID: | |
# st.write("Values below are for debug only:") | |
# st.write("Workspace ID: ", workspace_id) | |
# st.write("Deployment ID: ", deployment_id) | |
client_options = { | |
"host": host, | |
"deployment_token": deployment_token, | |
"workspace_id": workspace_id, | |
} | |
client = Client(**client_options) | |
# For debugging the session state you can uncomment the following lines: | |
# with st.expander("Debug session state", expanded=False): | |
# st.write(st.session_state) | |
# Input (for IRIS dataset) | |
# with st.expander("Input values for prediction", expanded=True): | |
# st.write("Please input the values for the model.") | |
# col1, col2 = st.columns(2) | |
# with col1: | |
# sep_len = st.number_input("Sepal length", value=1.0, step=0.1, key="Sepal length") | |
# sep_wid = st.number_input("Sepal width", value=1.0, step=0.1, key="Sepal width") | |
# with col2: | |
# pet_len = st.number_input("Petal length", value=1.0, step=0.1, key="Petal length") | |
# pet_wid = st.number_input("Petal width", value=1.0, step=0.1, key="Petal width") | |
request_body = { | |
"instances": [ | |
[ | |
20, | |
"RH", | |
80, | |
11622, | |
"Pave", | |
"missing", | |
"Reg", | |
"Lvl", | |
"AllPub", | |
"Inside", | |
"Gtl", | |
"NAmes", | |
"Feedr", | |
"Norm", | |
"1Fam", | |
"1Story", | |
5, | |
6, | |
1961, | |
1961, | |
"Gable", | |
"CompShg", | |
"VinylSd", | |
"VinylSd", | |
"NA", | |
0, | |
"TA", | |
"TA", | |
"CBlock", | |
"TA", | |
"TA", | |
"No", | |
"Rec", | |
468, | |
"LwQ", | |
144, | |
270, | |
882, | |
"GasA", | |
"TA", | |
"Y", | |
"SBrkr", | |
896, | |
0, | |
0, | |
896, | |
0, | |
0, | |
1, | |
0, | |
2, | |
1, | |
"TA", | |
5, | |
"Typ", | |
0, | |
"missing", | |
"Attchd", | |
1961, | |
"Unf", | |
1, | |
730, | |
"TA", | |
"TA", | |
"Y", | |
140, | |
0, | |
0, | |
0, | |
120, | |
0, | |
"missing", | |
"MnPrv", | |
"missing", | |
0, | |
6, | |
2010, | |
"WD", | |
"Normal" | |
] | |
] | |
} | |
st.session_state.request_body = request_body | |
# Predict and explain | |
predict_button = st.button("Predict", on_click=predict_callback) | |
if st.session_state.pred is not None: | |
st.write(st.session_state.pred) | |