File size: 4,485 Bytes
abf616c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4818da1
 
abf616c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4818da1
 
 
 
abf616c
 
 
4818da1
abf616c
4818da1
abf616c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy as np

import torch
import torch.nn as nn

import albumentations as A
from albumentations.pytorch import ToTensorV2

from huggingface_hub import hf_hub_download
import gradio as gr


class ObjectDetection:
    def __init__(self, ckpt_path):
        self.test_transform = A.Compose(
            [
                A.Resize(800, 600),
                A.CLAHE(clip_limit=10, p=1),
                A.Normalize(
                    [0.29278653, 0.25276296, 0.22975405],
                    [0.22653664, 0.19836408, 0.17775835],
                ),
                ToTensorV2(),
            ],
        )

        self.model = torch.hub.load(
            "facebookresearch/detr", "detr_resnet50", pretrained=False
        )
        in_features = self.model.class_embed.in_features
        self.model.class_embed = nn.Linear(
            in_features=in_features,
            out_features=12,
        )
        self.labels = [
            "Dog",
            "Motorbike",
            "People",
            "Cat",
            "Chair",
            "Table",
            "Car",
            "Bicycle",
            "Bottle",
            "Bus",
            "Cup",
            "Boat",
        ]

        model_ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))
        self.model.load_state_dict(model_ckpt)
        self.model.eval()

    def predict(self, img):
        score_threshold, iou_threshold = 0.05, 0.1
        img_w, img_h = img.size
        inp = self.test_transform(image=np.array(img.convert("RGB")))["image"]
        out = self.model(inp.unsqueeze(0))
        probas = out["pred_logits"].softmax(-1)[0, :, :-1]
        bboxes = []
        scores = []
        for idx, bbox in enumerate(out["pred_boxes"][0]):
            if not probas[idx].max().item() >= score_threshold:
                continue
            x_c, y_c, w, h = bbox.detach().numpy()
            x1 = int((x_c - w * 0.5) * img_w)
            y1 = int((y_c - h * 0.5) * img_h)
            x2 = int((x_c + w * 0.5) * img_w)
            y2 = int((y_c + h * 0.5) * img_h)
            label_idx = probas[idx].argmax().item()
            label = self.labels[label_idx] + f" {probas[idx].max().item():.2f}"
            bboxes.append(((x1, y1, x2, y2), label))
            scores.append(probas[idx].max().item())
        selected_indices = self.non_max_suppression(
            bboxes,
            scores,
            iou_threshold,
        )
        bboxes = [bboxes[i] for i in selected_indices]
        return (img, bboxes)

    def non_max_suppression(self, boxes, scores, iou_threshold):
        if len(boxes) == 0:
            return []

        sorted_indices = sorted(
            range(len(scores)), key=lambda i: scores[i], reverse=True
        )
        selected_indices = []

        while sorted_indices:
            current_index = sorted_indices[0]
            selected_indices.append(current_index)
            sorted_indices.pop(0)

            ious = [
                self.calculate_iou(boxes[current_index][0], boxes[i][0])
                for i in sorted_indices
            ]

            indices_to_remove = [i for i, iou in enumerate(ious) if iou > iou_threshold]

            sorted_indices = [
                i for j, i in enumerate(sorted_indices) if j not in indices_to_remove
            ]

        return selected_indices

    def calculate_iou(self, box1, box2):
        """
        Calculate the Intersection over Union (IoU) of two bounding boxes.

        Args:
            box1: [x1, y1, x2, y2] for the first box.
            box2: [x1, y1, x2, y2] for the second box.

        Returns:
            IoU value.
        """
        x1 = max(box1[0], box2[0])
        y1 = max(box1[1], box2[1])
        x2 = min(box1[2], box2[2])
        y2 = min(box1[3], box2[3])

        intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
        box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
        box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

        iou = intersection_area / (box1_area + box2_area - intersection_area)

        return iou


model_path = hf_hub_download(
    repo_id="SatwikKambham/detr_low_light",
    filename="detr.pt",
)
detector = ObjectDetection(ckpt_path=model_path)
iface = gr.Interface(
    fn=detector.predict,
    inputs=[
        gr.Image(
            type="pil",
            label="Input",
            height=400,
        ),
    ],
    outputs=gr.AnnotatedImage(
        height=400,
    ),
    examples="Examples",
)
iface.launch()