Spaces:
Runtime error
Runtime error
File size: 3,688 Bytes
3dac99f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import copy
import numpy as np
import torch
from fvcore.transforms import HFlipTransform
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from detectron2.data.detection_utils import read_image
from detectron2.modeling import DatasetMapperTTA
__all__ = [
"SemanticSegmentorWithTTA",
]
class SemanticSegmentorWithTTA(nn.Module):
"""
A SemanticSegmentor with test-time augmentation enabled.
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
"""
def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
"""
Args:
cfg (CfgNode):
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
tta_mapper (callable): takes a dataset dict and returns a list of
augmented versions of the dataset dict. Defaults to
`DatasetMapperTTA(cfg)`.
batch_size (int): batch the augmented images into this batch size for inference.
"""
super().__init__()
if isinstance(model, DistributedDataParallel):
model = model.module
self.cfg = cfg.clone()
self.model = model
if tta_mapper is None:
tta_mapper = DatasetMapperTTA(cfg)
self.tta_mapper = tta_mapper
self.batch_size = batch_size
def __call__(self, batched_inputs):
"""
Same input/output format as :meth:`SemanticSegmentor.forward`
"""
def _maybe_read_image(dataset_dict):
ret = copy.copy(dataset_dict)
if "image" not in ret:
image = read_image(ret.pop("file_name"), self.model.input_format)
image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW
ret["image"] = image
if "height" not in ret and "width" not in ret:
ret["height"] = image.shape[1]
ret["width"] = image.shape[2]
return ret
processed_results = []
for x in batched_inputs:
result = self._inference_one_image(_maybe_read_image(x))
processed_results.append(result)
return processed_results
def _inference_one_image(self, input):
"""
Args:
input (dict): one dataset dict with "image" field being a CHW tensor
Returns:
dict: one output dict
"""
orig_shape = (input["height"], input["width"])
augmented_inputs, tfms = self._get_augmented_inputs(input)
final_predictions = None
count_predictions = 0
for input, tfm in zip(augmented_inputs, tfms):
count_predictions += 1
with torch.no_grad():
if final_predictions is None:
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
final_predictions = self.model([input])[0].pop("sem_seg").flip(dims=[2])
else:
final_predictions = self.model([input])[0].pop("sem_seg")
else:
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
final_predictions += self.model([input])[0].pop("sem_seg").flip(dims=[2])
else:
final_predictions += self.model([input])[0].pop("sem_seg")
final_predictions = final_predictions / count_predictions
return {"sem_seg": final_predictions}
def _get_augmented_inputs(self, input):
augmented_inputs = self.tta_mapper(input)
tfms = [x.pop("transforms") for x in augmented_inputs]
return augmented_inputs, tfms
|