sanket09 commited on
Commit
cf7e785
1 Parent(s): 2146269

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -6
app.py CHANGED
@@ -15,6 +15,8 @@ import rasterio
15
  import matplotlib.pyplot as plt
16
  from tensorflow.keras.applications import ResNet50
17
  from tensorflow.keras.models import Model
 
 
18
 
19
  # Load crop data
20
  def load_data():
@@ -55,7 +57,64 @@ def predict_traditional(model_name, year, state, crop, yield_):
55
  else:
56
  return "Model not found"
57
 
58
- # Load pre-trained deep learning models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def load_deep_learning_model(model_name):
60
  base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
61
  base_model.trainable = False
@@ -123,9 +182,9 @@ def predict_deep_learning(model_name, file):
123
  plt.colorbar()
124
 
125
  # Save the plot to a file
126
- plt.savefig('/tmp/prediction_overlay.png')
127
 
128
- return '/tmp/prediction_overlay.png'
129
  else:
130
  return "No file uploaded"
131
  else:
@@ -141,7 +200,7 @@ inputs_traditional = [
141
  outputs_traditional = gr.Textbox(label='Predicted Profit')
142
 
143
  inputs_deep_learning = [
144
- gr.Dropdown(choices=list(deep_learning_models.keys()), label='Model'),
145
  gr.File(label='Upload TIFF File')
146
  ]
147
  outputs_deep_learning = gr.Image(label='Prediction Overlay')
@@ -157,10 +216,10 @@ with gr.Blocks() as demo:
157
 
158
  with gr.Tab("Deep Learning Models"):
159
  gr.Interface(
160
- fn=predict_deep_learning,
161
  inputs=inputs_deep_learning,
162
  outputs=outputs_deep_learning,
163
- title="Crop Yield Prediction using Deep Learning Models"
164
  )
165
 
166
  demo.launch()
 
15
  import matplotlib.pyplot as plt
16
  from tensorflow.keras.applications import ResNet50
17
  from tensorflow.keras.models import Model
18
+ import cv2
19
+ import joblib
20
 
21
  # Load crop data
22
  def load_data():
 
57
  else:
58
  return "Model not found"
59
 
60
+ # Train RandomForestRegressor model for deep learning model
61
+ def train_random_forest_model():
62
+ def process_tiff(file_path):
63
+ with rasterio.open(file_path) as src:
64
+ tiff_data = src.read()
65
+ B2_image = tiff_data[1, :, :] # Assuming B2 is the second band
66
+ target_size = (50, 50)
67
+ B2_resized = cv2.resize(B2_image, target_size, interpolation=cv2.INTER_NEAREST)
68
+ return B2_resized.reshape(-1, 1)
69
+
70
+ data_dir = 'Data'
71
+ X_list = []
72
+ y_list = []
73
+
74
+ for root, dirs, files in os.walk(data_dir):
75
+ for file in files:
76
+ if file.endswith('.tiff'):
77
+ file_path = os.path.join(root, file)
78
+ X_list.append(process_tiff(file_path))
79
+ y_list.append(np.random.rand(2500)) # Replace with actual target data
80
+
81
+ X = np.vstack(X_list)
82
+ y = np.hstack(y_list)
83
+
84
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
85
+
86
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
87
+ model.fit(X_train, y_train)
88
+
89
+ return model
90
+
91
+ rf_model = train_random_forest_model()
92
+
93
+ def predict_random_forest(file):
94
+ if file is not None:
95
+ def process_tiff(file_path):
96
+ with rasterio.open(file_path) as src:
97
+ tiff_data = src.read()
98
+ B2_image = tiff_data[1, :, :]
99
+ target_size = (50, 50)
100
+ B2_resized = cv2.resize(B2_image, target_size, interpolation=cv2.INTER_NEAREST)
101
+ return B2_resized.reshape(-1, 1)
102
+
103
+ tiff_processed = process_tiff(file.name)
104
+ prediction = rf_model.predict(tiff_processed)
105
+ prediction_reshaped = prediction.reshape((50, 50))
106
+
107
+ plt.figure(figsize=(10, 10))
108
+ plt.imshow(prediction_reshaped, cmap='viridis')
109
+ plt.colorbar()
110
+ plt.title('Yield Prediction for Single TIFF File')
111
+ plt.savefig('/tmp/rf_prediction_overlay.png')
112
+
113
+ return '/tmp/rf_prediction_overlay.png'
114
+ else:
115
+ return "No file uploaded"
116
+
117
+ # Load deep learning models
118
  def load_deep_learning_model(model_name):
119
  base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
120
  base_model.trainable = False
 
182
  plt.colorbar()
183
 
184
  # Save the plot to a file
185
+ plt.savefig('/tmp/dl_prediction_overlay.png')
186
 
187
+ return '/tmp/dl_prediction_overlay.png'
188
  else:
189
  return "No file uploaded"
190
  else:
 
200
  outputs_traditional = gr.Textbox(label='Predicted Profit')
201
 
202
  inputs_deep_learning = [
203
+ gr.Dropdown(choices=list(deep_learning_models.keys()) + ['Random Forest'], label='Model'),
204
  gr.File(label='Upload TIFF File')
205
  ]
206
  outputs_deep_learning = gr.Image(label='Prediction Overlay')
 
216
 
217
  with gr.Tab("Deep Learning Models"):
218
  gr.Interface(
219
+ fn=lambda model_name, file: predict_deep_learning(model_name, file) if model_name != 'Random Forest' else predict_random_forest(file),
220
  inputs=inputs_deep_learning,
221
  outputs=outputs_deep_learning,
222
+ title="Crop Yield Prediction using Deep Learning Models and Random Forest"
223
  )
224
 
225
  demo.launch()