fffiloni's picture
Migrated from GitHub
d59f323 verified
raw
history blame
5.6 kB
import os.path
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from mmengine.model import BaseModule
from vlm.utils import load_checkpoint_with_prefix, load_state_dict_to_model
BASE_DIR = 'pretrained/'
class SAM2TrainRunner(BaseModule):
def __init__(
self,
cfg_path: str = "sam2_hiera_l.yaml",
ckpt_path: str = "sam2_hiera_large.pt",
hydra_overrides_extra=None,
apply_postprocessing=True,
):
super().__init__(init_cfg=None)
import third_parts.sam2 # noqa: F401
if hydra_overrides_extra is None:
hydra_overrides_extra = []
hydra_overrides = [
## Extension: LLM prompt
"++model._target_=projects.llava_sam2.models.extension.SAM2Base",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
# "++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
# "++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=cfg_path, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
sam2_model = instantiate(cfg.model, _recursive_=True)
state_dict = load_checkpoint_with_prefix(os.path.join(BASE_DIR, ckpt_path))
load_state_dict_to_model(sam2_model, state_dict)
self.sam2_model = sam2_model
self.hidden_dim = self.sam2_model.hidden_dim
self.img_mean = (0.485, 0.456, 0.406)
self.img_std = (0.229, 0.224, 0.225)
def preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
image = image / 255.
img_mean = torch.tensor(self.img_mean, dtype=image.dtype, device=image.device)[:, None, None]
img_std = torch.tensor(self.img_std, dtype=image.dtype, device=image.device)[:, None, None]
image -= img_mean
image /= img_std
return image
def inject_language_embd(self, sam_states, language_embd, nf_nobj=None):
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(sam_states['current_vision_feats'][:-1], sam_states['feat_sizes'][:-1])
]
B = sam_states['current_vision_feats'][-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = sam_states['feat_sizes'][-1]
if self.sam2_model.directly_add_no_mem_embed:
# directly add no-mem embedding (instead of using the transformer encoder)
pix_feat_with_mem = sam_states['current_vision_feats'][-1] + self.sam2_model.no_mem_embed
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
else:
raise NotImplementedError("directly add no memory embedding is not implemented")
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
_, _, _, low_res_masks, high_res_masks, obj_ptr, _, = self.sam2_model._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=None,
mask_inputs=None,
high_res_features=high_res_features,
multimask_output=self.sam2_model._use_multimask(is_init_cond_frame=True, point_inputs=None),
# Inject language Embed if possible
language_embd=language_embd,
)
if nf_nobj is not None:
pred_masks = low_res_masks.squeeze(1)
pred_masks = pred_masks.unflatten(0, nf_nobj)
else:
pred_masks = low_res_masks
return pred_masks
def get_sam2_embeddings(self, images, expand_size=1):
# Step 1: inference the backbone with the images
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
feats = self.sam2_model.forward_image(images)
if expand_size > 1:
# feats['vision_features'] = feats['vision_features'][:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
for i, feat in enumerate(feats["backbone_fpn"]):
feats["backbone_fpn"][i] = feat[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
for i, pos in enumerate(feats["vision_pos_enc"]):
pos = pos[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
feats["vision_pos_enc"][i] = pos
# Step 2: Process the features to output
_, current_vision_feats, current_vision_pos_embeds, feat_sizes = self.sam2_model._prepare_backbone_features(feats)
return {
"current_vision_feats": current_vision_feats,
"current_vision_pos_embeds": current_vision_pos_embeds,
"feat_sizes": feat_sizes,
}
def forward(self, batch):
raise NotImplementedError