from functools import partial from typing import Any, List, Optional, Mapping, Callable from collections import OrderedDict from argparse import Namespace import torch from torch import nn import torch.nn.functional as F import torchvision.transforms as T import PIL import transformers from transformers import PreTrainedModel, PreTrainedTokenizer from .configuration_emu import EmuConfig from .constants import * from .modeling_llama import LlamaForCausalLM from .visual import EVAVisionTransformer class EmuPreTrainedModel(PreTrainedModel): config_class = EmuConfig base_model_prefix = "model" supports_gradient_checkpointing = False _no_split_modules = ["LlamaDecoderLayer", "Block"] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class EmuForClsAndRegression(EmuPreTrainedModel): def __init__(self, config): super(EmuForClsAndRegression, self).__init__(config) self.lm = LlamaForCausalLM(config=config) self.lm.model.embed_tokens.padding_idx = config.pad_token_id def get_num_layers(self): return len(self.lm.model.layers) class EmuModel(EmuPreTrainedModel): def __init__(self, config): super().__init__(config) vision_config = Namespace(**config.vision_config) self.visual = EVAVisionTransformer( img_size=vision_config.image_size, patch_size=vision_config.patch_size, embed_dim=vision_config.width, depth=vision_config.layers, num_heads=vision_config.width // vision_config.head_width, mlp_ratio=vision_config.mlp_ratio, qkv_bias=vision_config.qkv_bias, drop_path_rate=vision_config.drop_path_rate, norm_layer=partial(nn.LayerNorm, eps=vision_config.layer_norm_eps), xattn=vision_config.xattn, postnorm=vision_config.postnorm, ) self.decoder = EmuForClsAndRegression(config) self.gradient_checkpointing = False self.n_query = vision_config.n_query self.v_query = vision_config.v_query @property def device(self): return next(iter(self.parameters())).device @property def dtype(self): return next(iter(self.parameters())).dtype @torch.no_grad() def encode_image(self, image: torch.Tensor, *, n_query=None): n_query = n_query if n_query is not None else self.n_query image_embeds = self.visual(image) image_embeds = image_embeds[:, 1:, :] b, n, c = image_embeds.shape sqrt_n = int(n**0.5) image_embeds = image_embeds.permute(0, 2, 1).view(b, c, sqrt_n, sqrt_n) stride = int(sqrt_n // (n_query ** 0.5)) image_embeds = F.avg_pool2d(image_embeds, kernel_size=(stride, stride), stride=stride) image_embeds = image_embeds.view(b, c, -1).permute(0, 2, 1).contiguous() return image_embeds class EmuForCausalLM(EmuPreTrainedModel): _auto_class = "AutoModelForCausalLM" def __init__(self, config): super().__init__(config) self.config = config self.model = EmuModel(config) # LM to EVA self.project_down = nn.Linear(config.hidden_size, config.d_model, bias=False) # EVA to LM self.project_up = nn.Linear(config.d_model, config.hidden_size, bias=False) self.n_query = self.model.n_query self.v_query = self.model.v_query self.image_placeholder = DEFAULT_IMG_TOKEN + DEFAULT_IMAGE_TOKEN * self.n_query + DEFAULT_IMG_END_TOKEN # temporarily borrow [gIMG] as the video frame feature placeholder. self.video_placeholder = DEFAULT_IMG_TOKEN + DEFAULT_gIMG_TOKEN * self.v_query + DEFAULT_IMG_END_TOKEN @property def device(self): return next(iter(self.parameters())).device @property def dtype(self): return next(iter(self.parameters())).dtype @torch.no_grad() def generate( self, input_ids, attention_mask, image: Optional[torch.Tensor] = None, video: Optional[torch.Tensor] = None, num_beams=5, max_new_tokens=10, min_len=1, do_sample=False, penalty_alpha=None, top_p=None, top_k=None, temperature=None, length_penalty=-1, repetition_penalty=1.0, **kwargs ): text_embeds = self.model.decoder.lm.model.embed_tokens(input_ids).to("cuda") if image is not None: prompt_image_embeds = self.model.encode_image(image, n_query=self.n_query) _, _, c = prompt_image_embeds.shape prompt_image_embeds = prompt_image_embeds.view(-1, c) prompt_image_embeds = self.project_up(prompt_image_embeds) image_idx = (input_ids == IMAGE) text_embeds[image_idx] = prompt_image_embeds.to(text_embeds.device) if video is not None: prompt_video_embeds = self.model.encode_image(video, n_query=self.v_query) _, _, c = prompt_video_embeds.shape prompt_video_embeds = prompt_video_embeds.view(-1, c) prompt_video_embeds = self.project_up(prompt_video_embeds) video_idx = (input_ids == VIDEO) text_embeds[video_idx] = prompt_video_embeds.to(text_embeds.device) outputs = self.model.decoder.lm.generate( inputs_embeds=text_embeds, attention_mask=attention_mask, do_sample=do_sample, num_beams=num_beams, max_new_tokens=max_new_tokens, min_length=min_len, length_penalty=length_penalty, repetition_penalty=repetition_penalty, penalty_alpha=penalty_alpha, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs, ) return outputs def prepare_image_input(self, images): image_size: int = self.config.vision_config['image_size'] transform = T.Compose( [ T.Resize( (image_size, image_size), interpolation=T.InterpolationMode.BICUBIC ), T.ToTensor(), T.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD), ] ) images = [transform(image) for image in images] return torch.stack(images, 0) def _prepare_chat_template(self, text, system_msg=""): text = [ system_msg + USER_TOKEN + ": " + t + ASSISTANT_TOKEN +":" for t in text ] return text def prepare_text_input( self, text: List[str], tokenizer: PreTrainedTokenizer, image_placeholder: str = DEFAULT_IMG_PLACEHOLDER, video_placeholder: str = DEFAULT_VID_PLACEHOLDER, ): text = [ t.replace(image_placeholder, self.image_placeholder).replace(video_placeholder, self.video_placeholder) for t in text ] input_ids = tokenizer(text, padding="longest", return_tensors="pt") return input_ids def build_input_ids( self, text: List[str], tokenizer: PreTrainedTokenizer, image: Optional[List["PIL.Image"]] = None, video: Optional[List["PIL.Image"]] = None, system_msg: str = "", to_cuda: bool = True ): if self.config.model_version == "chat": text = self._prepare_chat_template(text, system_msg) if image is not None: image = self.prepare_image_input(image) if video is not None: video = self.prepare_image_input(video) inputs = self.prepare_text_input(text, tokenizer) input_ids = inputs.input_ids attention_mask = inputs.attention_mask if to_cuda: input_ids = input_ids.to("cuda") attention_mask = attention_mask.to("cuda") if image is not None: image = image.to("cuda") if video is not None: video = video.to("cuda") return { 'input_ids': input_ids, 'attention_mask': attention_mask, 'image': image, 'video': video }