Vineedhar's picture
Update app.py
e234d10 verified
raw
history blame contribute delete
No virus
2.6 kB
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
def iou(y_true, y_pred):
def f(y_true, y_pred):
intersection = (y_true * y_pred).sum()
union = y_true.sum() + y_pred.sum() - intersection
x = (intersection + 1e-15) / (union + 1e-15)
x = x.astype(np.float32)
return x
return tf.numpy_function(f, [y_true, y_pred], tf.float32)
def dice_coef(y_true, y_pred):
y_true = tf.keras.layers.Flatten()(y_true)
y_pred = tf.keras.layers.Flatten()(y_pred)
intersection = tf.reduce_sum(y_true * y_pred)
return (2. * intersection + 1e-15) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred))
def dice_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred)
def read_image(file_path, target_size=(256, 256)):
img = Image.open(file_path)
img = img.resize(target_size)
x = np.array(img, dtype=np.float32)
x = x / 255.0
return x
def preprocess_image(img):
if img.shape[-1] == 4:
img = img[..., :3]
img_expanded = np.expand_dims(img, axis=0)
return img_expanded
def predict_image(model, img):
pred = model.predict(img)
return pred[0, ...] # Taking the first item in the batch
# Load the model with specific custom objects
loaded_model = tf.keras.models.load_model(
"oryx_road_segmentation_model.h5",
custom_objects={'dice_coef': dice_coef, 'iou': iou})
def process_image(image):
img = read_image(image)
img_preprocessed = preprocess_image(img)
pred = predict_image(loaded_model, img_preprocessed)
# Convert single-channel image to RGB by duplicating the channel across RGB
pred_img = np.squeeze(pred) # Remove the singleton dimension
pred_img = np.clip(pred_img, 0, 1) # Ensure all values are between 0 and 1
pred_img_rgb = np.stack((pred_img,)*3, axis=-1) # Stack grayscale across three channels to mimic RGB
pred_img_rgb = (pred_img_rgb * 255).astype(np.uint8) # Scale to 0-255 and convert to uint8
return Image.fromarray(pred_img_rgb) # Now converting a proper 2D RGB array
# Sample images directory or paths
sample_images = ["234989_sat.jpg", "751359_sat.jpg","168243_sat.jpg","877873_sat.jpg","836987_sat.jpg"]
# Gradio Interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="filepath"),
outputs=gr.Image(type="pil", label="Predicted Image"),
title="orYx Models' - Road Segmentation Predictor",
description="Upload an image or choose a sample and view the model's segmentation for roads on different terrains.",
examples= sample_images
)
iface.launch(debug=True)