poudel commited on
Commit
b2be77c
1 Parent(s): c5154e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ from models.neural_network.inference import load_model_and_preprocessor
6
+
7
+
8
+ # Load the pre-trained model and preprocessor
9
+ nn_model, nn_preprocessor = load_model_and_preprocessor('nn_model.keras', 'nn_preprocessor.pkl')
10
+
11
+ # Load the unique aircraft data and airport distances
12
+ aircraft_data = pd.read_csv('aircraft_data.csv').drop_duplicates(subset='model')
13
+ airport_data = pd.read_csv('airport_distances.csv')
14
+
15
+
16
+
17
+ def predict_fuel_burn(model_name, origin, destination, seats, distance):
18
+ # Validate the distance against seats
19
+ max_seats = aircraft_dict[model_name]['seats']
20
+ if seats > max_seats:
21
+ return f"The {model_name} aircraft has a maximum of {max_seats} seats."
22
+ if seats <= 0:
23
+ return "The number of seats must be greater than 0."
24
+ if distance <= 0:
25
+ return "The distance must be greater than 0."
26
+
27
+ # Prepare the input data for the model
28
+ data = {
29
+ 'model': [model_name],
30
+ 'Origin_Airport': [origin],
31
+ 'Destination_Airport': [destination],
32
+ 'seats': [seats],
33
+ 'distance': [distance],
34
+ 'J/T': [aircraft_dict[model_name]['J/T']],
35
+ 'CAT': [aircraft_dict[model_name]['CAT']],
36
+ '_Manufacturer': [aircraft_dict[model_name]['_Manufacturer']],
37
+ 'dist': [distance]
38
+ }
39
+
40
+ df = pd.DataFrame(data)
41
+
42
+ # Make the prediction
43
+ fuel_burn_prediction_nn = nn_model.predict(nn_preprocessor.transform(df))[0]
44
+
45
+ return f" {fuel_burn_prediction_nn[0]:.2f} kg"
46
+
47
+
48
+ def update_fields(model_name):
49
+ return {
50
+ jt: gr.update(value=aircraft_dict[model_name]['J/T']),
51
+ cat: gr.update(value=aircraft_dict[model_name]['CAT']),
52
+ manufacturer: gr.update(value=aircraft_dict[model_name]['_Manufacturer'])
53
+ }
54
+
55
+
56
+ def update_destination_options(origin):
57
+ destinations = airport_data[airport_data['Origin_Airport'] == origin]['Destination_Airport'].unique()
58
+ return gr.update(choices=list(destinations))
59
+
60
+
61
+ def update_distance(origin, destination):
62
+ distance_value = airport_dict.get((origin, destination), {}).get('distance', 'Distance not found')
63
+ if distance_value == 'Distance not found':
64
+ return gr.update(value=0) # Return 0 if distance is not found
65
+ return gr.update(value=distance_value)
66
+
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("## Fuel Burn Prediction")
70
+
71
+ with gr.Row():
72
+ model_name = gr.Dropdown(
73
+ label="Aircraft Model",
74
+ choices=list(aircraft_dict.keys()),
75
+ value=list(aircraft_dict.keys())[0],
76
+ )
77
+ origin = gr.Dropdown(
78
+ label="Origin Airport",
79
+ choices=sorted(airport_data['Origin_Airport'].unique())
80
+ )
81
+ destination = gr.Dropdown(
82
+ label="Destination Airport",
83
+ choices=[]
84
+ )
85
+
86
+ with gr.Row():
87
+ jt = gr.Textbox(label="J/T", interactive=False)
88
+ cat = gr.Textbox(label="CAT", interactive=False)
89
+ manufacturer = gr.Textbox(label="Manufacturer", interactive=False)
90
+ seats = gr.Number(label="Seats")
91
+
92
+ distance = gr.Number(label="Distance", interactive=False)
93
+
94
+ model_name.change(fn=update_fields, inputs=model_name, outputs=[jt, cat, manufacturer])
95
+ origin.change(fn=update_destination_options, inputs=origin, outputs=destination)
96
+ destination.change(fn=update_distance, inputs=[origin, destination], outputs=distance)
97
+
98
+ submit_btn = gr.Button("Predict Fuel Burn")
99
+ result = gr.Textbox(label="Fuel Burn in Kg", interactive=False)
100
+
101
+ submit_btn.click(predict_fuel_burn, inputs=[model_name, origin, destination, seats, distance], outputs=result)
102
+
103
+ demo.launch()
104
+