SatwikKambham's picture
Added example images
4818da1
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from huggingface_hub import hf_hub_download
import gradio as gr
class ObjectDetection:
def __init__(self, ckpt_path):
self.test_transform = A.Compose(
[
A.Resize(800, 600),
A.CLAHE(clip_limit=10, p=1),
A.Normalize(
[0.29278653, 0.25276296, 0.22975405],
[0.22653664, 0.19836408, 0.17775835],
),
ToTensorV2(),
],
)
self.model = torch.hub.load(
"facebookresearch/detr", "detr_resnet50", pretrained=False
)
in_features = self.model.class_embed.in_features
self.model.class_embed = nn.Linear(
in_features=in_features,
out_features=12,
)
self.labels = [
"Dog",
"Motorbike",
"People",
"Cat",
"Chair",
"Table",
"Car",
"Bicycle",
"Bottle",
"Bus",
"Cup",
"Boat",
]
model_ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))
self.model.load_state_dict(model_ckpt)
self.model.eval()
def predict(self, img):
score_threshold, iou_threshold = 0.05, 0.1
img_w, img_h = img.size
inp = self.test_transform(image=np.array(img.convert("RGB")))["image"]
out = self.model(inp.unsqueeze(0))
probas = out["pred_logits"].softmax(-1)[0, :, :-1]
bboxes = []
scores = []
for idx, bbox in enumerate(out["pred_boxes"][0]):
if not probas[idx].max().item() >= score_threshold:
continue
x_c, y_c, w, h = bbox.detach().numpy()
x1 = int((x_c - w * 0.5) * img_w)
y1 = int((y_c - h * 0.5) * img_h)
x2 = int((x_c + w * 0.5) * img_w)
y2 = int((y_c + h * 0.5) * img_h)
label_idx = probas[idx].argmax().item()
label = self.labels[label_idx] + f" {probas[idx].max().item():.2f}"
bboxes.append(((x1, y1, x2, y2), label))
scores.append(probas[idx].max().item())
selected_indices = self.non_max_suppression(
bboxes,
scores,
iou_threshold,
)
bboxes = [bboxes[i] for i in selected_indices]
return (img, bboxes)
def non_max_suppression(self, boxes, scores, iou_threshold):
if len(boxes) == 0:
return []
sorted_indices = sorted(
range(len(scores)), key=lambda i: scores[i], reverse=True
)
selected_indices = []
while sorted_indices:
current_index = sorted_indices[0]
selected_indices.append(current_index)
sorted_indices.pop(0)
ious = [
self.calculate_iou(boxes[current_index][0], boxes[i][0])
for i in sorted_indices
]
indices_to_remove = [i for i, iou in enumerate(ious) if iou > iou_threshold]
sorted_indices = [
i for j, i in enumerate(sorted_indices) if j not in indices_to_remove
]
return selected_indices
def calculate_iou(self, box1, box2):
"""
Calculate the Intersection over Union (IoU) of two bounding boxes.
Args:
box1: [x1, y1, x2, y2] for the first box.
box2: [x1, y1, x2, y2] for the second box.
Returns:
IoU value.
"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
iou = intersection_area / (box1_area + box2_area - intersection_area)
return iou
model_path = hf_hub_download(
repo_id="SatwikKambham/detr_low_light",
filename="detr.pt",
)
detector = ObjectDetection(ckpt_path=model_path)
iface = gr.Interface(
fn=detector.predict,
inputs=[
gr.Image(
type="pil",
label="Input",
height=400,
),
],
outputs=gr.AnnotatedImage(
height=400,
),
examples="Examples",
)
iface.launch()