chendl's picture
Add application file
0b7b08a
raw
history blame
7.77 kB
import random
import torch
import torch.nn as nn
import numpy as np
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
class FlamingoLayer(nn.Module):
def __init__(self, decoder_layer):
super().__init__()
self.decoder_layer = decoder_layer
self.vis_x = None
self.image_nums = None
self.image_start_index_list = None
self.media_locations = None
self.add_visual_token = False
self.input_ids = None
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None):
self.vis_x = vis_x
self.image_nums = image_nums
self.image_start_index_list = image_start_index_list
self.num_beams = num_beams
self.visual_tokens = visual_tokens
self.data_list = data_list
self.input_ids = None
def condition_media_locations(self, media_locations):
self.media_locations = media_locations
def condition_attend_previous(self, attend_previous):
self.attend_previous = attend_previous
def forward(
self,
hidden_states, # alignment with hugging face name
attention_mask=None,
**decoder_layer_kwargs,
):
if self.media_locations is None:
raise ValueError("media_locations must be conditioned before forward pass")
if self.vis_x is not None:
if self.training:
single_length = self.vis_x.shape[-2]
image_nums = self.image_nums
image_start_index_list = self.image_start_index_list
image_nums = [0] + np.cumsum(image_nums).tolist()
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
for index in start_indices:
if image_num_begin < image_num_end:
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
image_num_begin += 1
if self.visual_tokens is not None and len(self.visual_tokens) != 0:
for i, (x, y) in enumerate(self.data_list):
if len(self.visual_tokens[i].shape) > 1:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
else:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y] = self.visual_tokens[i]
elif not self.training:
if (
("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or
("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None)
):
single_length = self.vis_x.shape[-2]
image_nums = self.image_nums
image_start_index_list = self.image_start_index_list
image_nums = [0] + np.cumsum(image_nums).tolist()
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
for index in start_indices:
if image_num_begin < image_num_end:
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
image_num_begin += 1
if self.visual_tokens is not None and len(self.visual_tokens) != 0:
for i, (x, y) in enumerate(self.data_list):
# import pdb; pdb.set_trace()
# print(x, y, self.visual_tokens[i].shape)
if len(self.visual_tokens[i].shape) > 1:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
else:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y] = self.visual_tokens[i]
hidden_states = self.decoder_layer(
hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs
)
return hidden_states
class FlamingoLMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_flamingo(
self,
media_token_id,
use_media_placement_augmentation,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
"""
self._set_decoder_layers(
nn.ModuleList(
[FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()]
)
)
self.media_token_id = media_token_id
self.use_media_placement_augmentation = use_media_placement_augmentation
self.initialized_flamingo = True
def forward(self, *input, **kwargs):
"""Condition the Flamingo layers on the media locations before forward()"""
if not self.initialized_flamingo:
raise ValueError(
"Flamingo layers are not initialized. Please call `init_flamingo` first."
)
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
media_locations = input_ids == self.media_token_id
attend_previous = (
(random.random() < 0.5) if self.use_media_placement_augmentation else True
)
if (
"gpt2" in self.__class__.__name__.lower()
or "codegen" in self.__class__.__name__.lower()
):
for layer in self.transformer.h:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
elif "gptneox" in self.__class__.__name__.lower():
for layer in self.gpt_neox.layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
else:
for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
return super().forward(
*input, **kwargs
) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_attend_previous(None)