File size: 1,957 Bytes
beb576f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948d0c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import pipeline
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import gradio as gr

# Initialize another model and feature extractor
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')

# Initialize the object detection pipeline
object_detector = pipeline("object-detection", model = model, feature_extractor = feature_extractor)

# Draw bounding box definition
def draw_bounding_box(im, score, label, xmin, ymin, xmax, ymax, index, num_boxes):
    """ Draw a bounding box. """
    # Draw the actual bounding box
    outline = ' '
    if label in ['truck', 'car', 'motorcycle', 'bus']:
      outline = 'red'
    elif label in ['person', 'bicycle']:
      outline = 'green'
    else:
      outline = 'blue'
    im_with_rectangle = ImageDraw.Draw(im)  
    im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = outline, width = 3, radius = 10)

    # Return the result
    return im

def detect_image(im):
    # Perform object detection
    bounding_boxes = object_detector(im)

    # Iteration elements
    num_boxes = len(bounding_boxes)
    index = 0

    # Draw bounding box for each result
    for bounding_box in bounding_boxes:
        if bounding_box['label'] in ['person','motorcycle','bicycle', 'truck', 'car','bus']:
            box = bounding_box['box']

            #Draw the bounding box
            output_image = draw_bounding_box(im, bounding_box['score'],
                                  bounding_box['label'],
                                  box['xmin'], box['ymin'],
                                  box['xmax'], box['ymax'],
                                  index, num_boxes)
            index += 1

    return output_image
    
iface = gr.Interface(detect_image, gr.inputs.Image(type = 'pil'), gr.outputs.Image()).launch()