# Copyright (c) Facebook, Inc. and its affiliates. import itertools import warnings from typing import Any, Dict, List, Tuple, Union import torch class Instances: """ This class represents a list of instances in an image. It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields". All fields must have the same ``__len__`` which is the number of instances. All other (non-field) attributes of this class are considered private: they must start with '_' and are not modifiable by a user. Some basic usage: 1. Set/get/check a field: .. code-block:: python instances.gt_boxes = Boxes(...) print(instances.pred_masks) # a tensor of shape (N, H, W) print('gt_masks' in instances) 2. ``len(instances)`` returns the number of instances 3. Indexing: ``instances[indices]`` will apply the indexing on all the fields and returns a new :class:`Instances`. Typically, ``indices`` is a integer vector of indices, or a binary mask of length ``num_instances`` .. code-block:: python category_3_detections = instances[instances.pred_classes == 3] confident_detections = instances[instances.scores > 0.9] """ def __init__(self, image_size: Tuple[int, int], **kwargs: Any): """ Args: image_size (height, width): the spatial size of the image. kwargs: fields to add to this `Instances`. """ self._image_size = image_size self._fields: Dict[str, Any] = {} for k, v in kwargs.items(): self.set(k, v) @property def image_size(self) -> Tuple[int, int]: """ Returns: tuple: height, width """ return self._image_size def __setattr__(self, name: str, val: Any) -> None: if name.startswith("_"): super().__setattr__(name, val) else: self.set(name, val) def __getattr__(self, name: str) -> Any: if name == "_fields" or name not in self._fields: raise AttributeError("Cannot find field '{}' in the given Instances!".format(name)) return self._fields[name] def set(self, name: str, value: Any) -> None: """ Set the field named `name` to `value`. The length of `value` must be the number of instances, and must agree with other existing fields in this object. """ with warnings.catch_warnings(record=True): data_len = len(value) if len(self._fields): assert ( len(self) == data_len ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self)) self._fields[name] = value def has(self, name: str) -> bool: """ Returns: bool: whether the field called `name` exists. """ return name in self._fields def remove(self, name: str) -> None: """ Remove the field called `name`. """ del self._fields[name] def get(self, name: str) -> Any: """ Returns the field called `name`. """ return self._fields[name] def get_fields(self) -> Dict[str, Any]: """ Returns: dict: a dict which maps names (str) to data of the fields Modifying the returned dict will modify this instance. """ return self._fields # Tensor-like methods def to(self, *args: Any, **kwargs: Any) -> "Instances": """ Returns: Instances: all fields are called with a `to(device)`, if the field has this method. """ ret = Instances(self._image_size) for k, v in self._fields.items(): if hasattr(v, "to"): v = v.to(*args, **kwargs) ret.set(k, v) return ret def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances": """ Args: item: an index-like object and will be used to index all the fields. Returns: If `item` is a string, return the data in the corresponding field. Otherwise, returns an `Instances` where all fields are indexed by `item`. """ if type(item) is int: if item >= len(self) or item < -len(self): raise IndexError("Instances index out of range!") else: item = slice(item, None, len(self)) ret = Instances(self._image_size) for k, v in self._fields.items(): ret.set(k, v[item]) return ret def __len__(self) -> int: for v in self._fields.values(): # use __len__ because len() has to be int and is not friendly to tracing return v.__len__() raise NotImplementedError("Empty Instances does not support __len__!") def __iter__(self): raise NotImplementedError("`Instances` object is not iterable!") @staticmethod def cat(instance_lists: List["Instances"]) -> "Instances": """ Args: instance_lists (list[Instances]) Returns: Instances """ assert all(isinstance(i, Instances) for i in instance_lists) assert len(instance_lists) > 0 if len(instance_lists) == 1: return instance_lists[0] image_size = instance_lists[0].image_size if not isinstance(image_size, torch.Tensor): # could be a tensor in tracing for i in instance_lists[1:]: assert i.image_size == image_size ret = Instances(image_size) for k in instance_lists[0]._fields.keys(): values = [i.get(k) for i in instance_lists] v0 = values[0] if isinstance(v0, torch.Tensor): values = torch.cat(values, dim=0) elif isinstance(v0, list): values = list(itertools.chain(*values)) elif hasattr(type(v0), "cat"): values = type(v0).cat(values) else: raise ValueError("Unsupported type {} for concatenation".format(type(v0))) ret.set(k, values) return ret def __str__(self) -> str: s = self.__class__.__name__ + "(" s += "num_instances={}, ".format(len(self)) s += "image_height={}, ".format(self._image_size[0]) s += "image_width={}, ".format(self._image_size[1]) s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items()))) return s __repr__ = __str__