# ------------------------------------------------------------------------ # Grounding DINO # url: https://github.com/IDEA-Research/GroundingDINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Conditional DETR model and criterion classes. # Copyright (c) 2021 Microsoft. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) # Copyright (c) 2020 SenseTime. All Rights Reserved. # ------------------------------------------------------------------------ import copy from typing import List import torchvision.transforms.functional as vis_F from torchvision.transforms import InterpolationMode import torch import torch.nn.functional as F from torch import nn from torchvision.ops.boxes import nms from torchvision.ops import roi_align from transformers import ( AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast, ) from groundingdino.util import box_ops, get_tokenlizer from groundingdino.util.misc import ( NestedTensor, accuracy, get_world_size, interpolate, inverse_sigmoid, is_dist_avail_and_initialized, nested_tensor_from_tensor_list, ) from groundingdino.util.utils import get_phrases_from_posmap from groundingdino.util.visualizer import COCOVisualizer from groundingdino.util.vl_utils import create_positive_map_from_span from ..registry import MODULE_BUILD_FUNCS from .backbone import build_backbone from .bertwarper import ( BertModelWarper, generate_masks_with_special_tokens, generate_masks_with_special_tokens_and_transfer_map, ) from .transformer import build_transformer from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss from .matcher import build_matcher import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from groundingdino.util.visualizer import renorm def numpy_2_cv2(np_img): if np.min(np_img) < 0: raise Exception("image min is less than 0. Img min: " + str(np.min(np_img))) if np.max(np_img) > 1: raise Exception("image max is greater than 1. Img max: " + str(np.max(np_img))) np_img = (np_img * 255).astype(np.uint8) # Need to somehow ensure image is in RGB format. Note this line shows up in SAM demo: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) cv2_image = np.asarray(np_img) return cv2_image def vis_exemps(image, exemp, f_name): plt.imshow(image) plt.gca().add_patch( Rectangle( (exemp[0], exemp[1]), exemp[2] - exemp[0], exemp[3] - exemp[1], edgecolor="red", facecolor="none", lw=1, ) ) plt.savefig(f_name) plt.close() class GroundingDINO(nn.Module): """This is the Cross-Attention Detector module that performs object detection""" def __init__( self, backbone, transformer, num_queries, aux_loss=False, iter_update=False, query_dim=2, num_feature_levels=1, nheads=8, # two stage two_stage_type="no", # ['no', 'standard'] dec_pred_bbox_embed_share=True, two_stage_class_embed_share=True, two_stage_bbox_embed_share=True, num_patterns=0, dn_number=100, dn_box_noise_scale=0.4, dn_label_noise_ratio=0.5, dn_labelbook_size=100, text_encoder_type="bert-base-uncased", sub_sentence_present=True, max_text_len=256, ): """Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_queries: number of object queries, ie detection slot. This is the maximal number of objects Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. """ super().__init__() self.num_queries = num_queries self.transformer = transformer self.hidden_dim = hidden_dim = transformer.d_model self.num_feature_levels = num_feature_levels self.nheads = nheads self.max_text_len = max_text_len self.sub_sentence_present = sub_sentence_present # setting query dim self.query_dim = query_dim assert query_dim == 4 # visual exemplar cropping self.feature_map_proj = nn.Conv2d((256 + 512 + 1024), hidden_dim, kernel_size=1) # for dn training self.num_patterns = num_patterns self.dn_number = dn_number self.dn_box_noise_scale = dn_box_noise_scale self.dn_label_noise_ratio = dn_label_noise_ratio self.dn_labelbook_size = dn_labelbook_size # bert self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) self.bert.pooler.dense.weight.requires_grad_(False) self.bert.pooler.dense.bias.requires_grad_(False) self.bert = BertModelWarper(bert_model=self.bert) self.feat_map = nn.Linear( self.bert.config.hidden_size, self.hidden_dim, bias=True ) nn.init.constant_(self.feat_map.bias.data, 0) nn.init.xavier_uniform_(self.feat_map.weight.data) # freeze # special tokens self.specical_tokens = self.tokenizer.convert_tokens_to_ids( ["[CLS]", "[SEP]", ".", "?"] ) # prepare input projection layers if num_feature_levels > 1: num_backbone_outs = len(backbone.num_channels) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( nn.Conv2d( in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 ), nn.GroupNorm(32, hidden_dim), ) ) in_channels = hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: assert ( two_stage_type == "no" ), "two_stage_type should be no if num_feature_levels=1 !!!" self.input_proj = nn.ModuleList( [ nn.Sequential( nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ] ) self.backbone = backbone self.aux_loss = aux_loss self.box_pred_damping = box_pred_damping = None self.iter_update = iter_update assert iter_update, "Why not iter_update?" # prepare pred layers self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # prepare class & box embed _class_embed = ContrastiveEmbed() _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: box_embed_layerlist = [ _bbox_embed for i in range(transformer.num_decoder_layers) ] else: box_embed_layerlist = [ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers) ] class_embed_layerlist = [ _class_embed for i in range(transformer.num_decoder_layers) ] self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.class_embed = nn.ModuleList(class_embed_layerlist) self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.class_embed = self.class_embed # two stage self.two_stage_type = two_stage_type assert two_stage_type in [ "no", "standard", ], "unknown param {} of two_stage_type".format(two_stage_type) if two_stage_type != "no": if two_stage_bbox_embed_share: assert dec_pred_bbox_embed_share self.transformer.enc_out_bbox_embed = _bbox_embed else: self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) if two_stage_class_embed_share: assert dec_pred_bbox_embed_share self.transformer.enc_out_class_embed = _class_embed else: self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) self.refpoint_embed = None self._reset_parameters() def _reset_parameters(self): # init input_proj for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) def add_exemplar_tokens(self, tokenized, text_dict, exemplar_tokens, labels): input_ids = tokenized["input_ids"] device = input_ids.device new_input_ids = [] encoded_text = text_dict["encoded_text"] new_encoded_text = [] text_token_mask = text_dict["text_token_mask"] new_text_token_mask = [] position_ids = text_dict["position_ids"] text_self_attention_masks = text_dict["text_self_attention_masks"] for sample_ind in range(len(labels)): label = labels[sample_ind][0] exemplars = exemplar_tokens[sample_ind] label_count = -1 assert len(input_ids[sample_ind]) == len(position_ids[sample_ind]) for token_ind in range(len(input_ids[sample_ind])): input_id = input_ids[sample_ind][token_ind] if (input_id not in self.specical_tokens) and ( token_ind == 0 or (input_ids[sample_ind][token_ind - 1] in self.specical_tokens) ): label_count += 1 if label_count == label: # Get the index where to insert the exemplar tokens. ind_to_insert_exemplar = token_ind while ( input_ids[sample_ind][ind_to_insert_exemplar] not in self.specical_tokens ): ind_to_insert_exemplar += 1 break # Handle no text case. if label_count == -1: ind_to_insert_exemplar = 1 # * token indicates exemplar. new_input_ids.append( torch.cat( [ input_ids[sample_ind][:ind_to_insert_exemplar], torch.tensor([1008] * exemplars.shape[0]).to(device), input_ids[sample_ind][ind_to_insert_exemplar:], ] ) ) new_encoded_text.append( torch.cat( [ encoded_text[sample_ind][:ind_to_insert_exemplar, :], exemplars, encoded_text[sample_ind][ind_to_insert_exemplar:, :], ] ) ) new_text_token_mask.append( torch.full((len(new_input_ids[sample_ind]),), True).to(device) ) tokenized["input_ids"] = torch.stack(new_input_ids) print(tokenized["input_ids"]) ( text_self_attention_masks, position_ids, _, ) = generate_masks_with_special_tokens_and_transfer_map( tokenized, self.specical_tokens, None ) return { "encoded_text": torch.stack(new_encoded_text), "text_token_mask": torch.stack(new_text_token_mask), "position_ids": position_ids, "text_self_attention_masks": text_self_attention_masks, } def combine_features(self, features): (bs, c, h, w) = ( features[0].decompose()[0].shape[-4], features[0].decompose()[0].shape[-3], features[0].decompose()[0].shape[-2], features[0].decompose()[0].shape[-1], ) x = torch.cat( [ F.interpolate( feat.decompose()[0], size=(h, w), mode="bilinear", align_corners=True, ) for feat in features ], dim=1, ) x = self.feature_map_proj(x) return x def forward( self, samples: NestedTensor, exemplar_images: NestedTensor, exemplars: List, labels, targets: List = None, cropped=False, orig_img=None, crop_width=0, crop_height=0, **kw, ): """The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all queries. Shape= [batch_size x num_queries x num_classes] - "pred_boxes": The normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ if targets is None: captions = kw["captions"] else: captions = [t["caption"] for t in targets] # encoder texts tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( samples.device ) one_hot_token = tokenized ( text_self_attention_masks, position_ids, cate_to_token_mask_list, ) = generate_masks_with_special_tokens_and_transfer_map( tokenized, self.specical_tokens, self.tokenizer ) if text_self_attention_masks.shape[1] > self.max_text_len: text_self_attention_masks = text_self_attention_masks[ :, : self.max_text_len, : self.max_text_len ] position_ids = position_ids[:, : self.max_text_len] tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] tokenized["attention_mask"] = tokenized["attention_mask"][ :, : self.max_text_len ] tokenized["token_type_ids"] = tokenized["token_type_ids"][ :, : self.max_text_len ] # extract text embeddings if self.sub_sentence_present: tokenized_for_encoder = { k: v for k, v in tokenized.items() if k != "attention_mask" } tokenized_for_encoder["attention_mask"] = text_self_attention_masks tokenized_for_encoder["position_ids"] = position_ids else: tokenized_for_encoder = tokenized bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 encoded_text = self.feat_map( bert_output["last_hidden_state"] ) # bs, 195, d_model text_token_mask = tokenized.attention_mask.bool() # bs, 195 # text_token_mask: True for nomask, False for mask # text_self_attention_masks: True for nomask, False for mask if encoded_text.shape[1] > self.max_text_len: encoded_text = encoded_text[:, : self.max_text_len, :] text_token_mask = text_token_mask[:, : self.max_text_len] position_ids = position_ids[:, : self.max_text_len] text_self_attention_masks = text_self_attention_masks[ :, : self.max_text_len, : self.max_text_len ] text_dict = { "encoded_text": encoded_text, # bs, 195, d_model "text_token_mask": text_token_mask, # bs, 195 "position_ids": position_ids, # bs, 195 "text_self_attention_masks": text_self_attention_masks, # bs, 195,195 } if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) if not cropped: features, poss = self.backbone(samples) features_exemp, _ = self.backbone(exemplar_images) combined_features = self.combine_features(features_exemp) # Get visual exemplar tokens. bs = len(exemplars) num_exemplars = exemplars[0].shape[0] print(exemplars) print(num_exemplars) if num_exemplars > 0: exemplar_tokens = ( roi_align( combined_features, boxes=exemplars, output_size=(1, 1), spatial_scale=(1 / 8), aligned=True, ) .squeeze(-1) .squeeze(-1) .reshape(bs, num_exemplars, -1) ) else: exemplar_tokens = None else: features, poss = self.backbone(samples) (h, w) = ( samples.decompose()[0][0].shape[1], samples.decompose()[0][0].shape[2], ) (orig_img_h, orig_img_w) = orig_img.shape[1], orig_img.shape[2] bs = len(samples.decompose()[0]) exemp_imgs = [] new_exemplars = [] ind = 0 for exemp in exemplars[0]: center_x = (exemp[0] + exemp[2]) / 2 center_y = (exemp[1] + exemp[3]) / 2 start_x = max(int(center_x - crop_width / 2), 0) end_x = min(int(center_x + crop_width / 2), orig_img_w) start_y = max(int(center_y - crop_height / 2), 0) end_y = min(int(center_y + crop_height / 2), orig_img_h) scale_x = w / (end_x - start_x) scale_y = h / (end_y - start_y) exemp_imgs.append( vis_F.resize( orig_img[:, start_y:end_y, start_x:end_x], (h, w), interpolation=InterpolationMode.BICUBIC, ) ) new_exemplars.append( [ (exemp[0] - start_x) * scale_x, (exemp[1] - start_y) * scale_y, (exemp[2] - start_x) * scale_x, (exemp[3] - start_y) * scale_y, ] ) vis_exemps( renorm(exemp_imgs[-1].cpu()).permute(1, 2, 0).numpy(), [coord.item() for coord in new_exemplars[-1]], str(ind) + ".jpg", ) vis_exemps( renorm(orig_img.cpu()).permute(1, 2, 0).numpy(), [coord.item() for coord in exemplars[0][ind]], "orig-" + str(ind) + ".jpg", ) ind += 1 exemp_imgs = nested_tensor_from_tensor_list(exemp_imgs) features_exemp, _ = self.backbone(exemp_imgs) combined_features = self.combine_features(features_exemp) new_exemplars = [ torch.tensor(exemp).unsqueeze(0).cuda() for exemp in new_exemplars ] # Get visual exemplar tokens. exemplar_tokens = ( roi_align( combined_features, boxes=new_exemplars, output_size=(1, 1), spatial_scale=(1 / 8), aligned=True, ) .squeeze(-1) .squeeze(-1) .reshape(3, 256) ) exemplar_tokens = torch.stack([exemplar_tokens] * bs) if exemplar_tokens is not None: text_dict = self.add_exemplar_tokens( tokenized, text_dict, exemplar_tokens, labels ) srcs = [] masks = [] for l, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to( torch.bool )[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) poss.append(pos_l) input_query_bbox = input_query_label = attn_mask = dn_meta = None hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict ) # deformable-detr-like anchor update outputs_coord_list = [] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( zip(reference[:-1], self.bbox_embed, hs) ): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) outputs_class = torch.stack( [ layer_cls_embed(layer_hs, text_dict) for layer_cls_embed, layer_hs in zip(self.class_embed, hs) ] ) out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]} # Used to calculate losses bs, len_td = text_dict["text_token_mask"].shape out["text_mask"] = torch.zeros(bs, self.max_text_len, dtype=torch.bool).to( samples.device ) for b in range(bs): for j in range(len_td): if text_dict["text_token_mask"][b][j] == True: out["text_mask"][b][j] = True # for intermediate outputs if self.aux_loss: out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord_list) out["token"] = one_hot_token # # for encoder output if hs_enc is not None: # prepare intermediate outputs interm_coord = ref_enc[-1] interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict) out["interm_outputs"] = { "pred_logits": interm_class, "pred_boxes": interm_coord, } out["interm_outputs_for_matching_pre"] = { "pred_logits": interm_class, "pred_boxes": init_box_proposal, } # outputs['pred_logits'].shape # torch.Size([4, 900, 256]) # outputs['pred_boxes'].shape # torch.Size([4, 900, 4]) # outputs['text_mask'].shape # torch.Size([256]) # outputs['text_mask'] # outputs['aux_outputs'][0].keys() # dict_keys(['pred_logits', 'pred_boxes', 'one_hot', 'text_mask']) # outputs['aux_outputs'][img_idx] # outputs['token'] # # outputs['interm_outputs'].keys() # dict_keys(['pred_logits', 'pred_boxes', 'one_hot', 'text_mask']) # outputs['interm_outputs_for_matching_pre'].keys() # dict_keys(['pred_logits', 'pred_boxes']) # outputs['one_hot'].shape # torch.Size([4, 900, 256]) return out @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. return [ {"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) ] class SetCriterion(nn.Module): def __init__(self, matcher, weight_dict, focal_alpha, focal_gamma, losses): """Create the criterion. Parameters: matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. losses: list of all the losses to be applied. See get_loss for list of available losses. focal_alpha: alpha in Focal Loss """ super().__init__() self.matcher = matcher self.weight_dict = weight_dict self.losses = losses self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma @torch.no_grad() def loss_cardinality(self, outputs, targets, indices, num_boxes): """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients """ pred_logits = outputs["pred_logits"] device = pred_logits.device tgt_lengths = torch.as_tensor( [len(v["labels"]) for v in targets], device=device ) # Count the number of predictions that are NOT "no-object" (which is the last class) card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) losses = {"cardinality_error": card_err} return losses def loss_boxes(self, outputs, targets, indices, num_boxes): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ assert "pred_boxes" in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs["pred_boxes"][idx] target_boxes = torch.cat( [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0 ) loss_bbox = F.l1_loss(src_boxes[:, :2], target_boxes[:, :2], reduction="none") losses = {} losses["loss_bbox"] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes), ) ) losses["loss_giou"] = loss_giou.sum() / num_boxes # calculate the x,y and h,w loss with torch.no_grad(): losses["loss_xy"] = loss_bbox[..., :2].sum() / num_boxes losses["loss_hw"] = loss_bbox[..., 2:].sum() / num_boxes return losses def token_sigmoid_binary_focal_loss(self, outputs, targets, indices, num_boxes): pred_logits = outputs["pred_logits"] new_targets = outputs["one_hot"].to(pred_logits.device) text_mask = outputs["text_mask"] assert new_targets.dim() == 3 assert pred_logits.dim() == 3 # batch x from x to bs, n, _ = pred_logits.shape alpha = self.focal_alpha gamma = self.focal_gamma if text_mask is not None: # ODVG: each sample has different mask text_mask = text_mask.repeat(1, pred_logits.size(1)).view( outputs["text_mask"].shape[0], -1, outputs["text_mask"].shape[1] ) pred_logits = torch.masked_select(pred_logits, text_mask) new_targets = torch.masked_select(new_targets, text_mask) new_targets = new_targets.float() p = torch.sigmoid(pred_logits) ce_loss = F.binary_cross_entropy_with_logits( pred_logits, new_targets, reduction="none" ) p_t = p * new_targets + (1 - p) * (1 - new_targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * new_targets + (1 - alpha) * (1 - new_targets) loss = alpha_t * loss total_num_pos = 0 for batch_indices in indices: total_num_pos += len(batch_indices[0]) num_pos_avg_per_gpu = max(total_num_pos, 1.0) loss = loss.sum() / num_pos_avg_per_gpu losses = {"loss_ce": loss} return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat( [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] ) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat( [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] ) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): loss_map = { "labels": self.token_sigmoid_binary_focal_loss, "cardinality": self.loss_cardinality, "boxes": self.loss_boxes, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) def forward(self, outputs, targets, cat_list, caption, return_indices=False): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc return_indices: used for vis. if True, the layer0-5 indices will be returned as well. """ device = next(iter(outputs.values())).device one_hot = torch.zeros( outputs["pred_logits"].size(), dtype=torch.int64 ) # torch.Size([bs, 900, 256]) token = outputs["token"] label_map_list = [] indices = [] for j in range(len(cat_list)): # bs label_map = [] for i in range(len(cat_list[j])): label_id = torch.tensor([i]) per_label = create_positive_map_exemplar( token["input_ids"][j], label_id, [101, 102, 1012, 1029] ) label_map.append(per_label) label_map = torch.stack(label_map, dim=0).squeeze(1) label_map_list.append(label_map) for j in range(len(cat_list)): # bs for_match = { "pred_logits": outputs["pred_logits"][j].unsqueeze(0), "pred_boxes": outputs["pred_boxes"][j].unsqueeze(0), } inds = self.matcher(for_match, [targets[j]], label_map_list[j]) indices.extend(inds) # indices : A list of size batch_size, containing tuples of (index_i, index_j) where: # - index_i is the indices of the selected predictions (in order) # - index_j is the indices of the corresponding selected targets (in order) # import pdb; pdb.set_trace() tgt_ids = [v["labels"].cpu() for v in targets] # len(tgt_ids) == bs for i in range(len(indices)): tgt_ids[i] = tgt_ids[i][indices[i][1]] one_hot[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to(torch.long) outputs["one_hot"] = one_hot if return_indices: indices0_copy = indices indices_list = [] # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes_list = [len(t["labels"]) for t in targets] num_boxes = sum(num_boxes_list) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "aux_outputs" in outputs: for idx, aux_outputs in enumerate(outputs["aux_outputs"]): indices = [] for j in range(len(cat_list)): # bs aux_output_single = { "pred_logits": aux_outputs["pred_logits"][j].unsqueeze(0), "pred_boxes": aux_outputs["pred_boxes"][j].unsqueeze(0), } inds = self.matcher( aux_output_single, [targets[j]], label_map_list[j] ) indices.extend(inds) one_hot_aux = torch.zeros( outputs["pred_logits"].size(), dtype=torch.int64 ) tgt_ids = [v["labels"].cpu() for v in targets] for i in range(len(indices)): tgt_ids[i] = tgt_ids[i][indices[i][1]] one_hot_aux[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to( torch.long ) aux_outputs["one_hot"] = one_hot_aux aux_outputs["text_mask"] = outputs["text_mask"] if return_indices: indices_list.append(indices) for loss in self.losses: kwargs = {} l_dict = self.get_loss( loss, aux_outputs, targets, indices, num_boxes, **kwargs ) l_dict = {k + f"_{idx}": v for k, v in l_dict.items()} losses.update(l_dict) # interm_outputs loss if "interm_outputs" in outputs: interm_outputs = outputs["interm_outputs"] indices = [] for j in range(len(cat_list)): # bs interm_output_single = { "pred_logits": interm_outputs["pred_logits"][j].unsqueeze(0), "pred_boxes": interm_outputs["pred_boxes"][j].unsqueeze(0), } inds = self.matcher( interm_output_single, [targets[j]], label_map_list[j] ) indices.extend(inds) one_hot_aux = torch.zeros(outputs["pred_logits"].size(), dtype=torch.int64) tgt_ids = [v["labels"].cpu() for v in targets] for i in range(len(indices)): tgt_ids[i] = tgt_ids[i][indices[i][1]] one_hot_aux[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to( torch.long ) interm_outputs["one_hot"] = one_hot_aux interm_outputs["text_mask"] = outputs["text_mask"] if return_indices: indices_list.append(indices) for loss in self.losses: kwargs = {} l_dict = self.get_loss( loss, interm_outputs, targets, indices, num_boxes, **kwargs ) l_dict = {k + f"_interm": v for k, v in l_dict.items()} losses.update(l_dict) if return_indices: indices_list.append(indices0_copy) return losses, indices_list return losses class PostProcess(nn.Module): """This module converts the model's output into the format expected by the coco api""" def __init__( self, num_select=100, text_encoder_type="text_encoder_type", nms_iou_threshold=-1, use_coco_eval=False, args=None, ) -> None: super().__init__() self.num_select = num_select self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) if args.use_coco_eval: from pycocotools.coco import COCO coco = COCO(args.coco_val_path) category_dict = coco.loadCats(coco.getCatIds()) cat_list = [item["name"] for item in category_dict] else: cat_list = args.label_list caption = " . ".join(cat_list) + " ." tokenized = self.tokenizer(caption, padding="longest", return_tensors="pt") label_list = torch.arange(len(cat_list)) pos_map = create_positive_map(tokenized, label_list, cat_list, caption) # build a mapping from label_id to pos_map if args.use_coco_eval: id_map = { 0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 13, 12: 14, 13: 15, 14: 16, 15: 17, 16: 18, 17: 19, 18: 20, 19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 27, 25: 28, 26: 31, 27: 32, 28: 33, 29: 34, 30: 35, 31: 36, 32: 37, 33: 38, 34: 39, 35: 40, 36: 41, 37: 42, 38: 43, 39: 44, 40: 46, 41: 47, 42: 48, 43: 49, 44: 50, 45: 51, 46: 52, 47: 53, 48: 54, 49: 55, 50: 56, 51: 57, 52: 58, 53: 59, 54: 60, 55: 61, 56: 62, 57: 63, 58: 64, 59: 65, 60: 67, 61: 70, 62: 72, 63: 73, 64: 74, 65: 75, 66: 76, 67: 77, 68: 78, 69: 79, 70: 80, 71: 81, 72: 82, 73: 84, 74: 85, 75: 86, 76: 87, 77: 88, 78: 89, 79: 90, } new_pos_map = torch.zeros((91, 256)) for k, v in id_map.items(): new_pos_map[v] = pos_map[k] pos_map = new_pos_map self.nms_iou_threshold = nms_iou_threshold self.positive_map = pos_map @torch.no_grad() def forward(self, outputs, target_sizes, not_to_xyxy=False, test=False): """Perform the computation Parameters: outputs: raw outputs of the model target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch For evaluation, this must be the original image size (before any data augmentation) For visualization, this should be the image size after data augment, but before padding """ num_select = self.num_select out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] prob_to_token = out_logits.sigmoid() pos_maps = self.positive_map.to(prob_to_token.device) for label_ind in range(len(pos_maps)): if pos_maps[label_ind].sum() != 0: pos_maps[label_ind] = pos_maps[label_ind] / pos_maps[label_ind].sum() prob_to_label = prob_to_token @ pos_maps.T assert len(out_logits) == len(target_sizes) assert target_sizes.shape[1] == 2 prob = prob_to_label topk_values, topk_indexes = torch.topk( prob.view(prob.shape[0], -1), num_select, dim=1 ) scores = topk_values topk_boxes = torch.div(topk_indexes, prob.shape[2], rounding_mode="trunc") labels = topk_indexes % prob.shape[2] if not_to_xyxy: boxes = out_bbox else: boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) # if test: # assert not not_to_xyxy # boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2] boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) # and from relative [0, 1] to absolute [0, height] coordinates img_h, img_w = target_sizes.unbind(1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) boxes = boxes * scale_fct[:, None, :] if self.nms_iou_threshold > 0: item_indices = [ nms(b, s, iou_threshold=self.nms_iou_threshold) for b, s in zip(boxes, scores) ] results = [ {"scores": s[i], "labels": l[i], "boxes": b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices) ] else: results = [ {"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes) ] results = [ {"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes) ] return results @MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino") def build_groundingdino(args): device = torch.device(args.device) backbone = build_backbone(args) transformer = build_transformer(args) dn_labelbook_size = args.dn_labelbook_size dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share sub_sentence_present = args.sub_sentence_present model = GroundingDINO( backbone, transformer, num_queries=args.num_queries, aux_loss=args.aux_loss, iter_update=True, query_dim=4, num_feature_levels=args.num_feature_levels, nheads=args.nheads, dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, two_stage_type=args.two_stage_type, two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, two_stage_class_embed_share=args.two_stage_class_embed_share, num_patterns=args.num_patterns, dn_number=0, dn_box_noise_scale=args.dn_box_noise_scale, dn_label_noise_ratio=args.dn_label_noise_ratio, dn_labelbook_size=dn_labelbook_size, text_encoder_type=args.text_encoder_type, sub_sentence_present=sub_sentence_present, max_text_len=args.max_text_len, ) matcher = build_matcher(args) # prepare weight dict weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef} weight_dict["loss_giou"] = args.giou_loss_coef clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) clean_weight_dict = copy.deepcopy(weight_dict) # TODO this is a hack if args.aux_loss: aux_weight_dict = {} for i in range(args.dec_layers - 1): aux_weight_dict.update( {k + f"_{i}": v for k, v in clean_weight_dict.items()} ) weight_dict.update(aux_weight_dict) if args.two_stage_type != "no": interm_weight_dict = {} try: no_interm_box_loss = args.no_interm_box_loss except: no_interm_box_loss = False _coeff_weight_dict = { "loss_ce": 1.0, "loss_bbox": 1.0 if not no_interm_box_loss else 0.0, "loss_giou": 1.0 if not no_interm_box_loss else 0.0, } try: interm_loss_coef = args.interm_loss_coef except: interm_loss_coef = 1.0 interm_weight_dict.update( { k + f"_interm": v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items() } ) weight_dict.update(interm_weight_dict) # losses = ['labels', 'boxes', 'cardinality'] losses = ["labels", "boxes"] criterion = SetCriterion( matcher=matcher, weight_dict=weight_dict, focal_alpha=args.focal_alpha, focal_gamma=args.focal_gamma, losses=losses, ) criterion.to(device) postprocessors = { "bbox": PostProcess( num_select=args.num_select, text_encoder_type=args.text_encoder_type, nms_iou_threshold=args.nms_iou_threshold, args=args, ) } return model, criterion, postprocessors def create_positive_map(tokenized, tokens_positive, cat_list, caption): """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) for j, label in enumerate(tokens_positive): start_ind = caption.find(cat_list[label]) end_ind = start_ind + len(cat_list[label]) - 1 beg_pos = tokenized.char_to_token(start_ind) try: end_pos = tokenized.char_to_token(end_ind) except: end_pos = None if end_pos is None: try: end_pos = tokenized.char_to_token(end_ind - 1) if end_pos is None: end_pos = tokenized.char_to_token(end_ind - 2) except: end_pos = None # except Exception as e: # print("beg:", beg, "end:", end) # print("token_positive:", tokens_positive) # # print("beg_pos:", beg_pos, "end_pos:", end_pos) # raise e # if beg_pos is None: # try: # beg_pos = tokenized.char_to_token(beg + 1) # if beg_pos is None: # beg_pos = tokenized.char_to_token(beg + 2) # except: # beg_pos = None # if end_pos is None: # try: # end_pos = tokenized.char_to_token(end - 2) # if end_pos is None: # end_pos = tokenized.char_to_token(end - 3) # except: # end_pos = None if beg_pos is None or end_pos is None: continue if beg_pos < 0 or end_pos < 0: continue if beg_pos > end_pos: continue # assert beg_pos is not None and end_pos is not None positive_map[j, beg_pos : end_pos + 1].fill_(1) return positive_map def create_positive_map_exemplar(input_ids, label, special_tokens): tokens_positive = torch.zeros(256, dtype=torch.float) count = -1 for token_ind in range(len(input_ids)): input_id = input_ids[token_ind] if (input_id not in special_tokens) and ( token_ind == 0 or (input_ids[token_ind - 1] in special_tokens) ): count += 1 if count == label: ind_to_insert_ones = token_ind while input_ids[ind_to_insert_ones] not in special_tokens: tokens_positive[ind_to_insert_ones] = 1 ind_to_insert_ones += 1 break return tokens_positive