|
import os.path as osp |
|
import xml.etree.ElementTree as ET |
|
|
|
import mmcv |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from .builder import DATASETS |
|
from .custom import CustomDataset |
|
|
|
|
|
@DATASETS.register_module() |
|
class XMLDataset(CustomDataset): |
|
"""XML dataset for detection. |
|
|
|
Args: |
|
min_size (int | float, optional): The minimum size of bounding |
|
boxes in the images. If the size of a bounding box is less than |
|
``min_size``, it would be add to ignored field. |
|
""" |
|
|
|
def __init__(self, min_size=None, **kwargs): |
|
assert self.CLASSES or kwargs.get( |
|
'classes', None), 'CLASSES in `XMLDataset` can not be None.' |
|
super(XMLDataset, self).__init__(**kwargs) |
|
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} |
|
self.min_size = min_size |
|
|
|
def load_annotations(self, ann_file): |
|
"""Load annotation from XML style ann_file. |
|
|
|
Args: |
|
ann_file (str): Path of XML file. |
|
|
|
Returns: |
|
list[dict]: Annotation info from XML file. |
|
""" |
|
|
|
data_infos = [] |
|
img_ids = mmcv.list_from_file(ann_file) |
|
for img_id in img_ids: |
|
filename = f'JPEGImages/{img_id}.jpg' |
|
xml_path = osp.join(self.img_prefix, 'Annotations', |
|
f'{img_id}.xml') |
|
tree = ET.parse(xml_path) |
|
root = tree.getroot() |
|
size = root.find('size') |
|
if size is not None: |
|
width = int(size.find('width').text) |
|
height = int(size.find('height').text) |
|
else: |
|
img_path = osp.join(self.img_prefix, 'JPEGImages', |
|
'{}.jpg'.format(img_id)) |
|
img = Image.open(img_path) |
|
width, height = img.size |
|
data_infos.append( |
|
dict(id=img_id, filename=filename, width=width, height=height)) |
|
|
|
return data_infos |
|
|
|
def _filter_imgs(self, min_size=32): |
|
"""Filter images too small or without annotation.""" |
|
valid_inds = [] |
|
for i, img_info in enumerate(self.data_infos): |
|
if min(img_info['width'], img_info['height']) < min_size: |
|
continue |
|
if self.filter_empty_gt: |
|
img_id = img_info['id'] |
|
xml_path = osp.join(self.img_prefix, 'Annotations', |
|
f'{img_id}.xml') |
|
tree = ET.parse(xml_path) |
|
root = tree.getroot() |
|
for obj in root.findall('object'): |
|
name = obj.find('name').text |
|
if name in self.CLASSES: |
|
valid_inds.append(i) |
|
break |
|
else: |
|
valid_inds.append(i) |
|
return valid_inds |
|
|
|
def get_ann_info(self, idx): |
|
"""Get annotation from XML file by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Annotation info of specified index. |
|
""" |
|
|
|
img_id = self.data_infos[idx]['id'] |
|
xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') |
|
tree = ET.parse(xml_path) |
|
root = tree.getroot() |
|
bboxes = [] |
|
labels = [] |
|
bboxes_ignore = [] |
|
labels_ignore = [] |
|
for obj in root.findall('object'): |
|
name = obj.find('name').text |
|
if name not in self.CLASSES: |
|
continue |
|
label = self.cat2label[name] |
|
difficult = obj.find('difficult') |
|
difficult = 0 if difficult is None else int(difficult.text) |
|
bnd_box = obj.find('bndbox') |
|
|
|
|
|
bbox = [ |
|
int(float(bnd_box.find('xmin').text)), |
|
int(float(bnd_box.find('ymin').text)), |
|
int(float(bnd_box.find('xmax').text)), |
|
int(float(bnd_box.find('ymax').text)) |
|
] |
|
ignore = False |
|
if self.min_size: |
|
assert not self.test_mode |
|
w = bbox[2] - bbox[0] |
|
h = bbox[3] - bbox[1] |
|
if w < self.min_size or h < self.min_size: |
|
ignore = True |
|
if difficult or ignore: |
|
bboxes_ignore.append(bbox) |
|
labels_ignore.append(label) |
|
else: |
|
bboxes.append(bbox) |
|
labels.append(label) |
|
if not bboxes: |
|
bboxes = np.zeros((0, 4)) |
|
labels = np.zeros((0, )) |
|
else: |
|
bboxes = np.array(bboxes, ndmin=2) - 1 |
|
labels = np.array(labels) |
|
if not bboxes_ignore: |
|
bboxes_ignore = np.zeros((0, 4)) |
|
labels_ignore = np.zeros((0, )) |
|
else: |
|
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 |
|
labels_ignore = np.array(labels_ignore) |
|
ann = dict( |
|
bboxes=bboxes.astype(np.float32), |
|
labels=labels.astype(np.int64), |
|
bboxes_ignore=bboxes_ignore.astype(np.float32), |
|
labels_ignore=labels_ignore.astype(np.int64)) |
|
return ann |
|
|
|
def get_cat_ids(self, idx): |
|
"""Get category ids in XML file by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
list[int]: All categories in the image of specified index. |
|
""" |
|
|
|
cat_ids = [] |
|
img_id = self.data_infos[idx]['id'] |
|
xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') |
|
tree = ET.parse(xml_path) |
|
root = tree.getroot() |
|
for obj in root.findall('object'): |
|
name = obj.find('name').text |
|
if name not in self.CLASSES: |
|
continue |
|
label = self.cat2label[name] |
|
cat_ids.append(label) |
|
|
|
return cat_ids |
|
|