khizon's picture
Update app.py
c182d0a
raw
history blame contribute delete
No virus
2.38 kB
from transformers import pipeline
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
# Initialize another model and feature extractor
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
# Initialize the object detection pipeline
object_detector = pipeline("object-detection", model = model, feature_extractor = feature_extractor)
# Draw bounding box definition
def draw_bounding_box(im, score, label, xmin, ymin, xmax, ymax, index, num_boxes):
""" Draw a bounding box. """
# Draw the actual bounding box
outline = ' '
if label in ['truck', 'car', 'motorcycle', 'bus']:
outline = 'red'
elif label in ['person', 'bicycle']:
outline = 'green'
else:
outline = 'blue'
im_with_rectangle = ImageDraw.Draw(im)
im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = outline, width = 3, radius = 10)
# Return the result
return im
def detect_image(im):
# Perform object detection
bounding_boxes = object_detector(im)
# Iteration elements
num_boxes = len(bounding_boxes)
index = 0
# Draw bounding box for each result
for bounding_box in bounding_boxes:
if bounding_box['label'] in ['person','motorcycle','bicycle', 'truck', 'car','bus']:
box = bounding_box['box']
#Draw the bounding box
output_image = draw_bounding_box(im, bounding_box['score'],
bounding_box['label'],
box['xmin'], box['ymin'],
box['xmax'], box['ymax'],
index, num_boxes)
index += 1
return output_image
TITLE = 'Active Transport Detection'
DESCRIPTION = 'This uses DETR as an object detection model and detects motor vehicles (red) and people and bikes (green). Much fine-tuning and optimization is still needed to make this a practical application'
examples = [['bike.jpg'], ['bike2.jpg'], ['bike_3.jpg'], ['bike_4.jpg']]
iface = gr.Interface(detect_image, gr.inputs.Image(type = 'pil'), gr.outputs.Image(), examples = examples, allow_flagging = 'never', title = TITLE, description = DESCRIPTION).launch(debug = True)