|
from CircumSpect.object_detection.groundingdino.util.inference import load_model, predict |
|
import CircumSpect.object_detection.groundingdino.datasets.transforms as T |
|
from torchvision.ops import box_convert |
|
from utils import setup_device |
|
from typing import Tuple, List |
|
import supervision as sv |
|
from io import BytesIO |
|
from PIL import Image |
|
import numpy as np |
|
import requests |
|
import torch |
|
import cv2 |
|
|
|
model = load_model("./CircumSpect/object_detection/groundingdino/config/GroundingDINO_SwinT_OGC.py", |
|
"./CircumSpect/object_detection/weights/groundingdino_swint_ogc.pth") |
|
|
|
|
|
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: |
|
transform = T.Compose( |
|
[ |
|
T.RandomResize([800], max_size=1333), |
|
T.ToTensor(), |
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
image_source = Image.open(BytesIO(requests.get( |
|
image_path).content) if image_path.startswith("http") else image_path).convert("RGB") |
|
image = np.asarray(image_source) |
|
image_transformed, _ = transform(image_source, None) |
|
return image, image_transformed |
|
|
|
|
|
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray: |
|
h, w, _ = image_source.shape |
|
boxes = boxes * torch.Tensor([w, h, w, h]) |
|
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() |
|
detections = sv.Detections(xyxy=xyxy) |
|
|
|
labels = [ |
|
f"{phrase} {logit:.2f}" |
|
for phrase, logit |
|
in zip(phrases, logits) |
|
] |
|
coordinates = [] |
|
for i in list(detections): |
|
left, top, right, bottom = list(list(i)[0]) |
|
coordinates.append((int((left+right)/2), int((top+bottom)/2))) |
|
|
|
object_coordinates = [[" ".join(object_name.split()[:-1])+": "+object_name.split( |
|
)[-1]+"%", coordinate] for object_name, coordinate in zip(labels, coordinates)] |
|
|
|
box_annotator = sv.BoxAnnotator() |
|
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) |
|
annotated_frame = box_annotator.annotate( |
|
scene=annotated_frame, detections=detections, labels=labels) |
|
return annotated_frame, object_coordinates |
|
|
|
|
|
BOX_TRESHOLD = 0.35 |
|
TEXT_TRESHOLD = 0.25 |
|
|
|
device = setup_device() |
|
|
|
|
|
def locate_object(objects, image): |
|
image_source, image = load_image(image) |
|
|
|
boxes, logits, phrases = predict( |
|
model=model, |
|
image=image, |
|
device=device, |
|
caption=objects, |
|
box_threshold=BOX_TRESHOLD, |
|
text_threshold=TEXT_TRESHOLD |
|
) |
|
|
|
annotated_frame, object_coordinates = annotate( |
|
image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) |
|
cv2.circle(annotated_frame, object_coordinates[0][1], 2, (255, 0, 0), 2) |
|
cv2.imwrite("detected_objects.png", annotated_frame) |
|
return annotated_frame, object_coordinates |
|
|
|
|
|
if __name__ == "__main__": |
|
frame, coord = locate_object( |
|
"drawer", "https://images.nationalgeographic.org/image/upload/v1638890052/EducationHub/photos/robots-3d-landing-page.jpg") |
|
print(coord) |
|
|