crystal-technologies's picture
Upload 1671 files
38057e4
raw
history blame
3.06 kB
from CircumSpect.object_detection.groundingdino.util.inference import load_model, predict
import CircumSpect.object_detection.groundingdino.datasets.transforms as T
from torchvision.ops import box_convert
from utils import setup_device
from typing import Tuple, List
import supervision as sv
from io import BytesIO
from PIL import Image
import numpy as np
import requests
import torch
import cv2
model = load_model("./CircumSpect/object_detection/groundingdino/config/GroundingDINO_SwinT_OGC.py",
"./CircumSpect/object_detection/weights/groundingdino_swint_ogc.pth")
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(BytesIO(requests.get(
image_path).content) if image_path.startswith("http") else image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(phrases, logits)
]
coordinates = []
for i in list(detections):
left, top, right, bottom = list(list(i)[0])
coordinates.append((int((left+right)/2), int((top+bottom)/2)))
object_coordinates = [[" ".join(object_name.split()[:-1])+": "+object_name.split(
)[-1]+"%", coordinate] for object_name, coordinate in zip(labels, coordinates)]
box_annotator = sv.BoxAnnotator()
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = box_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels)
return annotated_frame, object_coordinates
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
device = setup_device()
def locate_object(objects, image):
image_source, image = load_image(image)
boxes, logits, phrases = predict(
model=model,
image=image,
device=device,
caption=objects,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD
)
annotated_frame, object_coordinates = annotate(
image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
cv2.circle(annotated_frame, object_coordinates[0][1], 2, (255, 0, 0), 2)
cv2.imwrite("detected_objects.png", annotated_frame)
return annotated_frame, object_coordinates
if __name__ == "__main__":
frame, coord = locate_object(
"drawer", "https://images.nationalgeographic.org/image/upload/v1638890052/EducationHub/photos/robots-3d-landing-page.jpg")
print(coord)