Spaces:
Runtime error
Runtime error
import torch | |
import torchvision | |
from einops import rearrange | |
from torch import nn | |
from yolox.models.yolo_head import YOLOXHead | |
from yolox.utils.boxes import xyxy2cxcywh, cxcywh2xyxy | |
from yolox.utils.demo_utils import nms | |
# import matplotlib.pyplot as plt | |
# import seaborn as sns | |
import numpy as np | |
import logging | |
from open_flamingo.src.gcn import GCN | |
from transformers import LogitsProcessorList | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s %(message)s', | |
datefmt='%m/%d %I:%M:%S', | |
) | |
# class PositionEncodingModule(nn.Module): | |
# def __init__(self, dim, pos_dim=128): | |
# super().__init__() | |
# self.encode = nn.Sequential( | |
# nn.Linear(5, pos_dim // 2), | |
# nn.BatchNorm1d(pos_dim // 2), | |
# nn.GELU(), | |
# nn.Linear(pos_dim // 2, pos_dim), | |
# nn.BatchNorm1d(pos_dim), | |
# nn.GELU(), | |
# ) | |
# self.merge = nn.Sequential( | |
# nn.Linear(dim + pos_dim, dim), | |
# nn.BatchNorm1d(dim), | |
# nn.GELU(), | |
# ) | |
# def forward(self, x, box): | |
# box = self.encode(box) | |
# x = torch.cat([x, box], dim=-1) | |
# x = self.merge(x) | |
# return x | |
# class PositionEncodingModule(nn.Module): | |
# def __init__(self, dim): | |
# super().__init__() | |
# self.encode = nn.Sequential( | |
# nn.Linear(5, dim), | |
# nn.GELU(), | |
# ) | |
# def forward(self, x, box): | |
# box = self.encode(box) | |
# x = x + box | |
# return x | |
# class PositionEncodingModule2(nn.Module): | |
# def __init__(self, dim): | |
# super().__init__() | |
# self.encode = nn.Sequential( | |
# nn.Linear(5 + dim, dim), | |
# nn.ELU(), | |
# ) | |
# def forward(self, x, box): | |
# x = torch.cat([x, box], dim=-1) | |
# x = self.encode(x) | |
# return x | |
# class RelationHead(nn.Module): | |
# def __init__(self, dim): | |
# super().__init__() | |
# self.encode = nn.Sequential( | |
# nn.LayerNorm(dim), | |
# nn.Linear(dim, 128), | |
# nn.ELU(), | |
# ) | |
# self.classifier = nn.Linear(256, 51) | |
# def forward(self, x1, x2): | |
# x1 = self.encode(x1) | |
# x2 = self.encode(x2) | |
# x = torch.cat([x1, x2], dim=-1) | |
# x = self.classifier(x) | |
# return x | |
class Flamingo(nn.Module): | |
def __init__( | |
self, | |
vision_encoder: nn.Module, | |
lang_encoder: nn.Module, | |
eoc_token_id: int, | |
media_token_id: int, | |
image_end_token_id: int, | |
visual_token_id: int, | |
previsual_token_id: int, | |
box_token_id: int, | |
prebox_token_id: int, | |
nothing_token_id: int, | |
endofobject_token_id: int, | |
vis_dim: int, | |
vis_embed_size: int, | |
lang_dim: int, | |
hidden_state_dim: int, | |
image_size: int, | |
patch_size: int, | |
use_media_placement_augmentation: bool = False, | |
add_visual_token: bool = False, | |
add_pe: bool = False, | |
add_relation: bool = False, | |
use_format_v2: bool = False, | |
roi_align: bool = False, | |
roi_output_size: int = 4, | |
apply_mask: bool = False, | |
): | |
""" | |
Args: | |
vision_encoder (nn.Module): HF CLIPModel | |
lang_encoder (nn.Module): HF causal language model | |
eoc_token_id (int): Token id for eos token | |
media_token_id (int): Token id for <|#image#|> | |
vis_dim (int): Dimension of the visual features. | |
Visual features are projected to match this shape along the last dimension. | |
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. | |
use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False. | |
""" | |
super().__init__() | |
self.image_end_token_id = image_end_token_id | |
self.eoc_token_id = eoc_token_id | |
self.media_token_id = media_token_id | |
self.use_media_placement_augmentation = use_media_placement_augmentation | |
self.vis_dim = vis_dim | |
self.lang_dim = lang_dim | |
# inner_dim = self.lang_dim * 4 | |
# self.vis_proj = nn.Sequential( | |
# nn.LayerNorm(self.vis_dim), | |
# nn.Linear(self.vis_dim, inner_dim, bias=False), | |
# nn.GELU(), | |
# nn.Linear(inner_dim, self.lang_dim, bias=False), | |
# ) | |
self.vis_proj = nn.Linear(self.vis_dim, self.lang_dim) | |
self.vision_encoder = vision_encoder | |
self.num_positions = vis_embed_size | |
self.lang_encoder = lang_encoder | |
self.lang_encoder.init_flamingo( | |
media_token_id=media_token_id, | |
use_media_placement_augmentation=self.use_media_placement_augmentation, | |
) | |
first_layer = self.lang_encoder._get_decoder_layers()[0] | |
first_layer.add_visual_token = add_visual_token | |
first_layer.visual_token_id = visual_token_id | |
first_layer.media_token_id = media_token_id | |
first_layer.box_token_id = box_token_id | |
# first_layer.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None | |
# assert not (add_pe and add_relation) | |
# self.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None | |
# first_layer.pos_enc = self.pos_enc | |
self.box_token_id = box_token_id | |
self.prebox_token_id = prebox_token_id | |
self.media_token_id = media_token_id | |
self.visual_token_id = visual_token_id | |
self.previsual_token_id = previsual_token_id | |
self.hidden_state_dim = hidden_state_dim | |
self.image_size = image_size | |
self.patch_size = patch_size | |
self.patch_num = self.image_size // self.patch_size | |
self.detection_head = YOLOXHead( | |
num_classes=1, | |
strides=[patch_size], | |
in_channels=[self.hidden_state_dim + self.lang_dim], | |
) | |
self.use_format_v2 = use_format_v2 | |
self.nothing_token_id = nothing_token_id | |
self.roi_align = roi_align | |
self.roi_output_size = roi_output_size if roi_align else None | |
self.apply_mask = apply_mask | |
self.endofobject_token_id = endofobject_token_id | |
def _get_detection_batch( | |
self, | |
visual_token_id, | |
previsual_token_id, | |
input_ids: torch.Tensor, | |
hidden_states: torch.Tensor, | |
added_bbox_list, | |
box_num = 100, | |
): | |
select_mask = torch.logical_or(input_ids == visual_token_id, input_ids == previsual_token_id) | |
visual_token_position = select_mask.nonzero() | |
visual_token_hidden_states = hidden_states[select_mask] | |
prev_batch_idx = -1 | |
media_idx = [] | |
cnt = 0 | |
assert len(visual_token_hidden_states) == len(visual_token_position) | |
if len(added_bbox_list) != len(visual_token_position): | |
msg = f"ERROR: {len(added_bbox_list)}:{len(visual_token_position)}\n{added_bbox_list}\n{visual_token_position}" | |
logging.info(msg) | |
alpha = 0.0 | |
else: | |
alpha = 1.0 | |
visual_batches = [] | |
previsual_batches = [] | |
for (batch_idx, idx), visual_token_hidden_state, bbox in zip( | |
visual_token_position, visual_token_hidden_states, added_bbox_list, | |
): | |
# ! VERY IMPORTANT BUG ! | |
bbox = bbox.clone() | |
# ! VERY IMPORTANT BUG ! | |
batch_idx = batch_idx.item() | |
idx = idx.item() | |
if batch_idx != prev_batch_idx: | |
prev_batch_idx = batch_idx | |
this_input_ids = input_ids[batch_idx] | |
cnt += len(media_idx) | |
media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist() | |
for i in range(len(media_idx)): | |
if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]: | |
break | |
image_index = cnt + i | |
size = int(self.image_embedding[image_index].shape[0] ** 0.5) | |
image_embedding = self.image_embedding[image_index] | |
# inplace xyxy2cxcywh | |
# print(bbox) | |
# TODO: CHECK self.image_size. Is it 224? | |
bbox = xyxy2cxcywh(bbox) * self.image_size | |
# print(bbox) | |
concat_image_visual_embedding = torch.cat([image_embedding, visual_token_hidden_state.unsqueeze(0).repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1) | |
label = torch.cat([torch.zeros(bbox.shape[0], 1, device=bbox.device), bbox], dim=-1) | |
label = torch.cat([label, torch.zeros(box_num - label.shape[0], label.shape[1], device=label.device)], dim=0) | |
if input_ids[batch_idx, idx] == previsual_token_id: | |
previsual_batches.append([concat_image_visual_embedding, label]) | |
elif input_ids[batch_idx, idx] == visual_token_id: | |
visual_batches.append([concat_image_visual_embedding, label]) | |
else: | |
logging.info(f"WARNING... NOT visual nor previsual. it is {input_ids[batch_idx, idx]}") | |
return visual_batches, previsual_batches, alpha, alpha | |
def get_detection_losses( | |
self, | |
input_ids: torch.Tensor, | |
hidden_states: torch.Tensor, | |
added_bbox_list, | |
box_num = 100, | |
): | |
visual_token_batches, previsual_token_batches, alpha1, alpha2 = self._get_detection_batch( | |
visual_token_id=self.visual_token_id, | |
previsual_token_id=self.previsual_token_id, | |
input_ids=input_ids, | |
hidden_states=hidden_states, | |
added_bbox_list=added_bbox_list, | |
box_num=box_num, | |
) | |
loss_dict = [] | |
for batches, alpha in zip([visual_token_batches, previsual_token_batches], [alpha1, alpha2]): | |
# x: [B, C, H, W] | |
if len(batches) != 0: | |
x = torch.cat([batch[0].unsqueeze(0) for batch in batches], dim=0).permute(0,3,1,2) | |
labels = torch.cat([batch[1].unsqueeze(0) for batch in batches], dim=0) | |
else: | |
x = None | |
labels = None | |
if x is not None: | |
losses = self.detection_head(xin=[x], labels=labels) | |
loss, loss_iou, loss_obj, loss_cls, loss_l1, _ = losses | |
else: | |
loss = torch.tensor(0.0).cuda() | |
loss_iou = loss | |
loss_obj = loss | |
loss_cls = loss | |
loss_l1 = loss | |
loss_dict.append(dict( | |
loss=loss * alpha, | |
loss_iou=loss_iou * alpha, | |
loss_obj=loss_obj * alpha, | |
loss_cls=loss_cls * alpha, | |
loss_l1=loss_l1 * alpha, | |
)) | |
ret_loss = {} | |
for key in loss_dict[0].keys(): | |
ret_loss[key] = 0.0 | |
for d in loss_dict: | |
ret_loss[key] += d[key] | |
return ret_loss, loss_dict | |
def get_detection_result( | |
self, | |
input_ids: torch.Tensor, | |
hidden_states: torch.Tensor, | |
nms_thr: float = 0.45, | |
score_thr: float = 0.01, | |
debug_id: int = 0, | |
debug_mode: bool = False, | |
): | |
assert len(input_ids) == 1, "only batch size = 1 is supported yet" | |
# assert len(self.image_embedding) == 1, "only one image is supported yet" | |
# assert (input_ids[..., -1] == self.visual_token_id).all(), "the last token should be visual token" | |
visual_token_hidden_state = hidden_states[..., -1, :] | |
boxes_list = [] | |
scores_list = [] | |
for image_embedding in self.image_embedding: | |
size = int(image_embedding.shape[0] ** 0.5) | |
x = torch.cat([image_embedding, visual_token_hidden_state.repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1).unsqueeze(0).permute(0,3,1,2) | |
with torch.no_grad(): | |
outputs = self.detection_head(xin=[x], labels=None) | |
boxes = outputs[0,:,:4].cpu().numpy() | |
scores = outputs[0,:,4].cpu().numpy() | |
scores_mask = scores > score_thr | |
boxes = boxes[scores_mask] | |
boxes = cxcywh2xyxy(boxes) | |
scores = scores[scores_mask] | |
keep = nms(boxes, scores, nms_thr=nms_thr) | |
boxes = boxes[keep] | |
scores = scores[keep] | |
if debug_mode: | |
obj_heatmap = outputs[0,:, -2].reshape(size, size).cpu().numpy() | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
plt.figure() | |
sns_plot = sns.heatmap(obj_heatmap) | |
plt.savefig(f"heatmap_{debug_id}.jpg") | |
debug_id += 1 | |
boxes_list.append(boxes) | |
scores_list.append(scores) | |
if len(boxes_list) == 1: | |
boxes_list = boxes_list[0] | |
scores_list = scores_list[0] | |
return boxes_list, scores_list | |
def _condition_attention(self, loc_list = None): | |
for i in range(len(self.lang_encoder.gpt_neox.layers)): | |
self.lang_encoder.gpt_neox.layers[i].decoder_layer.attention.loc_list = loc_list | |
def forward( | |
self, | |
vision_x: torch.Tensor, | |
lang_x: torch.Tensor, | |
attention_mask: torch.Tensor = None, | |
labels: torch.Tensor = None, | |
use_cached_vision_x: bool = False, | |
clear_conditioned_layers: bool = True, | |
past_key_values=None, | |
use_cache: bool = False, | |
image_nums=None, | |
image_start_index_list=None, | |
added_bbox_list=None, | |
add_box: bool = False, | |
relations=None, | |
debug_mode: bool = False, | |
): | |
""" | |
Forward pass of Flamingo. | |
Args: | |
vision_x (torch.Tensor): Vision input | |
shape (B, T_img, F, C, H, W) with F=1 | |
lang_x (torch.Tensor): Language input ids | |
shape (B, T_txt) | |
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
labels (torch.Tensor, optional): Labels. Defaults to None. | |
clear_conditioned_layers: if True, clear the conditioned layers | |
once the foward pass is completed. Set this to false if the | |
same set of images will be reused in another subsequent | |
forward pass. | |
past_key_values: pre-computed values to pass to language model. | |
See past_key_values documentation in Hugging Face | |
CausalLM models. | |
use_cache: whether to use cached key values. See use_cache | |
documentation in Hugging Face CausalLM models. | |
""" | |
self.valid = True | |
self.lang_encoder.loc_list = None | |
if use_cached_vision_x: | |
# Case: use cached; vision_x should be cached and other | |
# vision-related inputs should not be provided. | |
assert ( | |
vision_x is None | |
), "Expect vision_x to be None when use_cached_vision_x is True." | |
assert self.lang_encoder.is_conditioned() | |
else: | |
# Case: do not use caching (i.e. this is a standard forward pass); | |
self._encode_vision_x( | |
vision_x=vision_x, | |
image_nums=image_nums, | |
image_start_index_list=image_start_index_list, | |
added_bbox_list=added_bbox_list if add_box else None, | |
input_ids=lang_x, | |
relations=relations, | |
) | |
if self.apply_mask: | |
if self.roi_align: | |
attend_length = 1 + self.roi_output_size ** 2 | |
else: | |
attend_length = 2 | |
prebox_loc = (lang_x == self.prebox_token_id).nonzero() | |
loc_list = [] | |
for (x, y) in prebox_loc: | |
x = x.item() | |
y = y.item() | |
for yy in range(y+1, lang_x.shape[1]): | |
if lang_x[x, yy] == self.endofobject_token_id: | |
# [batch_idx, [previsual:prebox], [object:endofobject-1]] | |
loc_list.append([x, [y-attend_length+1, y], [y+1, yy-1]]) | |
self._condition_attention(loc_list=loc_list) | |
else: | |
self._condition_attention(None) | |
output = self.lang_encoder( | |
input_ids=lang_x, | |
attention_mask=attention_mask, | |
labels=labels, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
output_hidden_states=True, | |
) | |
if vision_x is None: | |
output['loss'][0] += 0.0 * self.vis_proj(self.vision_encoder.visual(torch.randn(1, 3, 224, 224, device=lang_x.device, dtype=output['loss'].dtype))[1]).mean() | |
hidden_states = output["hidden_states"][-1] | |
if self.training and added_bbox_list is not None: | |
detection_losses, loss_dict = self.get_detection_losses( | |
input_ids=lang_x, | |
hidden_states=hidden_states, | |
added_bbox_list=added_bbox_list, | |
) | |
output["detection_losses"] = detection_losses | |
output["loss_dict"] = loss_dict | |
elif labels is None: | |
boxes, scores = self.get_detection_result( | |
input_ids=lang_x, | |
hidden_states=hidden_states, | |
debug_id=self.debug_id if hasattr(self, "debug_id") else None, | |
debug_mode=debug_mode, | |
) | |
output["boxes"] = boxes | |
output["scores"] = scores | |
if clear_conditioned_layers: | |
self.lang_encoder.clear_conditioned_layers() | |
self._condition_attention(None) | |
return output | |
def generate( | |
self, | |
vision_x: torch.Tensor, | |
lang_x: torch.Tensor, | |
attention_mask: torch.Tensor = None, | |
added_bbox_list=None, | |
num_beams=1, | |
max_new_tokens=None, | |
temperature=1.0, | |
top_k=0, | |
top_p=1.0, | |
no_repeat_ngram_size=0, | |
prefix_allowed_tokens_fn=None, | |
length_penalty=1.0, | |
num_return_sequences=1, | |
do_sample=False, | |
early_stopping=False, | |
bad_words_ids=None, | |
force_words_ids=None, | |
image_start_index_list=None, | |
image_nums=None, | |
min_length=None, | |
return_dict_in_generate=False, | |
output_hidden_states=False, | |
output_scores=False, | |
logits_processor_list=None, | |
eos_token_id=None, | |
): | |
""" | |
Generate text conditioned on vision and language inputs. | |
Args: | |
vision_x (torch.Tensor): Vision input | |
shape (B, T_img, F, C, H, W) | |
images in the same chunk are collated along T_img, and frames are collated along F | |
currently only F=1 is supported (single-frame videos) | |
lang_x (torch.Tensor): Language input | |
shape (B, T_txt) | |
max_length (int, optional): Maximum length of the output. Defaults to None. | |
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
num_beams (int, optional): Number of beams. Defaults to 1. | |
max_new_tokens (int, optional): Maximum new tokens. Defaults to None. | |
temperature (float, optional): Temperature. Defaults to 1.0. | |
top_k (int, optional): Top k. Defaults to 0. | |
top_p (float, optional): Top p. Defaults to 1.0. | |
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. | |
length_penalty (float, optional): Length penalty. Defaults to 1.0. | |
num_return_sequences (int, optional): Number of return sequences. Defaults to 1. | |
do_sample (bool, optional): Do sample. Defaults to False. | |
early_stopping (bool, optional): Early stopping. Defaults to False. | |
Returns: | |
torch.Tensor: lang_x with generated tokens appended to it | |
""" | |
if num_beams > 1: | |
vision_x = vision_x.repeat_interleave(num_beams, dim=0) | |
image_start_index_list = torch.tensor(image_start_index_list).repeat_interleave(num_beams, dim=0).tolist() | |
image_nums = torch.tensor(image_nums).repeat_interleave(num_beams, dim=0).tolist() | |
if added_bbox_list is not None and len(added_bbox_list) != 0: | |
added_bbox_list = added_bbox_list * num_beams | |
self._encode_vision_x(vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, num_beams=num_beams, added_bbox_list=added_bbox_list, input_ids=lang_x.repeat_interleave(num_beams, dim=0)) | |
if logits_processor_list is not None: | |
assert isinstance(logits_processor_list, list) | |
logits_processor_list = LogitsProcessorList(logits_processor_list) | |
output = self.lang_encoder.generate( | |
input_ids=lang_x, | |
attention_mask=attention_mask, | |
eos_token_id=(self.eoc_token_id) if eos_token_id is None else eos_token_id, | |
num_beams=num_beams, | |
max_new_tokens=max_new_tokens, | |
min_length=min_length, | |
length_penalty=length_penalty, | |
logits_processor=logits_processor_list, | |
return_dict_in_generate=return_dict_in_generate, | |
output_scores=output_scores, | |
) | |
self.lang_encoder.clear_conditioned_layers() | |
return output | |
def _get_data_list_and_visual_tokens( | |
self, | |
all_box_list, | |
box_token_id, | |
prebox_token_id, | |
input_ids, | |
vision_x, | |
nothing_embedding = None, | |
): | |
box_locations = (torch.logical_or(input_ids == box_token_id, input_ids == prebox_token_id)).nonzero() | |
prev_batch_idx = -1 | |
media_idx = [] | |
cnt = 0 | |
data_list = [] | |
visual_tokens = [] | |
if len(all_box_list) != len(box_locations): | |
logging.info(f"WARNING. len(all_box_list) != len(box_locations) {len(all_box_list)} vs {len(box_locations)}") | |
self.valid = False | |
for III, (batch_idx, idx) in enumerate(box_locations): | |
batch_idx = batch_idx.item() | |
idx = idx.item() | |
if batch_idx != prev_batch_idx: | |
prev_batch_idx = batch_idx | |
this_input_ids = input_ids[batch_idx] | |
cnt += len(media_idx) | |
media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist() | |
for i in range(len(media_idx)): | |
if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]: | |
break | |
image_index = cnt + i | |
size = int(vision_x[image_index].shape[0] ** 0.5) | |
image_feature = vision_x[image_index].reshape(size, size, -1) | |
try: | |
raw_xyxy = all_box_list[III] | |
except: | |
logging.info("out of scope for all_box_list") | |
raw_xyxy = all_box_list[-1] | |
region_xyxy = np.array(raw_xyxy) * size | |
x1, y1, x2, y2 = region_xyxy.astype(int).clip(0, size-1).tolist() | |
x2 = max(x1, x2) | |
y2 = max(y1, y2) | |
if x1 + y1 + x2 + y2 == 0.0 and nothing_embedding is not None: | |
visual_token = nothing_embedding | |
else: | |
if self.roi_align: | |
visual_token = torchvision.ops.roi_align( | |
image_feature.permute(2, 0, 1).unsqueeze(0), | |
[torch.tensor(region_xyxy.astype(np.float32)).unsqueeze(0).cuda()], | |
output_size=self.roi_output_size, | |
spatial_scale=1.0, | |
) | |
visual_token = visual_token.squeeze(0).flatten(1).permute(1, 0) | |
else: | |
visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0) | |
box = torch.tensor([0] + raw_xyxy, device=visual_token.device, dtype=visual_token.dtype) | |
data_list.append([visual_token, box, batch_idx, idx, i]) | |
visual_tokens.append(visual_token) | |
return data_list, visual_tokens | |
def _encode_vision_x(self, vision_x: torch.Tensor, image_nums=None, image_start_index_list=None, added_bbox_list=None, num_beams=None, input_ids=None, relations=None): | |
""" | |
Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
Args: | |
vision_x (torch.Tensor): Vision input | |
shape (B, T_img, F, C, H, W) | |
Images in the same chunk are collated along T_img, and frames are collated along F | |
Currently only F=1 is supported (single-frame videos) | |
rearrange code based on https://github.com/dhansmair/flamingo-mini | |
""" | |
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" | |
b, T, F = vision_x.shape[:3] | |
assert F == 1, "Only single frame supported" | |
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") | |
if hasattr(self.vision_encoder, "visual"): | |
vision_x = self.vision_encoder.visual(vision_x)[1] | |
else: | |
vision_x = self.vision_encoder(vision_x).flatten(2).permute(0, 2, 1) | |
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) | |
# print(vision_x[0,0,0]) | |
# # DEBUG HERE | |
# if torch.distributed.get_rank() == 0: | |
# import pdb; pdb.set_trace() | |
# else: | |
# torch.distributed.barrier() | |
vision_x = vision_x.mean(2) | |
# vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d) | |
# vision_x = self.vis_proj(vision_x) + self.vis_position_embedding(self.vis_position_ids).unsqueeze(0) | |
vision_x = self.vis_proj(vision_x).squeeze(1) | |
self.image_embedding = vision_x | |
data_list = None | |
visual_tokens = None | |
if added_bbox_list is not None and input_ids is not None: | |
all_box_list = added_bbox_list[0].tolist() | |
for list in added_bbox_list[1:]: | |
all_box_list.extend(list.tolist()) | |
data_list, visual_tokens = self._get_data_list_and_visual_tokens( | |
all_box_list=all_box_list, | |
box_token_id=self.box_token_id, | |
prebox_token_id=self.prebox_token_id, | |
input_ids=input_ids, | |
vision_x=vision_x, | |
nothing_embedding=self.lang_encoder.gpt_neox.embed_in(torch.tensor(self.nothing_token_id).to(self.lang_encoder.gpt_neox.embed_in.weight.device)) if self.nothing_token_id is not None else None, | |
) | |
first_layer = self.lang_encoder._get_decoder_layers()[0] | |
first_layer.condition_vis_x(vision_x, image_nums, image_start_index_list, num_beams=num_beams, visual_tokens=visual_tokens, data_list=[[d[2], d[3]] for d in data_list] if data_list is not None else data_list) | |