Vineedhar commited on
Commit
8db8b38
1 Parent(s): 6564e85

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import tensorflow as tf
5
+ import matplotlib.pyplot as plt
6
+
7
+ def iou(y_true, y_pred):
8
+ def f(y_true, y_pred):
9
+ intersection = (y_true * y_pred).sum()
10
+ union = y_true.sum() + y_pred.sum() - intersection
11
+ x = (intersection + 1e-15) / (union + 1e-15)
12
+ x = x.astype(np.float32)
13
+ return x
14
+ return tf.numpy_function(f, [y_true, y_pred], tf.float32)
15
+
16
+ def dice_coef(y_true, y_pred):
17
+ y_true = tf.keras.layers.Flatten()(y_true)
18
+ y_pred = tf.keras.layers.Flatten()(y_pred)
19
+ intersection = tf.reduce_sum(y_true * y_pred)
20
+ return (2. * intersection + 1e-15) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred))
21
+
22
+ def dice_loss(y_true, y_pred):
23
+ return 1.0 - dice_coef(y_true, y_pred)
24
+
25
+ def read_image(file, target_size=(256, 256)):
26
+ img = Image.open(file).convert('RGB')
27
+ img = img.resize(target_size)
28
+ x = np.array(img, dtype=np.float32)
29
+ x = x / 255.0
30
+ return x
31
+
32
+ def preprocess_image(img):
33
+ if img.shape[-1] == 4:
34
+ img = img[..., :3]
35
+ img_expanded = np.expand_dims(img, axis=0)
36
+ return img_expanded
37
+
38
+ def predict_image(model, img):
39
+ pred = model.predict(img)
40
+ return pred
41
+
42
+ def visualize_prediction(img, pred):
43
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
44
+ axs[0].imshow(img)
45
+ axs[0].set_title('Original Image')
46
+ axs[1].imshow(pred[0, ...], cmap='gray') # Assuming the prediction is a mask or similar
47
+ axs[1].set_title('Predicted Image')
48
+ plt.close(fig)
49
+ return fig
50
+
51
+ # Load the model with custom loss and metric
52
+ model = tf.keras.models.load_model("oryx_road_segmentation_model.h5", custom_objects={'dice_coef': dice_coef, 'iou': iou})
53
+
54
+ def process_image(image):
55
+ img = read_image(image)
56
+ img_preprocessed = preprocess_image(img)
57
+ pred = predict_image(model, img_preprocessed)
58
+ return visualize_prediction(img, pred)
59
+
60
+ iface = gr.Interface(fn=process_image, inputs="file", outputs="plot", title="orYx Models - Road Segmentation")
61
+ iface.launch()