chendl's picture
Add application file
0b7b08a
raw
history blame
27 kB
import torch
import torchvision
from einops import rearrange
from torch import nn
from yolox.models.yolo_head import YOLOXHead
from yolox.utils.boxes import xyxy2cxcywh, cxcywh2xyxy
from yolox.utils.demo_utils import nms
# import matplotlib.pyplot as plt
# import seaborn as sns
import numpy as np
import logging
from open_flamingo.src.gcn import GCN
from transformers import LogitsProcessorList
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(message)s',
datefmt='%m/%d %I:%M:%S',
)
# class PositionEncodingModule(nn.Module):
# def __init__(self, dim, pos_dim=128):
# super().__init__()
# self.encode = nn.Sequential(
# nn.Linear(5, pos_dim // 2),
# nn.BatchNorm1d(pos_dim // 2),
# nn.GELU(),
# nn.Linear(pos_dim // 2, pos_dim),
# nn.BatchNorm1d(pos_dim),
# nn.GELU(),
# )
# self.merge = nn.Sequential(
# nn.Linear(dim + pos_dim, dim),
# nn.BatchNorm1d(dim),
# nn.GELU(),
# )
# def forward(self, x, box):
# box = self.encode(box)
# x = torch.cat([x, box], dim=-1)
# x = self.merge(x)
# return x
# class PositionEncodingModule(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.encode = nn.Sequential(
# nn.Linear(5, dim),
# nn.GELU(),
# )
# def forward(self, x, box):
# box = self.encode(box)
# x = x + box
# return x
# class PositionEncodingModule2(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.encode = nn.Sequential(
# nn.Linear(5 + dim, dim),
# nn.ELU(),
# )
# def forward(self, x, box):
# x = torch.cat([x, box], dim=-1)
# x = self.encode(x)
# return x
# class RelationHead(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.encode = nn.Sequential(
# nn.LayerNorm(dim),
# nn.Linear(dim, 128),
# nn.ELU(),
# )
# self.classifier = nn.Linear(256, 51)
# def forward(self, x1, x2):
# x1 = self.encode(x1)
# x2 = self.encode(x2)
# x = torch.cat([x1, x2], dim=-1)
# x = self.classifier(x)
# return x
class Flamingo(nn.Module):
def __init__(
self,
vision_encoder: nn.Module,
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
image_end_token_id: int,
visual_token_id: int,
previsual_token_id: int,
box_token_id: int,
prebox_token_id: int,
nothing_token_id: int,
endofobject_token_id: int,
vis_dim: int,
vis_embed_size: int,
lang_dim: int,
hidden_state_dim: int,
image_size: int,
patch_size: int,
use_media_placement_augmentation: bool = False,
add_visual_token: bool = False,
add_pe: bool = False,
add_relation: bool = False,
use_format_v2: bool = False,
roi_align: bool = False,
roi_output_size: int = 4,
apply_mask: bool = False,
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
eoc_token_id (int): Token id for eos token
media_token_id (int): Token id for <|#image#|>
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
"""
super().__init__()
self.image_end_token_id = image_end_token_id
self.eoc_token_id = eoc_token_id
self.media_token_id = media_token_id
self.use_media_placement_augmentation = use_media_placement_augmentation
self.vis_dim = vis_dim
self.lang_dim = lang_dim
# inner_dim = self.lang_dim * 4
# self.vis_proj = nn.Sequential(
# nn.LayerNorm(self.vis_dim),
# nn.Linear(self.vis_dim, inner_dim, bias=False),
# nn.GELU(),
# nn.Linear(inner_dim, self.lang_dim, bias=False),
# )
self.vis_proj = nn.Linear(self.vis_dim, self.lang_dim)
self.vision_encoder = vision_encoder
self.num_positions = vis_embed_size
self.lang_encoder = lang_encoder
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
use_media_placement_augmentation=self.use_media_placement_augmentation,
)
first_layer = self.lang_encoder._get_decoder_layers()[0]
first_layer.add_visual_token = add_visual_token
first_layer.visual_token_id = visual_token_id
first_layer.media_token_id = media_token_id
first_layer.box_token_id = box_token_id
# first_layer.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
# assert not (add_pe and add_relation)
# self.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
# first_layer.pos_enc = self.pos_enc
self.box_token_id = box_token_id
self.prebox_token_id = prebox_token_id
self.media_token_id = media_token_id
self.visual_token_id = visual_token_id
self.previsual_token_id = previsual_token_id
self.hidden_state_dim = hidden_state_dim
self.image_size = image_size
self.patch_size = patch_size
self.patch_num = self.image_size // self.patch_size
self.detection_head = YOLOXHead(
num_classes=1,
strides=[patch_size],
in_channels=[self.hidden_state_dim + self.lang_dim],
)
self.use_format_v2 = use_format_v2
self.nothing_token_id = nothing_token_id
self.roi_align = roi_align
self.roi_output_size = roi_output_size if roi_align else None
self.apply_mask = apply_mask
self.endofobject_token_id = endofobject_token_id
def _get_detection_batch(
self,
visual_token_id,
previsual_token_id,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
added_bbox_list,
box_num = 100,
):
select_mask = torch.logical_or(input_ids == visual_token_id, input_ids == previsual_token_id)
visual_token_position = select_mask.nonzero()
visual_token_hidden_states = hidden_states[select_mask]
prev_batch_idx = -1
media_idx = []
cnt = 0
assert len(visual_token_hidden_states) == len(visual_token_position)
if len(added_bbox_list) != len(visual_token_position):
msg = f"ERROR: {len(added_bbox_list)}:{len(visual_token_position)}\n{added_bbox_list}\n{visual_token_position}"
logging.info(msg)
alpha = 0.0
else:
alpha = 1.0
visual_batches = []
previsual_batches = []
for (batch_idx, idx), visual_token_hidden_state, bbox in zip(
visual_token_position, visual_token_hidden_states, added_bbox_list,
):
# ! VERY IMPORTANT BUG !
bbox = bbox.clone()
# ! VERY IMPORTANT BUG !
batch_idx = batch_idx.item()
idx = idx.item()
if batch_idx != prev_batch_idx:
prev_batch_idx = batch_idx
this_input_ids = input_ids[batch_idx]
cnt += len(media_idx)
media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
for i in range(len(media_idx)):
if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
break
image_index = cnt + i
size = int(self.image_embedding[image_index].shape[0] ** 0.5)
image_embedding = self.image_embedding[image_index]
# inplace xyxy2cxcywh
# print(bbox)
# TODO: CHECK self.image_size. Is it 224?
bbox = xyxy2cxcywh(bbox) * self.image_size
# print(bbox)
concat_image_visual_embedding = torch.cat([image_embedding, visual_token_hidden_state.unsqueeze(0).repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1)
label = torch.cat([torch.zeros(bbox.shape[0], 1, device=bbox.device), bbox], dim=-1)
label = torch.cat([label, torch.zeros(box_num - label.shape[0], label.shape[1], device=label.device)], dim=0)
if input_ids[batch_idx, idx] == previsual_token_id:
previsual_batches.append([concat_image_visual_embedding, label])
elif input_ids[batch_idx, idx] == visual_token_id:
visual_batches.append([concat_image_visual_embedding, label])
else:
logging.info(f"WARNING... NOT visual nor previsual. it is {input_ids[batch_idx, idx]}")
return visual_batches, previsual_batches, alpha, alpha
def get_detection_losses(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
added_bbox_list,
box_num = 100,
):
visual_token_batches, previsual_token_batches, alpha1, alpha2 = self._get_detection_batch(
visual_token_id=self.visual_token_id,
previsual_token_id=self.previsual_token_id,
input_ids=input_ids,
hidden_states=hidden_states,
added_bbox_list=added_bbox_list,
box_num=box_num,
)
loss_dict = []
for batches, alpha in zip([visual_token_batches, previsual_token_batches], [alpha1, alpha2]):
# x: [B, C, H, W]
if len(batches) != 0:
x = torch.cat([batch[0].unsqueeze(0) for batch in batches], dim=0).permute(0,3,1,2)
labels = torch.cat([batch[1].unsqueeze(0) for batch in batches], dim=0)
else:
x = None
labels = None
if x is not None:
losses = self.detection_head(xin=[x], labels=labels)
loss, loss_iou, loss_obj, loss_cls, loss_l1, _ = losses
else:
loss = torch.tensor(0.0).cuda()
loss_iou = loss
loss_obj = loss
loss_cls = loss
loss_l1 = loss
loss_dict.append(dict(
loss=loss * alpha,
loss_iou=loss_iou * alpha,
loss_obj=loss_obj * alpha,
loss_cls=loss_cls * alpha,
loss_l1=loss_l1 * alpha,
))
ret_loss = {}
for key in loss_dict[0].keys():
ret_loss[key] = 0.0
for d in loss_dict:
ret_loss[key] += d[key]
return ret_loss, loss_dict
def get_detection_result(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
nms_thr: float = 0.45,
score_thr: float = 0.01,
debug_id: int = 0,
debug_mode: bool = False,
):
assert len(input_ids) == 1, "only batch size = 1 is supported yet"
# assert len(self.image_embedding) == 1, "only one image is supported yet"
# assert (input_ids[..., -1] == self.visual_token_id).all(), "the last token should be visual token"
visual_token_hidden_state = hidden_states[..., -1, :]
boxes_list = []
scores_list = []
for image_embedding in self.image_embedding:
size = int(image_embedding.shape[0] ** 0.5)
x = torch.cat([image_embedding, visual_token_hidden_state.repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1).unsqueeze(0).permute(0,3,1,2)
with torch.no_grad():
outputs = self.detection_head(xin=[x], labels=None)
boxes = outputs[0,:,:4].cpu().numpy()
scores = outputs[0,:,4].cpu().numpy()
scores_mask = scores > score_thr
boxes = boxes[scores_mask]
boxes = cxcywh2xyxy(boxes)
scores = scores[scores_mask]
keep = nms(boxes, scores, nms_thr=nms_thr)
boxes = boxes[keep]
scores = scores[keep]
if debug_mode:
obj_heatmap = outputs[0,:, -2].reshape(size, size).cpu().numpy()
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure()
sns_plot = sns.heatmap(obj_heatmap)
plt.savefig(f"heatmap_{debug_id}.jpg")
debug_id += 1
boxes_list.append(boxes)
scores_list.append(scores)
if len(boxes_list) == 1:
boxes_list = boxes_list[0]
scores_list = scores_list[0]
return boxes_list, scores_list
def _condition_attention(self, loc_list = None):
for i in range(len(self.lang_encoder.gpt_neox.layers)):
self.lang_encoder.gpt_neox.layers[i].decoder_layer.attention.loc_list = loc_list
def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
use_cached_vision_x: bool = False,
clear_conditioned_layers: bool = True,
past_key_values=None,
use_cache: bool = False,
image_nums=None,
image_start_index_list=None,
added_bbox_list=None,
add_box: bool = False,
relations=None,
debug_mode: bool = False,
):
"""
Forward pass of Flamingo.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W) with F=1
lang_x (torch.Tensor): Language input ids
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
self.valid = True
self.lang_encoder.loc_list = None
if use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert (
vision_x is None
), "Expect vision_x to be None when use_cached_vision_x is True."
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(
vision_x=vision_x,
image_nums=image_nums,
image_start_index_list=image_start_index_list,
added_bbox_list=added_bbox_list if add_box else None,
input_ids=lang_x,
relations=relations,
)
if self.apply_mask:
if self.roi_align:
attend_length = 1 + self.roi_output_size ** 2
else:
attend_length = 2
prebox_loc = (lang_x == self.prebox_token_id).nonzero()
loc_list = []
for (x, y) in prebox_loc:
x = x.item()
y = y.item()
for yy in range(y+1, lang_x.shape[1]):
if lang_x[x, yy] == self.endofobject_token_id:
# [batch_idx, [previsual:prebox], [object:endofobject-1]]
loc_list.append([x, [y-attend_length+1, y], [y+1, yy-1]])
self._condition_attention(loc_list=loc_list)
else:
self._condition_attention(None)
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=True,
)
if vision_x is None:
output['loss'][0] += 0.0 * self.vis_proj(self.vision_encoder.visual(torch.randn(1, 3, 224, 224, device=lang_x.device, dtype=output['loss'].dtype))[1]).mean()
hidden_states = output["hidden_states"][-1]
if self.training and added_bbox_list is not None:
detection_losses, loss_dict = self.get_detection_losses(
input_ids=lang_x,
hidden_states=hidden_states,
added_bbox_list=added_bbox_list,
)
output["detection_losses"] = detection_losses
output["loss_dict"] = loss_dict
elif labels is None:
boxes, scores = self.get_detection_result(
input_ids=lang_x,
hidden_states=hidden_states,
debug_id=self.debug_id if hasattr(self, "debug_id") else None,
debug_mode=debug_mode,
)
output["boxes"] = boxes
output["scores"] = scores
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
self._condition_attention(None)
return output
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
added_bbox_list=None,
num_beams=1,
max_new_tokens=None,
temperature=1.0,
top_k=0,
top_p=1.0,
no_repeat_ngram_size=0,
prefix_allowed_tokens_fn=None,
length_penalty=1.0,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
bad_words_ids=None,
force_words_ids=None,
image_start_index_list=None,
image_nums=None,
min_length=None,
return_dict_in_generate=False,
output_hidden_states=False,
output_scores=False,
logits_processor_list=None,
eos_token_id=None,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
images in the same chunk are collated along T_img, and frames are collated along F
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 0.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
image_start_index_list = torch.tensor(image_start_index_list).repeat_interleave(num_beams, dim=0).tolist()
image_nums = torch.tensor(image_nums).repeat_interleave(num_beams, dim=0).tolist()
if added_bbox_list is not None and len(added_bbox_list) != 0:
added_bbox_list = added_bbox_list * num_beams
self._encode_vision_x(vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, num_beams=num_beams, added_bbox_list=added_bbox_list, input_ids=lang_x.repeat_interleave(num_beams, dim=0))
if logits_processor_list is not None:
assert isinstance(logits_processor_list, list)
logits_processor_list = LogitsProcessorList(logits_processor_list)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
eos_token_id=(self.eoc_token_id) if eos_token_id is None else eos_token_id,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_length=min_length,
length_penalty=length_penalty,
logits_processor=logits_processor_list,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
)
self.lang_encoder.clear_conditioned_layers()
return output
def _get_data_list_and_visual_tokens(
self,
all_box_list,
box_token_id,
prebox_token_id,
input_ids,
vision_x,
nothing_embedding = None,
):
box_locations = (torch.logical_or(input_ids == box_token_id, input_ids == prebox_token_id)).nonzero()
prev_batch_idx = -1
media_idx = []
cnt = 0
data_list = []
visual_tokens = []
if len(all_box_list) != len(box_locations):
logging.info(f"WARNING. len(all_box_list) != len(box_locations) {len(all_box_list)} vs {len(box_locations)}")
self.valid = False
for III, (batch_idx, idx) in enumerate(box_locations):
batch_idx = batch_idx.item()
idx = idx.item()
if batch_idx != prev_batch_idx:
prev_batch_idx = batch_idx
this_input_ids = input_ids[batch_idx]
cnt += len(media_idx)
media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
for i in range(len(media_idx)):
if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
break
image_index = cnt + i
size = int(vision_x[image_index].shape[0] ** 0.5)
image_feature = vision_x[image_index].reshape(size, size, -1)
try:
raw_xyxy = all_box_list[III]
except:
logging.info("out of scope for all_box_list")
raw_xyxy = all_box_list[-1]
region_xyxy = np.array(raw_xyxy) * size
x1, y1, x2, y2 = region_xyxy.astype(int).clip(0, size-1).tolist()
x2 = max(x1, x2)
y2 = max(y1, y2)
if x1 + y1 + x2 + y2 == 0.0 and nothing_embedding is not None:
visual_token = nothing_embedding
else:
if self.roi_align:
visual_token = torchvision.ops.roi_align(
image_feature.permute(2, 0, 1).unsqueeze(0),
[torch.tensor(region_xyxy.astype(np.float32)).unsqueeze(0).cuda()],
output_size=self.roi_output_size,
spatial_scale=1.0,
)
visual_token = visual_token.squeeze(0).flatten(1).permute(1, 0)
else:
visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0)
box = torch.tensor([0] + raw_xyxy, device=visual_token.device, dtype=visual_token.dtype)
data_list.append([visual_token, box, batch_idx, idx, i])
visual_tokens.append(visual_token)
return data_list, visual_tokens
def _encode_vision_x(self, vision_x: torch.Tensor, image_nums=None, image_start_index_list=None, added_bbox_list=None, num_beams=None, input_ids=None, relations=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
if hasattr(self.vision_encoder, "visual"):
vision_x = self.vision_encoder.visual(vision_x)[1]
else:
vision_x = self.vision_encoder(vision_x).flatten(2).permute(0, 2, 1)
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
# print(vision_x[0,0,0])
# # DEBUG HERE
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
# else:
# torch.distributed.barrier()
vision_x = vision_x.mean(2)
# vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
# vision_x = self.vis_proj(vision_x) + self.vis_position_embedding(self.vis_position_ids).unsqueeze(0)
vision_x = self.vis_proj(vision_x).squeeze(1)
self.image_embedding = vision_x
data_list = None
visual_tokens = None
if added_bbox_list is not None and input_ids is not None:
all_box_list = added_bbox_list[0].tolist()
for list in added_bbox_list[1:]:
all_box_list.extend(list.tolist())
data_list, visual_tokens = self._get_data_list_and_visual_tokens(
all_box_list=all_box_list,
box_token_id=self.box_token_id,
prebox_token_id=self.prebox_token_id,
input_ids=input_ids,
vision_x=vision_x,
nothing_embedding=self.lang_encoder.gpt_neox.embed_in(torch.tensor(self.nothing_token_id).to(self.lang_encoder.gpt_neox.embed_in.weight.device)) if self.nothing_token_id is not None else None,
)
first_layer = self.lang_encoder._get_decoder_layers()[0]
first_layer.condition_vis_x(vision_x, image_nums, image_start_index_list, num_beams=num_beams, visual_tokens=visual_tokens, data_list=[[d[2], d[3]] for d in data_list] if data_list is not None else data_list)