import random import torch import torch.nn as nn import numpy as np from .helpers import GatedCrossAttentionBlock from .utils import getattr_recursive, setattr_recursive class FlamingoLayer(nn.Module): def __init__(self, decoder_layer): super().__init__() self.decoder_layer = decoder_layer self.vis_x = None self.image_nums = None self.image_start_index_list = None self.media_locations = None self.add_visual_token = False self.input_ids = None def is_conditioned(self) -> bool: """Check whether the layer is conditioned.""" return self.vis_x is not None # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None): self.vis_x = vis_x self.image_nums = image_nums self.image_start_index_list = image_start_index_list self.num_beams = num_beams self.visual_tokens = visual_tokens self.data_list = data_list self.input_ids = None def condition_media_locations(self, media_locations): self.media_locations = media_locations def condition_attend_previous(self, attend_previous): self.attend_previous = attend_previous def forward( self, hidden_states, # alignment with hugging face name attention_mask=None, **decoder_layer_kwargs, ): if self.media_locations is None: raise ValueError("media_locations must be conditioned before forward pass") if self.vis_x is not None: if self.training: single_length = self.vis_x.shape[-2] image_nums = self.image_nums image_start_index_list = self.image_start_index_list image_nums = [0] + np.cumsum(image_nums).tolist() for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)): for index in start_indices: if image_num_begin < image_num_end: hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin] image_num_begin += 1 if self.visual_tokens is not None and len(self.visual_tokens) != 0: for i, (x, y) in enumerate(self.data_list): if len(self.visual_tokens[i].shape) > 1: # print(self.visual_tokens[i].shape[0], "embedding") hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i] else: # print(self.visual_tokens[i].shape[0], "embedding") hidden_states[x, y] = self.visual_tokens[i] elif not self.training: if ( ("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or ("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None) ): single_length = self.vis_x.shape[-2] image_nums = self.image_nums image_start_index_list = self.image_start_index_list image_nums = [0] + np.cumsum(image_nums).tolist() for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)): for index in start_indices: if image_num_begin < image_num_end: hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin] image_num_begin += 1 if self.visual_tokens is not None and len(self.visual_tokens) != 0: for i, (x, y) in enumerate(self.data_list): # import pdb; pdb.set_trace() # print(x, y, self.visual_tokens[i].shape) if len(self.visual_tokens[i].shape) > 1: # print(self.visual_tokens[i].shape[0], "embedding") hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i] else: # print(self.visual_tokens[i].shape[0], "embedding") hidden_states[x, y] = self.visual_tokens[i] hidden_states = self.decoder_layer( hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs ) return hidden_states class FlamingoLMMixin(nn.Module): """ Mixin to add cross-attention layers to a language model. """ def set_decoder_layers_attr_name(self, decoder_layers_attr_name): self.decoder_layers_attr_name = decoder_layers_attr_name def _get_decoder_layers(self): return getattr_recursive(self, self.decoder_layers_attr_name) def _set_decoder_layers(self, value): setattr_recursive(self, self.decoder_layers_attr_name, value) def init_flamingo( self, media_token_id, use_media_placement_augmentation, ): """ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. """ self._set_decoder_layers( nn.ModuleList( [FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()] ) ) self.media_token_id = media_token_id self.use_media_placement_augmentation = use_media_placement_augmentation self.initialized_flamingo = True def forward(self, *input, **kwargs): """Condition the Flamingo layers on the media locations before forward()""" if not self.initialized_flamingo: raise ValueError( "Flamingo layers are not initialized. Please call `init_flamingo` first." ) input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0] media_locations = input_ids == self.media_token_id attend_previous = ( (random.random() < 0.5) if self.use_media_placement_augmentation else True ) if ( "gpt2" in self.__class__.__name__.lower() or "codegen" in self.__class__.__name__.lower() ): for layer in self.transformer.h: layer.condition_media_locations(media_locations) layer.condition_attend_previous(attend_previous) elif "gptneox" in self.__class__.__name__.lower(): for layer in self.gpt_neox.layers: layer.condition_media_locations(media_locations) layer.condition_attend_previous(attend_previous) else: for layer in self.get_decoder().layers: layer.condition_media_locations(media_locations) layer.condition_attend_previous(attend_previous) return super().forward( *input, **kwargs ) # Call the other parent's forward method def is_conditioned(self) -> bool: """Check whether all decoder layers are already conditioned.""" return all(l.is_conditioned() for l in self._get_decoder_layers()) def clear_conditioned_layers(self): for layer in self._get_decoder_layers(): layer.condition_vis_x(None) layer.condition_media_locations(None) layer.condition_attend_previous(None)