|
from transformers import ( |
|
PreTrainedModel, |
|
VisionEncoderDecoderModel, |
|
VisionEncoderDecoderConfig, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
AutoConfig |
|
) |
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput |
|
from torch import nn |
|
from .configuration_clipcap import CLIPEncoderDecoderConfig |
|
from typing import Optional, Tuple, Union |
|
import torch |
|
import gc |
|
import os |
|
import tempfile |
|
|
|
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
if decoder_start_token_id is None: |
|
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
if pad_token_id is None: |
|
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") |
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
|
return shifted_input_ids |
|
|
|
|
|
class Encoder(nn.Module): |
|
main_input_name = 'pixel_values' |
|
def __init__(self): |
|
super().__init__() |
|
clip = AutoModel.from_pretrained('openai/clip-vit-base-patch32') |
|
self.vision_model = clip.vision_model |
|
self.visual_projection = clip.visual_projection |
|
self.config = clip.vision_model.config |
|
self.config.hidden_size = clip.config.projection_dim |
|
|
|
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=False, **kwargs): |
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = vision_outputs[1] |
|
image_features = self.visual_projection(pooled_output).view(pooled_output.size(0), 1, -1) |
|
return BaseModelOutput(last_hidden_state=image_features) |
|
def get_output_embeddings(self): |
|
pass |
|
|
|
class CLIPEncoderDecoderModel(PreTrainedModel): |
|
config_class = CLIPEncoderDecoderConfig |
|
base_model_prefix = "clip_encoder_decoder" |
|
main_input_name = "pixel_values" |
|
supports_gradient_checkpointing = True |
|
def __init__( |
|
self, |
|
config = None, |
|
encoder = None, |
|
decoder = None, |
|
): |
|
config.tie_word_embeddings = False |
|
super().__init__(config) |
|
|
|
encoder = Encoder() |
|
encoder_hidden_size = encoder.config.hidden_size |
|
|
|
if decoder is None: |
|
decoder = AutoModelForCausalLM.from_config(config.decoder) |
|
|
|
self.encoder = encoder |
|
self.decoder = decoder |
|
|
|
self.encoder.config = self.config.encoder |
|
self.decoder.config = self.config.decoder |
|
|
|
self.enc_to_dec_proj = nn.Linear(encoder_hidden_size, self.decoder.config.hidden_size) |
|
|
|
def get_encoder(self): |
|
return self.encoder |
|
|
|
def get_decoder(self): |
|
return self.decoder |
|
|
|
def get_output_embeddings(self): |
|
return self.decoder.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
return self.decoder.set_output_embeddings(new_embeddings) |
|
|
|
@classmethod |
|
def from_encoder_decoder_pretrained( |
|
cls, |
|
encoder_pretrained_model_name_or_path: str = None, |
|
decoder_pretrained_model_name_or_path: str = None, |
|
*model_args, |
|
**kwargs, |
|
) -> PreTrainedModel: |
|
kwargs_encoder = { |
|
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") |
|
} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
|
|
for key in kwargs_encoder.keys(): |
|
del kwargs["encoder_" + key] |
|
for key in kwargs_decoder.keys(): |
|
del kwargs["decoder_" + key] |
|
|
|
|
|
|
|
|
|
encoder = kwargs_encoder.pop("model", None) |
|
if encoder is None: |
|
if encoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_encoder: |
|
encoder_config, kwargs_encoder = AutoConfig.from_pretrained( |
|
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True |
|
) |
|
|
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: |
|
encoder_config.is_decoder = False |
|
encoder_config.add_cross_attention = False |
|
|
|
kwargs_encoder["config"] = encoder_config |
|
|
|
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) |
|
|
|
decoder = kwargs_decoder.pop("model", None) |
|
if decoder is None: |
|
if decoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_decoder: |
|
decoder_config, kwargs_decoder = AutoConfig.from_pretrained( |
|
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True |
|
) |
|
|
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: |
|
decoder_config.is_decoder = True |
|
decoder_config.add_cross_attention = True |
|
|
|
kwargs_decoder["config"] = decoder_config |
|
|
|
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) |
|
|
|
|
|
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) |
|
|
|
|
|
config.tie_word_embeddings = False |
|
return cls(encoder=encoder, decoder=decoder, config=config) |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.BoolTensor] = None, |
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
if encoder_outputs is None: |
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
encoder_outputs = self.encoder( |
|
pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs_encoder, |
|
) |
|
elif isinstance(encoder_outputs, tuple): |
|
encoder_outputs = BaseModelOutput(*encoder_outputs) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
|
|
|
|
|
encoder_attention_mask = None |
|
|
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): |
|
decoder_input_ids = shift_tokens_right( |
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
**kwargs_decoder, |
|
) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1)) |
|
|
|
if not return_dict: |
|
if loss is not None: |
|
return (loss,) + decoder_outputs + encoder_outputs |
|
else: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs |
|
): |
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) |
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None |
|
input_dict = { |
|
"attention_mask": attention_mask, |
|
"decoder_attention_mask": decoder_attention_mask, |
|
"decoder_input_ids": decoder_inputs["input_ids"], |
|
"encoder_outputs": encoder_outputs, |
|
"past_key_values": decoder_inputs["past_key_values"], |
|
"use_cache": use_cache, |
|
} |
|
return input_dict |
|
|
|
def resize_token_embeddings(self, *args, **kwargs): |
|
raise NotImplementedError( |
|
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" |
|
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" |
|
) |
|
|
|
def _reorder_cache(self, past_key_values, beam_idx): |
|
|
|
return self.decoder._reorder_cache(past_key_values, beam_idx) |
|
|