PKaushik commited on
Commit
9067b6a
1 Parent(s): f42fe08
Files changed (1) hide show
  1. yolov6/data/vis_dataset.py +58 -0
yolov6/data/vis_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Description: visualize yolo label image.
3
+
4
+ import argparse
5
+ import os
6
+ import cv2
7
+ import numpy as np
8
+
9
+ IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
10
+
11
+
12
+ def main(args):
13
+ img_dir, label_dir, class_names = args.img_dir, args.label_dir, args.class_names
14
+
15
+ label_map = dict()
16
+ for class_id, classname in enumerate(class_names):
17
+ label_map[class_id] = classname
18
+
19
+ for file in os.listdir(img_dir):
20
+ if file.split('.')[-1] not in IMG_FORMATS:
21
+ print(f'[Warning]: Non-image file {file}')
22
+ continue
23
+ img_path = os.path.join(img_dir, file)
24
+ label_path = os.path.join(label_dir, file[: file.rindex('.')] + '.txt')
25
+
26
+ try:
27
+ img_data = cv2.imread(img_path)
28
+ height, width, _ = img_data.shape
29
+ color = [tuple(np.random.choice(range(256), size=3)) for i in class_names]
30
+ thickness = 2
31
+
32
+ with open(label_path, 'r') as f:
33
+ for bbox in f:
34
+ cls, x_c, y_c, w, h = [float(v) if i > 0 else int(v) for i, v in enumerate(bbox.split('\n')[0].split(' '))]
35
+
36
+ x_tl = int((x_c - w / 2) * width)
37
+ y_tl = int((y_c - h / 2) * height)
38
+ cv2.rectangle(img_data, (x_tl, y_tl), (x_tl + int(w * width), y_tl + int(h * height)), tuple([int(x) for x in color[cls]]), thickness)
39
+ cv2.putText(img_data, label_map[cls], (x_tl, y_tl - 10), cv2.FONT_HERSHEY_COMPLEX, 1, tuple([int(x) for x in color[cls]]), thickness)
40
+
41
+ cv2.imshow('image', img_data)
42
+ cv2.waitKey(0)
43
+ except Exception as e:
44
+ print(f'[Error]: {e} {img_path}')
45
+ print('======All Done!======')
46
+
47
+
48
+ if __name__ == '__main__':
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument('--img_dir', default='VOCdevkit/voc_07_12/images')
51
+ parser.add_argument('--label_dir', default='VOCdevkit/voc_07_12/labels')
52
+ parser.add_argument('--class_names', default=['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
53
+ 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])
54
+
55
+ args = parser.parse_args()
56
+ print(args)
57
+
58
+ main(args)