File size: 4,147 Bytes
22007d6 c3cdc2f ee47fb7 c3cdc2f 22007d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
import pandas as pd
import pickle
# Load the model and encoder and scaler
model = pickle.load(open("model.pkl", "rb"))
encoder = pickle.load(open("encoder.pkl", "rb"))
scaler = pickle.load(open("scaler.pkl", "rb"))
# Load the data
data = pd.read_csv('data.csv')
# Define the input and output interfaces for the Gradio app
def create_gradio_inputs(data):
input_components = []
for column in data.columns:
if data[column].dtype == 'object' and len(data[column].unique()) > 3:
input_components.append(gr.Dropdown(choices=list(data[column].unique()), label=column))
elif data[column].dtype == 'object' and len(data[column].unique()) <= 3:
input_components.append(gr.Radio(choices=list(data[column].unique()), label=column))
elif data[column].dtype in ['int64', 'float64']:
if data[column].min() == 1:
input_components.append(gr.Slider(minimum=1, maximum=data[column].max(), step=1, label=column))
else:
input_components.append(gr.Slider(maximum=data[column].max(), step=0.5, label=column))
return input_components
input_components = create_gradio_inputs(data)
output_components = [
gr.Label(label="Churn Prediction"),
]
# Convert the input values to a pandas DataFrame with the appropriate column names
def input_df_creator(gender, SeniorCitizen, Partner, Dependents, tenure,
PhoneService, InternetService, OnlineBackup, TechSupport,
Contract, PaperlessBilling, PaymentMethod, MonthlyCharges,
TotalCharges, StreamingService, SecurityService):
input_data = pd.DataFrame({
"gender": [gender],
"SeniorCitizen": [SeniorCitizen],
"Partner": [Partner],
"Dependents": [Dependents],
"tenure": [int(tenure)],
"PhoneService": [PhoneService],
"InternetService": [InternetService],
"OnlineBackup": [OnlineBackup],
"TechSupport": [TechSupport],
"Contract": [Contract],
"PaperlessBilling": [PaperlessBilling],
"PaymentMethod": [PaymentMethod],
"StreamingService": [StreamingService],
"SecurityService": [SecurityService],
"MonthlyCharges": [float(MonthlyCharges)],
"TotalCharges": [float(TotalCharges)],
})
return input_data
# Define the function to be called when the Gradio app is run
def predict_churn(gender, SeniorCitizen, Partner, Dependents, tenure,
PhoneService, InternetService, OnlineBackup, TechSupport,
Contract, PaperlessBilling, PaymentMethod, MonthlyCharges,
TotalCharges, StreamingService, SecurityService):
input_df = input_df_creator(gender, SeniorCitizen, Partner, Dependents, tenure,
PhoneService, InternetService, OnlineBackup, TechSupport,
Contract, PaperlessBilling, PaymentMethod, MonthlyCharges,
TotalCharges, StreamingService, SecurityService)
# Encode categorical variables
cat_cols = data.select_dtypes(include=['object']).columns
cat_encoded = encoder.transform(input_df[cat_cols])
# Scale numerical variables
num_cols = data.select_dtypes(include=['int64', 'float64']).columns
num_scaled = scaler.transform(input_df[num_cols])
# joining encoded and scaled columns back together
processed_df = pd.concat([num_scaled, cat_encoded], axis=1)
# Make prediction
prediction = model.predict(processed_df)
return "Churn" if prediction[0] == 1 else "No Churn"
# Add Image Tags
gr.Markdown(
"""
<div class="row">
<div class="column">
<img src="https://user-images.githubusercontent.com/115732734/271723332-6c824e95-5e2f-48ec-af1c-b66ac7db1d7a.jpeg" style="width:550"></div>
<div class="column">
<img src="https://user-images.githubusercontent.com/115732734/271723345-50f27ca9-94ee-4e7c-ad3b-2b10f27d31bb.jpeg" style="width:550"></div>
</div>
<style>
.row {
display: flex;
}
.column {
flex: 33.33%;
padding: 5px;
}
</style>
""",
unsafe_allow_html=True
)
# Launch the Gradio app
iface = gr.Interface(predict_churn, inputs=input_components, outputs=output_components)
iface.launch(inbrowser= True, show_error= True)
|