import random import pdb from einops import rearrange from typing import List, Optional, Tuple, Union import os import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast import transformers.models.opt.modeling_opt as modeling_opt from transformers.models.opt.modeling_opt\ import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig from transformers import ViTModel from .utils import exists, freeze_all_layers_, unfreeze_all_layers_ from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler class OPTLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, num_embeddings: int, embedding_dim: int): # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" attention_mask = attention_mask.long() # create positions depending on attention_mask positions = torch.cumsum(attention_mask, dim=1) positions = (positions.type_as(attention_mask) * attention_mask).long() - 1 # cut positions if `past_key_values_length` is > 0 positions = positions[:, past_key_values_length:] return super().forward(positions + self.offset) class OPTDecoder(modeling_opt.OPTDecoder): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] Args: config: OPTConfig embed_tokens (nn.Embedding): output embedding """ def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) self.dropout = config.dropout self.layerdrop = config.layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size self.media_token_id = config.media_token_id self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) if config.word_embed_proj_dim != config.hidden_size: self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) else: self.project_in = None # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # with checkpoints that have been fine-tuned before transformers v4.20.1 # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm(config.hidden_size) else: self.final_layer_norm = None dim_head = config.hidden_size // config.num_attention_heads if not config.id_perceiver: self.perceiver_resampler = PerceiverResampler( dim=config.hidden_size, depth=config.perceiver_depth, dim_head=dim_head, heads=config.num_attention_heads, num_latents=config.perceiver_num_latents, inp_dim=config.inp_dim, ) else: if config.inp_dim is None: self.perceiver_resampler = nn.Identity() else: self.perceiver_resampler = nn.Linear( config.inp_dim, config.hidden_size, bias=False) self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gated_attn_layers = nn.ModuleList( [GatedCrossAttentionBlock( dim=config.hidden_size, dim_head=dim_head, heads=config.num_attention_heads, only_attend_immediate_media=config.only_attend_immediate_media)\ if not (ind % config.cross_attn_every) else None \ for ind in range(config.num_hidden_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() # in flamingo mode, freeze everything but perceiver and gated cross attention if not config.finetune_LM: freeze_all_layers_(self) unfreeze_all_layers_(self.perceiver_resampler) [unfreeze_all_layers_(cross_attn) for cross_attn in self.gated_attn_layers if exists(cross_attn)] def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values=None, image_embeds=None ) -> Union[Tuple, BaseModelOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") batch, device = input_ids.shape[0], input_ids.device flamingo_mode = exists(pixel_values) or exists(image_embeds) # derive the media token ids (as a boolean tensor), for calculating the masked cross attention if flamingo_mode: media_locations = input_ids == self.media_token_id assert not (exists(pixel_values) and exists(image_embeds)) # encode images into embeddings # with the img_encoder passed in at init # it can also accept precomputed image embeddings if exists(pixel_values): assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding' if len(pixel_values.shape) == 4: pixel_values = torch.unsqueeze(pixel_values, 1) pixel_values = rearrange(pixel_values, 'b t ... -> (b t) ...') with torch.no_grad(): if getattr(self.img_encoder, 'vision_model', None) is not None: image_outputs = self.img_encoder.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True) else: image_outputs = self.img_encoder( pixel_values=pixel_values, output_hidden_states=True, return_dict=True) image_embeds = image_outputs['last_hidden_state'] image_embeds = rearrange(image_embeds, '(b t) ... -> b t ...', b = batch) if exists(image_embeds): image_embeds = self.perceiver_resampler(image_embeds) past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) pos_embeds = self.embed_positions(attention_mask, past_key_values_length) attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) if self.project_in is not None: inputs_embeds = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask], ["head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): continue past_key_value = past_key_values[idx] if past_key_values is not None else None flamingo_cross_attn = self.gated_attn_layers[idx] if exists(flamingo_cross_attn) and exists(image_embeds): hidden_states = flamingo_cross_attn( hidden_states, image_embeds, media_locations = media_locations ) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: hidden_states = self.project_out(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class OPTModel(modeling_opt.OPTModel): def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) self.decoder = OPTDecoder(config) # Initialize weights and apply final processing self.post_init() class OPTForCausalLM(modeling_opt.OPTForCausalLM): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def __init__(self, config): OPTPreTrainedModel.__init__(self, config) self.model = OPTModel(config) # the lm_head weight is automatically tied to the embed tokens weight self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def set_default_if_nonexist(config, key, value): if getattr(config, key, None) is None: setattr(config, key, value) return config def setup_default_flamingo_configs(config): set_default_if_nonexist(config, 'perceiver_depth', 2) set_default_if_nonexist(config, 'perceiver_num_latents', 64) set_default_if_nonexist(config, 'cross_attn_every', 3) set_default_if_nonexist(config, 'only_attend_immediate_media', True) set_default_if_nonexist(config, 'media_token_id', 50265) set_default_if_nonexist(config, 'inp_dim', 768) set_default_if_nonexist(config, 'finetune_LM', True) set_default_if_nonexist(config, 'id_perceiver', False) return config class FlamingoForCausalLM(modeling_opt.OPTForCausalLM): _keys_to_ignore_on_load_missing = [ r"lm_head.weight", ] def __init__(self, config): OPTPreTrainedModel.__init__(self, config) config = setup_default_flamingo_configs(config) self.model = OPTModel(config) # the lm_head weight is automatically tied to the embed tokens weight self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() self.model.decoder.img_encoder = None self.loss_fct = CrossEntropyLoss() dino_model = ViTModel.from_pretrained("facebook/dino-vitb16") self.setup_vis_encoder(dino_model) def setup_vis_encoder(self, img_encoder): self.model.decoder.img_encoder = img_encoder freeze_all_layers_(img_encoder) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, *args, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: Example: ```python >>> from transformers import GPT2Tokenizer, OPTForCausalLM >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model.decoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, *args, **kwargs) logits = self.lm_head(outputs[0]).contiguous() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )