Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
def all_to_onehot(masks, labels): | |
if len(masks.shape) == 3: | |
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) | |
else: | |
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) | |
for ni, l in enumerate(labels): | |
Ms[ni] = (masks == l).astype(np.uint8) | |
return Ms | |
class MaskMapper: | |
""" | |
This class is used to convert a indexed-mask to a one-hot representation. | |
It also takes care of remapping non-continuous indices | |
It has two modes: | |
1. Default. Only masks with new indices are supposed to go into the remapper. | |
This is also the case for YouTubeVOS. | |
i.e., regions with index 0 are not "background", but "don't care". | |
2. Exhaustive. Regions with index 0 are considered "background". | |
Every single pixel is considered to be "labeled". | |
""" | |
def __init__(self): | |
self.labels = [] | |
self.remappings = {} | |
# if coherent, no mapping is required | |
self.coherent = True | |
def clear_labels(self): | |
self.labels = [] | |
self.remappings = {} | |
# if coherent, no mapping is required | |
self.coherent = True | |
def convert_mask(self, mask, exhaustive=False): | |
# mask is in index representation, H*W numpy array | |
labels = np.unique(mask).astype(np.uint8) | |
labels = labels[labels!=0].tolist() | |
new_labels = list(set(labels) - set(self.labels)) | |
if not exhaustive: | |
assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' | |
# add new remappings | |
for i, l in enumerate(new_labels): | |
self.remappings[l] = i+len(self.labels)+1 | |
if self.coherent and i+len(self.labels)+1 != l: | |
self.coherent = False | |
if exhaustive: | |
new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) | |
else: | |
if self.coherent: | |
new_mapped_labels = new_labels | |
else: | |
new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) | |
self.labels.extend(new_labels) | |
mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() | |
# mask num_objects*H*W | |
return mask, new_mapped_labels | |
def remap_index_mask(self, mask): | |
# mask is in index representation, H*W numpy array | |
if self.coherent: | |
return mask | |
new_mask = np.zeros_like(mask) | |
for l, i in self.remappings.items(): | |
new_mask[mask==i] = l | |
return new_mask |