Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from pathlib import Path | |
from models.common import DetectMultiBackend | |
from utils.dataloaders import LoadImages | |
from utils.general import (non_max_suppression, scale_boxes, check_img_size) | |
from utils.plots import Annotator, colors | |
from utils.torch_utils import select_device | |
import cv2 | |
def predict_image(image_path, weights=r"yolov9/yolov9_vinbigData.pt", conf_thres=0.25, iou_thres=0.45, output_dir='pages/output_yolov9', device='cpu'): | |
# Load model | |
device = select_device(device) | |
model = DetectMultiBackend(weights, device=device) | |
stride, names, pt = model.stride, model.names, model.pt | |
imgsz = check_img_size((640, 640), s=stride) # Inference size | |
dataset = LoadImages(image_path, img_size=imgsz, stride=stride, auto=pt) | |
model.warmup(imgsz=(1, 3, *imgsz)) # Warmup model | |
for path, im, im0s, _, _ in dataset: | |
im = torch.from_numpy(im).to(model.device) | |
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 | |
im /= 255 # 0 - 255 to 0.0 - 1.0 | |
if len(im.shape) == 3: | |
im = im[None] # Expand for batch dim | |
# Inference | |
pred = model(im) | |
# Nếu `pred` là một danh sách, lấy phần tử đầu tiên | |
if isinstance(pred, list): | |
pred = pred[0] | |
# Thực hiện NMS | |
pred = non_max_suppression(pred, conf_thres, iou_thres, max_det=1000) | |
# Process predictions | |
for i, det in enumerate(pred): # Per image | |
im0 = im0s.copy() | |
annotator = Annotator(im0, line_width=3, example=str(names)) | |
if len(det): | |
# Rescale boxes from img_size to im0 size | |
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() | |
# Draw bounding boxes and labels on image | |
for *xyxy, conf, cls in reversed(det): | |
label = f'{names[int(cls)]} {conf:.2f}' | |
annotator.box_label(xyxy, label, color=colors(int(cls), True)) | |
# Save or display results | |
output_path = os.path.join(output_dir, Path(path).name) | |
os.makedirs(output_dir, exist_ok=True) | |
im0 = annotator.result() | |
cv2.imwrite(output_path, im0) |