|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
|
|
from huggingface_hub import hf_hub_download |
|
import gradio as gr |
|
|
|
|
|
class ObjectDetection: |
|
def __init__(self, ckpt_path): |
|
self.test_transform = A.Compose( |
|
[ |
|
A.Resize(800, 600), |
|
A.CLAHE(clip_limit=10, p=1), |
|
A.Normalize( |
|
[0.29278653, 0.25276296, 0.22975405], |
|
[0.22653664, 0.19836408, 0.17775835], |
|
), |
|
ToTensorV2(), |
|
], |
|
) |
|
|
|
self.model = torch.hub.load( |
|
"facebookresearch/detr", "detr_resnet50", pretrained=False |
|
) |
|
in_features = self.model.class_embed.in_features |
|
self.model.class_embed = nn.Linear( |
|
in_features=in_features, |
|
out_features=12, |
|
) |
|
self.labels = [ |
|
"Dog", |
|
"Motorbike", |
|
"People", |
|
"Cat", |
|
"Chair", |
|
"Table", |
|
"Car", |
|
"Bicycle", |
|
"Bottle", |
|
"Bus", |
|
"Cup", |
|
"Boat", |
|
] |
|
|
|
model_ckpt = torch.load(ckpt_path, map_location=torch.device("cpu")) |
|
self.model.load_state_dict(model_ckpt) |
|
self.model.eval() |
|
|
|
def predict(self, img): |
|
score_threshold, iou_threshold = 0.05, 0.1 |
|
img_w, img_h = img.size |
|
inp = self.test_transform(image=np.array(img.convert("RGB")))["image"] |
|
out = self.model(inp.unsqueeze(0)) |
|
probas = out["pred_logits"].softmax(-1)[0, :, :-1] |
|
bboxes = [] |
|
scores = [] |
|
for idx, bbox in enumerate(out["pred_boxes"][0]): |
|
if not probas[idx].max().item() >= score_threshold: |
|
continue |
|
x_c, y_c, w, h = bbox.detach().numpy() |
|
x1 = int((x_c - w * 0.5) * img_w) |
|
y1 = int((y_c - h * 0.5) * img_h) |
|
x2 = int((x_c + w * 0.5) * img_w) |
|
y2 = int((y_c + h * 0.5) * img_h) |
|
label_idx = probas[idx].argmax().item() |
|
label = self.labels[label_idx] + f" {probas[idx].max().item():.2f}" |
|
bboxes.append(((x1, y1, x2, y2), label)) |
|
scores.append(probas[idx].max().item()) |
|
selected_indices = self.non_max_suppression( |
|
bboxes, |
|
scores, |
|
iou_threshold, |
|
) |
|
bboxes = [bboxes[i] for i in selected_indices] |
|
return (img, bboxes) |
|
|
|
def non_max_suppression(self, boxes, scores, iou_threshold): |
|
if len(boxes) == 0: |
|
return [] |
|
|
|
sorted_indices = sorted( |
|
range(len(scores)), key=lambda i: scores[i], reverse=True |
|
) |
|
selected_indices = [] |
|
|
|
while sorted_indices: |
|
current_index = sorted_indices[0] |
|
selected_indices.append(current_index) |
|
sorted_indices.pop(0) |
|
|
|
ious = [ |
|
self.calculate_iou(boxes[current_index][0], boxes[i][0]) |
|
for i in sorted_indices |
|
] |
|
|
|
indices_to_remove = [i for i, iou in enumerate(ious) if iou > iou_threshold] |
|
|
|
sorted_indices = [ |
|
i for j, i in enumerate(sorted_indices) if j not in indices_to_remove |
|
] |
|
|
|
return selected_indices |
|
|
|
def calculate_iou(self, box1, box2): |
|
""" |
|
Calculate the Intersection over Union (IoU) of two bounding boxes. |
|
|
|
Args: |
|
box1: [x1, y1, x2, y2] for the first box. |
|
box2: [x1, y1, x2, y2] for the second box. |
|
|
|
Returns: |
|
IoU value. |
|
""" |
|
x1 = max(box1[0], box2[0]) |
|
y1 = max(box1[1], box2[1]) |
|
x2 = min(box1[2], box2[2]) |
|
y2 = min(box1[3], box2[3]) |
|
|
|
intersection_area = max(0, x2 - x1) * max(0, y2 - y1) |
|
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) |
|
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) |
|
|
|
iou = intersection_area / (box1_area + box2_area - intersection_area) |
|
|
|
return iou |
|
|
|
|
|
model_path = hf_hub_download( |
|
repo_id="SatwikKambham/detr_low_light", |
|
filename="detr.pt", |
|
) |
|
detector = ObjectDetection(ckpt_path=model_path) |
|
iface = gr.Interface( |
|
fn=detector.predict, |
|
inputs=[ |
|
gr.Image( |
|
type="pil", |
|
label="Input", |
|
height=400, |
|
), |
|
], |
|
outputs=gr.AnnotatedImage( |
|
height=400, |
|
), |
|
examples="Examples", |
|
) |
|
iface.launch() |
|
|