|
import os |
|
import socket |
|
import time |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import supervision as sv |
|
import cv2 |
|
import base64 |
|
import requests |
|
import json |
|
|
|
|
|
DL4EO_API_URL = "https://dl4eo--ship-predict.modal.run" |
|
|
|
|
|
DL4EO_API_KEY = os.environ['DL4EO_API_KEY'] |
|
|
|
|
|
LINE_WIDTH = 2 |
|
|
|
|
|
print(f"Gradio version: {gr.__version__}") |
|
|
|
|
|
def predict_image(img, threshold): |
|
|
|
if isinstance(img, Image.Image): |
|
img = np.array(img) |
|
|
|
if not isinstance(img, np.ndarray) or len(img.shape) != 3 or img.shape[2] != 3: |
|
raise BaseException("predict_image(): input 'img' shoud be single RGB image in PIL or Numpy array format.") |
|
|
|
|
|
image_base64 = base64.b64encode(np.ascontiguousarray(img)).decode() |
|
|
|
|
|
payload = { |
|
'image': image_base64, |
|
'shape': img.shape, |
|
'threshold': threshold, |
|
} |
|
|
|
headers = { |
|
'Authorization': 'Bearer ' + DL4EO_API_KEY, |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
|
|
response = requests.post(DL4EO_API_URL, json=payload, headers=headers) |
|
|
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"Received status code={response.status_code} in inference API" |
|
) |
|
|
|
json_data = json.loads(response.content) |
|
detections = json_data['detections'] |
|
duration = json_data['duration'] |
|
|
|
|
|
cv2_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
oriented_box_annotator = sv.OrientedBoxAnnotator() |
|
annotated_frame = oriented_box_annotator.annotate( |
|
scene=cv2_img, |
|
detections=detections |
|
) |
|
image_with_predictions_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) |
|
|
|
img_data_in = base64.b64decode(json_data['image']) |
|
np_img = np.frombuffer(img_data_in, dtype=np.uint8).reshape(img.shape) |
|
pil_img = Image.fromarray(np_img) |
|
|
|
return pil_img, img.shape, len(detections), duration |
|
|
|
|
|
|
|
example_data = [ |
|
["./demo/12ab97857.jpg", 0.8], |
|
["./demo/82f13510a.jpg", 0.8], |
|
["./demo/836f35381.jpg", 0.8], |
|
["./demo/848d2afef.jpg", 0.8], |
|
["./demo/911b25478.jpg", 0.8], |
|
["./demo/b86e4046f.jpg", 0.8], |
|
["./demo/ce2220f49.jpg", 0.8], |
|
["./demo/d9762ef5e.jpg", 0.8], |
|
["./demo/fa613751e.jpg", 0.8], |
|
|
|
|
|
] |
|
|
|
|
|
css = """ |
|
.image-preview { |
|
height: 820px !important; |
|
width: 800px !important; |
|
} |
|
""" |
|
|
|
TITLE = "Oriented bounding boxes detection on Optical Satellite images" |
|
|
|
|
|
demo = gr.Blocks(title=TITLE, css=css).queue() |
|
with demo: |
|
gr.Markdown(f"<h1><center>{TITLE}<center><h1>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0): |
|
input_image = gr.Image(type="pil", interactive=True) |
|
run_button = gr.Button(value="Run") |
|
with gr.Accordion("Advanced options", open=True): |
|
threshold = gr.Slider(label="Confidence threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01) |
|
dimensions = gr.Textbox(label="Image size", interactive=False) |
|
detections = gr.Textbox(label="Predicted objects", interactive=False) |
|
stopwatch = gr.Number(label="Execution time (sec.)", interactive=False, precision=3) |
|
|
|
with gr.Column(scale=2): |
|
output_image = gr.Image(type="pil", elem_classes='image-preview', interactive=False) |
|
|
|
run_button.click(fn=predict_image, inputs=[input_image, threshold], outputs=[output_image, dimensions, detections, stopwatch]) |
|
gr.Examples( |
|
examples=example_data, |
|
inputs = [input_image, threshold], |
|
outputs = [output_image, dimensions, detections, stopwatch], |
|
fn=predict_image, |
|
cache_examples=True, |
|
label='Try these images!' |
|
) |
|
|
|
gr.Markdown(""" |
|
<p>This demo is provided by <a href='https://www.linkedin.com/in/faudi/'>Jeff Faudi</a> and <a href='https://www.dl4eo.com/'>DL4EO</a>. |
|
This model is based on the <a href='https://github.com/open-mmlab/mmrotate'>MMRotate framework</a> which provides oriented bounding boxes. |
|
We believe that oriented bouding boxes are better suited for detection in satellite images. This model has been trained on the |
|
<a href='https://captain-whu.github.io/DOTA/dataset.html'>DOTA dataset</a> which contains 15 classes: plane, ship, storage tank, |
|
baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, large vehicle, small vehicle, helicopter, |
|
roundabout, soccer ball field and swimming pool. </p><p>The associated licenses are |
|
<a href='https://about.google/brand-resource-center/products-and-services/geo-guidelines/#google-earth-web-and-apps'>GoogleEarth fair use</a> |
|
and <a href='https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en'>CC-BY-SA-NC</a>. This demonstration CANNOT be used for commercial puposes. |
|
Please contact <a href='mailto:jeff@dl4eo.com'>me</a> for more information on how you could get access to a commercial grade model or API. </p> |
|
""") |
|
|
|
demo.launch( |
|
inline=False, |
|
show_api=False, |
|
debug=False |
|
) |
|
|