import torch import torchvision from einops import rearrange from torch import nn from yolox.models.yolo_head import YOLOXHead from yolox.utils.boxes import xyxy2cxcywh, cxcywh2xyxy from yolox.utils.demo_utils import nms # import matplotlib.pyplot as plt # import seaborn as sns import numpy as np import logging from open_flamingo.src.gcn import GCN from transformers import LogitsProcessorList logging.basicConfig( level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%m/%d %I:%M:%S', ) # class PositionEncodingModule(nn.Module): # def __init__(self, dim, pos_dim=128): # super().__init__() # self.encode = nn.Sequential( # nn.Linear(5, pos_dim // 2), # nn.BatchNorm1d(pos_dim // 2), # nn.GELU(), # nn.Linear(pos_dim // 2, pos_dim), # nn.BatchNorm1d(pos_dim), # nn.GELU(), # ) # self.merge = nn.Sequential( # nn.Linear(dim + pos_dim, dim), # nn.BatchNorm1d(dim), # nn.GELU(), # ) # def forward(self, x, box): # box = self.encode(box) # x = torch.cat([x, box], dim=-1) # x = self.merge(x) # return x # class PositionEncodingModule(nn.Module): # def __init__(self, dim): # super().__init__() # self.encode = nn.Sequential( # nn.Linear(5, dim), # nn.GELU(), # ) # def forward(self, x, box): # box = self.encode(box) # x = x + box # return x # class PositionEncodingModule2(nn.Module): # def __init__(self, dim): # super().__init__() # self.encode = nn.Sequential( # nn.Linear(5 + dim, dim), # nn.ELU(), # ) # def forward(self, x, box): # x = torch.cat([x, box], dim=-1) # x = self.encode(x) # return x # class RelationHead(nn.Module): # def __init__(self, dim): # super().__init__() # self.encode = nn.Sequential( # nn.LayerNorm(dim), # nn.Linear(dim, 128), # nn.ELU(), # ) # self.classifier = nn.Linear(256, 51) # def forward(self, x1, x2): # x1 = self.encode(x1) # x2 = self.encode(x2) # x = torch.cat([x1, x2], dim=-1) # x = self.classifier(x) # return x class Flamingo(nn.Module): def __init__( self, vision_encoder: nn.Module, lang_encoder: nn.Module, eoc_token_id: int, media_token_id: int, image_end_token_id: int, visual_token_id: int, previsual_token_id: int, box_token_id: int, prebox_token_id: int, nothing_token_id: int, endofobject_token_id: int, vis_dim: int, vis_embed_size: int, lang_dim: int, hidden_state_dim: int, image_size: int, patch_size: int, use_media_placement_augmentation: bool = False, add_visual_token: bool = False, add_pe: bool = False, add_relation: bool = False, use_format_v2: bool = False, roi_align: bool = False, roi_output_size: int = 4, apply_mask: bool = False, ): """ Args: vision_encoder (nn.Module): HF CLIPModel lang_encoder (nn.Module): HF causal language model eoc_token_id (int): Token id for eos token media_token_id (int): Token id for <|#image#|> vis_dim (int): Dimension of the visual features. Visual features are projected to match this shape along the last dimension. cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False. """ super().__init__() self.image_end_token_id = image_end_token_id self.eoc_token_id = eoc_token_id self.media_token_id = media_token_id self.use_media_placement_augmentation = use_media_placement_augmentation self.vis_dim = vis_dim self.lang_dim = lang_dim # inner_dim = self.lang_dim * 4 # self.vis_proj = nn.Sequential( # nn.LayerNorm(self.vis_dim), # nn.Linear(self.vis_dim, inner_dim, bias=False), # nn.GELU(), # nn.Linear(inner_dim, self.lang_dim, bias=False), # ) self.vis_proj = nn.Linear(self.vis_dim, self.lang_dim) self.vision_encoder = vision_encoder self.num_positions = vis_embed_size self.lang_encoder = lang_encoder self.lang_encoder.init_flamingo( media_token_id=media_token_id, use_media_placement_augmentation=self.use_media_placement_augmentation, ) first_layer = self.lang_encoder._get_decoder_layers()[0] first_layer.add_visual_token = add_visual_token first_layer.visual_token_id = visual_token_id first_layer.media_token_id = media_token_id first_layer.box_token_id = box_token_id # first_layer.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None # assert not (add_pe and add_relation) # self.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None # first_layer.pos_enc = self.pos_enc self.box_token_id = box_token_id self.prebox_token_id = prebox_token_id self.media_token_id = media_token_id self.visual_token_id = visual_token_id self.previsual_token_id = previsual_token_id self.hidden_state_dim = hidden_state_dim self.image_size = image_size self.patch_size = patch_size self.patch_num = self.image_size // self.patch_size self.detection_head = YOLOXHead( num_classes=1, strides=[patch_size], in_channels=[self.hidden_state_dim + self.lang_dim], ) self.use_format_v2 = use_format_v2 self.nothing_token_id = nothing_token_id self.roi_align = roi_align self.roi_output_size = roi_output_size if roi_align else None self.apply_mask = apply_mask self.endofobject_token_id = endofobject_token_id def _get_detection_batch( self, visual_token_id, previsual_token_id, input_ids: torch.Tensor, hidden_states: torch.Tensor, added_bbox_list, box_num = 100, ): select_mask = torch.logical_or(input_ids == visual_token_id, input_ids == previsual_token_id) visual_token_position = select_mask.nonzero() visual_token_hidden_states = hidden_states[select_mask] prev_batch_idx = -1 media_idx = [] cnt = 0 assert len(visual_token_hidden_states) == len(visual_token_position) if len(added_bbox_list) != len(visual_token_position): msg = f"ERROR: {len(added_bbox_list)}:{len(visual_token_position)}\n{added_bbox_list}\n{visual_token_position}" logging.info(msg) alpha = 0.0 else: alpha = 1.0 visual_batches = [] previsual_batches = [] for (batch_idx, idx), visual_token_hidden_state, bbox in zip( visual_token_position, visual_token_hidden_states, added_bbox_list, ): # ! VERY IMPORTANT BUG ! bbox = bbox.clone() # ! VERY IMPORTANT BUG ! batch_idx = batch_idx.item() idx = idx.item() if batch_idx != prev_batch_idx: prev_batch_idx = batch_idx this_input_ids = input_ids[batch_idx] cnt += len(media_idx) media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist() for i in range(len(media_idx)): if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]: break image_index = cnt + i size = int(self.image_embedding[image_index].shape[0] ** 0.5) image_embedding = self.image_embedding[image_index] # inplace xyxy2cxcywh # print(bbox) # TODO: CHECK self.image_size. Is it 224? bbox = xyxy2cxcywh(bbox) * self.image_size # print(bbox) concat_image_visual_embedding = torch.cat([image_embedding, visual_token_hidden_state.unsqueeze(0).repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1) label = torch.cat([torch.zeros(bbox.shape[0], 1, device=bbox.device), bbox], dim=-1) label = torch.cat([label, torch.zeros(box_num - label.shape[0], label.shape[1], device=label.device)], dim=0) if input_ids[batch_idx, idx] == previsual_token_id: previsual_batches.append([concat_image_visual_embedding, label]) elif input_ids[batch_idx, idx] == visual_token_id: visual_batches.append([concat_image_visual_embedding, label]) else: logging.info(f"WARNING... NOT visual nor previsual. it is {input_ids[batch_idx, idx]}") return visual_batches, previsual_batches, alpha, alpha def get_detection_losses( self, input_ids: torch.Tensor, hidden_states: torch.Tensor, added_bbox_list, box_num = 100, ): visual_token_batches, previsual_token_batches, alpha1, alpha2 = self._get_detection_batch( visual_token_id=self.visual_token_id, previsual_token_id=self.previsual_token_id, input_ids=input_ids, hidden_states=hidden_states, added_bbox_list=added_bbox_list, box_num=box_num, ) loss_dict = [] for batches, alpha in zip([visual_token_batches, previsual_token_batches], [alpha1, alpha2]): # x: [B, C, H, W] if len(batches) != 0: x = torch.cat([batch[0].unsqueeze(0) for batch in batches], dim=0).permute(0,3,1,2) labels = torch.cat([batch[1].unsqueeze(0) for batch in batches], dim=0) else: x = None labels = None if x is not None: losses = self.detection_head(xin=[x], labels=labels) loss, loss_iou, loss_obj, loss_cls, loss_l1, _ = losses else: loss = torch.tensor(0.0).cuda() loss_iou = loss loss_obj = loss loss_cls = loss loss_l1 = loss loss_dict.append(dict( loss=loss * alpha, loss_iou=loss_iou * alpha, loss_obj=loss_obj * alpha, loss_cls=loss_cls * alpha, loss_l1=loss_l1 * alpha, )) ret_loss = {} for key in loss_dict[0].keys(): ret_loss[key] = 0.0 for d in loss_dict: ret_loss[key] += d[key] return ret_loss, loss_dict def get_detection_result( self, input_ids: torch.Tensor, hidden_states: torch.Tensor, nms_thr: float = 0.45, score_thr: float = 0.01, debug_id: int = 0, debug_mode: bool = False, ): assert len(input_ids) == 1, "only batch size = 1 is supported yet" # assert len(self.image_embedding) == 1, "only one image is supported yet" # assert (input_ids[..., -1] == self.visual_token_id).all(), "the last token should be visual token" visual_token_hidden_state = hidden_states[..., -1, :] boxes_list = [] scores_list = [] for image_embedding in self.image_embedding: size = int(image_embedding.shape[0] ** 0.5) x = torch.cat([image_embedding, visual_token_hidden_state.repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1).unsqueeze(0).permute(0,3,1,2) with torch.no_grad(): outputs = self.detection_head(xin=[x], labels=None) boxes = outputs[0,:,:4].cpu().numpy() scores = outputs[0,:,4].cpu().numpy() scores_mask = scores > score_thr boxes = boxes[scores_mask] boxes = cxcywh2xyxy(boxes) scores = scores[scores_mask] keep = nms(boxes, scores, nms_thr=nms_thr) boxes = boxes[keep] scores = scores[keep] if debug_mode: obj_heatmap = outputs[0,:, -2].reshape(size, size).cpu().numpy() import matplotlib.pyplot as plt import seaborn as sns plt.figure() sns_plot = sns.heatmap(obj_heatmap) plt.savefig(f"heatmap_{debug_id}.jpg") debug_id += 1 boxes_list.append(boxes) scores_list.append(scores) if len(boxes_list) == 1: boxes_list = boxes_list[0] scores_list = scores_list[0] return boxes_list, scores_list def _condition_attention(self, loc_list = None): for i in range(len(self.lang_encoder.gpt_neox.layers)): self.lang_encoder.gpt_neox.layers[i].decoder_layer.attention.loc_list = loc_list def forward( self, vision_x: torch.Tensor, lang_x: torch.Tensor, attention_mask: torch.Tensor = None, labels: torch.Tensor = None, use_cached_vision_x: bool = False, clear_conditioned_layers: bool = True, past_key_values=None, use_cache: bool = False, image_nums=None, image_start_index_list=None, added_bbox_list=None, add_box: bool = False, relations=None, debug_mode: bool = False, ): """ Forward pass of Flamingo. Args: vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) with F=1 lang_x (torch.Tensor): Language input ids shape (B, T_txt) attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. labels (torch.Tensor, optional): Labels. Defaults to None. clear_conditioned_layers: if True, clear the conditioned layers once the foward pass is completed. Set this to false if the same set of images will be reused in another subsequent forward pass. past_key_values: pre-computed values to pass to language model. See past_key_values documentation in Hugging Face CausalLM models. use_cache: whether to use cached key values. See use_cache documentation in Hugging Face CausalLM models. """ self.valid = True self.lang_encoder.loc_list = None if use_cached_vision_x: # Case: use cached; vision_x should be cached and other # vision-related inputs should not be provided. assert ( vision_x is None ), "Expect vision_x to be None when use_cached_vision_x is True." assert self.lang_encoder.is_conditioned() else: # Case: do not use caching (i.e. this is a standard forward pass); self._encode_vision_x( vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=added_bbox_list if add_box else None, input_ids=lang_x, relations=relations, ) if self.apply_mask: if self.roi_align: attend_length = 1 + self.roi_output_size ** 2 else: attend_length = 2 prebox_loc = (lang_x == self.prebox_token_id).nonzero() loc_list = [] for (x, y) in prebox_loc: x = x.item() y = y.item() for yy in range(y+1, lang_x.shape[1]): if lang_x[x, yy] == self.endofobject_token_id: # [batch_idx, [previsual:prebox], [object:endofobject-1]] loc_list.append([x, [y-attend_length+1, y], [y+1, yy-1]]) self._condition_attention(loc_list=loc_list) else: self._condition_attention(None) output = self.lang_encoder( input_ids=lang_x, attention_mask=attention_mask, labels=labels, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=True, ) if vision_x is None: output['loss'][0] += 0.0 * self.vis_proj(self.vision_encoder.visual(torch.randn(1, 3, 224, 224, device=lang_x.device, dtype=output['loss'].dtype))[1]).mean() hidden_states = output["hidden_states"][-1] if self.training and added_bbox_list is not None: detection_losses, loss_dict = self.get_detection_losses( input_ids=lang_x, hidden_states=hidden_states, added_bbox_list=added_bbox_list, ) output["detection_losses"] = detection_losses output["loss_dict"] = loss_dict elif labels is None: boxes, scores = self.get_detection_result( input_ids=lang_x, hidden_states=hidden_states, debug_id=self.debug_id if hasattr(self, "debug_id") else None, debug_mode=debug_mode, ) output["boxes"] = boxes output["scores"] = scores if clear_conditioned_layers: self.lang_encoder.clear_conditioned_layers() self._condition_attention(None) return output def generate( self, vision_x: torch.Tensor, lang_x: torch.Tensor, attention_mask: torch.Tensor = None, added_bbox_list=None, num_beams=1, max_new_tokens=None, temperature=1.0, top_k=0, top_p=1.0, no_repeat_ngram_size=0, prefix_allowed_tokens_fn=None, length_penalty=1.0, num_return_sequences=1, do_sample=False, early_stopping=False, bad_words_ids=None, force_words_ids=None, image_start_index_list=None, image_nums=None, min_length=None, return_dict_in_generate=False, output_hidden_states=False, output_scores=False, logits_processor_list=None, eos_token_id=None, ): """ Generate text conditioned on vision and language inputs. Args: vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) images in the same chunk are collated along T_img, and frames are collated along F currently only F=1 is supported (single-frame videos) lang_x (torch.Tensor): Language input shape (B, T_txt) max_length (int, optional): Maximum length of the output. Defaults to None. attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. num_beams (int, optional): Number of beams. Defaults to 1. max_new_tokens (int, optional): Maximum new tokens. Defaults to None. temperature (float, optional): Temperature. Defaults to 1.0. top_k (int, optional): Top k. Defaults to 0. top_p (float, optional): Top p. Defaults to 1.0. no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. length_penalty (float, optional): Length penalty. Defaults to 1.0. num_return_sequences (int, optional): Number of return sequences. Defaults to 1. do_sample (bool, optional): Do sample. Defaults to False. early_stopping (bool, optional): Early stopping. Defaults to False. Returns: torch.Tensor: lang_x with generated tokens appended to it """ if num_beams > 1: vision_x = vision_x.repeat_interleave(num_beams, dim=0) image_start_index_list = torch.tensor(image_start_index_list).repeat_interleave(num_beams, dim=0).tolist() image_nums = torch.tensor(image_nums).repeat_interleave(num_beams, dim=0).tolist() if added_bbox_list is not None and len(added_bbox_list) != 0: added_bbox_list = added_bbox_list * num_beams self._encode_vision_x(vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, num_beams=num_beams, added_bbox_list=added_bbox_list, input_ids=lang_x.repeat_interleave(num_beams, dim=0)) if logits_processor_list is not None: assert isinstance(logits_processor_list, list) logits_processor_list = LogitsProcessorList(logits_processor_list) output = self.lang_encoder.generate( input_ids=lang_x, attention_mask=attention_mask, eos_token_id=(self.eoc_token_id) if eos_token_id is None else eos_token_id, num_beams=num_beams, max_new_tokens=max_new_tokens, min_length=min_length, length_penalty=length_penalty, logits_processor=logits_processor_list, return_dict_in_generate=return_dict_in_generate, output_scores=output_scores, ) self.lang_encoder.clear_conditioned_layers() return output def _get_data_list_and_visual_tokens( self, all_box_list, box_token_id, prebox_token_id, input_ids, vision_x, nothing_embedding = None, ): box_locations = (torch.logical_or(input_ids == box_token_id, input_ids == prebox_token_id)).nonzero() prev_batch_idx = -1 media_idx = [] cnt = 0 data_list = [] visual_tokens = [] if len(all_box_list) != len(box_locations): logging.info(f"WARNING. len(all_box_list) != len(box_locations) {len(all_box_list)} vs {len(box_locations)}") self.valid = False for III, (batch_idx, idx) in enumerate(box_locations): batch_idx = batch_idx.item() idx = idx.item() if batch_idx != prev_batch_idx: prev_batch_idx = batch_idx this_input_ids = input_ids[batch_idx] cnt += len(media_idx) media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist() for i in range(len(media_idx)): if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]: break image_index = cnt + i size = int(vision_x[image_index].shape[0] ** 0.5) image_feature = vision_x[image_index].reshape(size, size, -1) try: raw_xyxy = all_box_list[III] except: logging.info("out of scope for all_box_list") raw_xyxy = all_box_list[-1] region_xyxy = np.array(raw_xyxy) * size x1, y1, x2, y2 = region_xyxy.astype(int).clip(0, size-1).tolist() x2 = max(x1, x2) y2 = max(y1, y2) if x1 + y1 + x2 + y2 == 0.0 and nothing_embedding is not None: visual_token = nothing_embedding else: if self.roi_align: visual_token = torchvision.ops.roi_align( image_feature.permute(2, 0, 1).unsqueeze(0), [torch.tensor(region_xyxy.astype(np.float32)).unsqueeze(0).cuda()], output_size=self.roi_output_size, spatial_scale=1.0, ) visual_token = visual_token.squeeze(0).flatten(1).permute(1, 0) else: visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0) box = torch.tensor([0] + raw_xyxy, device=visual_token.device, dtype=visual_token.dtype) data_list.append([visual_token, box, batch_idx, idx, i]) visual_tokens.append(visual_token) return data_list, visual_tokens def _encode_vision_x(self, vision_x: torch.Tensor, image_nums=None, image_start_index_list=None, added_bbox_list=None, num_beams=None, input_ids=None, relations=None): """ Compute media tokens from vision input by passing it through vision encoder and conditioning language model. Args: vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) Images in the same chunk are collated along T_img, and frames are collated along F Currently only F=1 is supported (single-frame videos) rearrange code based on https://github.com/dhansmair/flamingo-mini """ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" b, T, F = vision_x.shape[:3] assert F == 1, "Only single frame supported" vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") if hasattr(self.vision_encoder, "visual"): vision_x = self.vision_encoder.visual(vision_x)[1] else: vision_x = self.vision_encoder(vision_x).flatten(2).permute(0, 2, 1) vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) # print(vision_x[0,0,0]) # # DEBUG HERE # if torch.distributed.get_rank() == 0: # import pdb; pdb.set_trace() # else: # torch.distributed.barrier() vision_x = vision_x.mean(2) # vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d) # vision_x = self.vis_proj(vision_x) + self.vis_position_embedding(self.vis_position_ids).unsqueeze(0) vision_x = self.vis_proj(vision_x).squeeze(1) self.image_embedding = vision_x data_list = None visual_tokens = None if added_bbox_list is not None and input_ids is not None: all_box_list = added_bbox_list[0].tolist() for list in added_bbox_list[1:]: all_box_list.extend(list.tolist()) data_list, visual_tokens = self._get_data_list_and_visual_tokens( all_box_list=all_box_list, box_token_id=self.box_token_id, prebox_token_id=self.prebox_token_id, input_ids=input_ids, vision_x=vision_x, nothing_embedding=self.lang_encoder.gpt_neox.embed_in(torch.tensor(self.nothing_token_id).to(self.lang_encoder.gpt_neox.embed_in.weight.device)) if self.nothing_token_id is not None else None, ) first_layer = self.lang_encoder._get_decoder_layers()[0] first_layer.condition_vis_x(vision_x, image_nums, image_start_index_list, num_beams=num_beams, visual_tokens=visual_tokens, data_list=[[d[2], d[3]] for d in data_list] if data_list is not None else data_list)