Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from models.neural_network.inference import load_model_and_preprocessor | |
# Load the pre-trained model and preprocessor | |
nn_model, nn_preprocessor = load_model_and_preprocessor('nn_model.keras', 'nn_preprocessor.pkl') | |
# Load the unique aircraft data and airport distances | |
aircraft_data = pd.read_csv('aircraft_data.csv').drop_duplicates(subset='model') | |
airport_data = pd.read_csv('airport_distances.csv') | |
def predict_fuel_burn(model_name, origin, destination, seats, distance): | |
# Validate the distance against seats | |
max_seats = aircraft_dict[model_name]['seats'] | |
if seats > max_seats: | |
return f"The {model_name} aircraft has a maximum of {max_seats} seats." | |
if seats <= 0: | |
return "The number of seats must be greater than 0." | |
if distance <= 0: | |
return "The distance must be greater than 0." | |
# Prepare the input data for the model | |
data = { | |
'model': [model_name], | |
'Origin_Airport': [origin], | |
'Destination_Airport': [destination], | |
'seats': [seats], | |
'distance': [distance], | |
'J/T': [aircraft_dict[model_name]['J/T']], | |
'CAT': [aircraft_dict[model_name]['CAT']], | |
'_Manufacturer': [aircraft_dict[model_name]['_Manufacturer']], | |
'dist': [distance] | |
} | |
df = pd.DataFrame(data) | |
# Make the prediction | |
fuel_burn_prediction_nn = nn_model.predict(nn_preprocessor.transform(df))[0] | |
return f" {fuel_burn_prediction_nn[0]:.2f} kg" | |
def update_fields(model_name): | |
return { | |
jt: gr.update(value=aircraft_dict[model_name]['J/T']), | |
cat: gr.update(value=aircraft_dict[model_name]['CAT']), | |
manufacturer: gr.update(value=aircraft_dict[model_name]['_Manufacturer']) | |
} | |
def update_destination_options(origin): | |
destinations = airport_data[airport_data['Origin_Airport'] == origin]['Destination_Airport'].unique() | |
return gr.update(choices=list(destinations)) | |
def update_distance(origin, destination): | |
distance_value = airport_dict.get((origin, destination), {}).get('distance', 'Distance not found') | |
if distance_value == 'Distance not found': | |
return gr.update(value=0) # Return 0 if distance is not found | |
return gr.update(value=distance_value) | |
with gr.Blocks() as demo: | |
gr.Markdown("## Fuel Burn Prediction") | |
with gr.Row(): | |
model_name = gr.Dropdown( | |
label="Aircraft Model", | |
choices=list(aircraft_dict.keys()), | |
value=list(aircraft_dict.keys())[0], | |
) | |
origin = gr.Dropdown( | |
label="Origin Airport", | |
choices=sorted(airport_data['Origin_Airport'].unique()) | |
) | |
destination = gr.Dropdown( | |
label="Destination Airport", | |
choices=[] | |
) | |
with gr.Row(): | |
jt = gr.Textbox(label="J/T", interactive=False) | |
cat = gr.Textbox(label="CAT", interactive=False) | |
manufacturer = gr.Textbox(label="Manufacturer", interactive=False) | |
seats = gr.Number(label="Seats") | |
distance = gr.Number(label="Distance", interactive=False) | |
model_name.change(fn=update_fields, inputs=model_name, outputs=[jt, cat, manufacturer]) | |
origin.change(fn=update_destination_options, inputs=origin, outputs=destination) | |
destination.change(fn=update_distance, inputs=[origin, destination], outputs=distance) | |
submit_btn = gr.Button("Predict Fuel Burn") | |
result = gr.Textbox(label="Fuel Burn in Kg", interactive=False) | |
submit_btn.click(predict_fuel_burn, inputs=[model_name, origin, destination, seats, distance], outputs=result) | |
demo.launch() | |