xichen98cn's picture
Upload 135 files
3dac99f verified
raw
history blame
3.78 kB
from typing import Dict
from torch import nn
from detectron2.config import configurable
from detectron2.layers import ShapeSpec
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from ..transformer_decoder.frozenseg_transformer_decoder import build_transformer_decoder
from ..pixel_decoder.msdeformattn import build_pixel_decoder
@SEM_SEG_HEADS_REGISTRY.register()
class FrozenSegHead(nn.Module):
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
num_classes: int,
pixel_decoder: nn.Module,
loss_weight: float = 1.0,
ignore_value: int = -1,
# extra parameters
transformer_predictor: nn.Module,
transformer_in_feature: str,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
transformer_predictor: the transformer decoder that makes prediction
transformer_in_feature: input feature name to the transformer_predictor
"""
super().__init__()
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
self.in_features = [k for k, v in input_shape]
feature_strides = [v.stride for k, v in input_shape]
feature_channels = [v.channels for k, v in input_shape]
self.ignore_value = ignore_value
self.common_stride = 4
self.loss_weight = loss_weight
self.pixel_decoder = pixel_decoder
self.predictor = transformer_predictor
self.transformer_in_feature = transformer_in_feature
self.num_classes = num_classes
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
# figure out in_channels to transformer predictor
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "multi_scale_pixel_decoder":
transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
else:
raise NotImplementedError
return {
"input_shape": {
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
},
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
"transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
"transformer_predictor": build_transformer_decoder(
cfg,
transformer_predictor_in_channels,
mask_classification=True,
),
}
def forward(self, features, mask=None):
return self.layers(features, mask)
def layers(self, features, mask=None):
mask_features, multi_scale_features, sam_fpn = self.pixel_decoder.forward_features(features)
if self.transformer_in_feature == "multi_scale_pixel_decoder":
predictions = self.predictor(multi_scale_features, mask_features, mask,
text_classifier=features["text_classifier"],
num_templates=features["num_templates"],
sam_embedding=features['sam_embedding'],
sam=features["sam"],
sam_fpn=sam_fpn)
else:
raise NotImplementedError
return predictions