NeTI / models /neti_clip_text_encoder.py
neural-ti's picture
Upload 17 files
3eb1ce9
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPEncoder
from transformers.models.clip.modeling_clip import CLIPTextTransformer, _expand_mask
from models.net_clip_text_embedding import NeTICLIPTextEmbeddings
from utils.types import NeTIBatch
class NeTICLIPTextModel(CLIPTextModel):
""" Modification of CLIPTextModel to use our NeTI mapper for computing the embeddings of the concept. """
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = NeTICLIPTextTransformer(config)
self.post_init()
def forward(self, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
return self.text_model.forward(
batch=batch,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class NeTICLIPTextTransformer(CLIPTextTransformer):
""" Modification of CLIPTextTransformer to use our NeTI mapper for computing the embeddings of the concept. """
def __init__(self, config: CLIPTextConfig):
super().__init__(config=config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = NeTICLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(self, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bypass_output = None
if input_ids is not None: # Regular embedding logic
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states, _ = self.embeddings(input_ids=input_ids, position_ids=position_ids)
###########################
# NeTI logic
###########################
elif batch is not None:
input_shape = batch.input_ids.size()
batch.input_ids = batch.input_ids.view(-1, input_shape[-1])
hidden_states, bypass_output = self.embeddings(batch=batch, position_ids=position_ids)
else:
raise ValueError("You have to specify either batch or input_ids!")
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state_with_bypass = last_hidden_state.clone()
###############################################
# NeTI logic - compute the scaled bypass output
###############################################
if bypass_output is not None:
learnable_idxs = (batch.input_ids == batch.placeholder_token_id).nonzero(as_tuple=True)[1]
existing_state = last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs]
bypass_output = bypass_output / bypass_output.norm(dim=1, keepdim=True) \
* existing_state.norm(dim=1, keepdim=True)
new_state = existing_state + 0.2 * bypass_output
new_state = new_state.to(dtype=hidden_states.dtype)
last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs] = new_state
last_hidden_state = self.final_layer_norm(last_hidden_state)
last_hidden_state_with_bypass = self.final_layer_norm(last_hidden_state_with_bypass)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
if input_ids is not None:
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
]
pooled_output_with_bypass = last_hidden_state_with_bypass[
torch.arange(last_hidden_state_with_bypass.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
]
elif batch is not None:
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
]
pooled_output_with_bypass = last_hidden_state_with_bypass[
torch.arange(last_hidden_state_with_bypass.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
]
else:
raise ValueError("You have to specify either batch or input_ids!")
if bypass_output is not None:
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
), BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state_with_bypass,
pooler_output=pooled_output_with_bypass,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
else:
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
), None