from transformers import AutoModelForCausalLM, AutoTokenizer import open_clip import torch from .flamingo import Flamingo from .flamingo_lm import FlamingoLMMixin from .utils import extend_instance import logging import random import time def create_model_and_transforms( clip_vision_encoder_path: str, clip_vision_encoder_pretrained: str, lang_encoder_path: str, tokenizer_path: str, use_local_files: bool = False, decoder_layers_attr_name: str = None, location_token_num: int = 1000, checkpoint_activations: bool = False, freeze_vision_encoder: bool = False, lora: bool = False, lora_r: int = 16, fix_ffn: bool = False, add_visual_token: bool = False, add_box: bool = False, add_pe: bool = False, add_relation: bool = False, use_format_v2: bool = False, use_sam: str = None, enhance_data: bool = False, roi_align: bool = False, roi_output_size: int = 4, apply_mask: bool = False, **flamingo_kwargs, ): """ Initialize a Flamingo model from a pretrained vision encoder and language encoder. Appends special tokens to the tokenizer and freezes backbones. Args: clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") lang_encoder_path (str): path to pretrained language encoder tokenizer_path (str): path to pretrained tokenizer cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. use_local_files (bool, optional): whether to use local files. Defaults to False. decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. Returns: Flamingo: Flamingo model from pretrained vision and language encoders Image processor: Pipeline to preprocess input images Tokenizer: A tokenizer for the language model """ if use_sam is None: no_success = True while no_success: try: vision_encoder, _, image_processor = open_clip.create_model_and_transforms( clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained ) no_success = False except: logging.info("retry creating vision_encoder") time.sleep(random.random() * 5) # set the vision encoder to output the visual features vision_encoder.visual.output_tokens = True # delete text encoder part del vision_encoder.transformer del vision_encoder.text_projection del vision_encoder.token_embedding del vision_encoder.ln_final del vision_encoder.positional_embedding del vision_encoder.logit_scale vision_encoder.visual.proj = None vision_encoder.visual.ln_post = torch.nn.Identity() else: from segment_anything import SamPredictor, sam_model_registry assert use_sam == "vit_l" sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth") del sam.prompt_encoder del sam.mask_decoder sam.image_encoder.neck = torch.nn.Identity() vision_encoder = sam.image_encoder from open_clip.transform import image_transform image_processor = image_transform( 256, is_train=False, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ) text_tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=use_local_files ) # add Flamingo special tokens to the tokenizer additional_special_tokens = ["<|#image#|>", "<|#endofimage#|>"] if add_visual_token: additional_special_tokens += ["<|#visual#|>", "<|#object#|>"] if add_box: additional_special_tokens += ["<|#box#|>", "<|#endofobject#|>", "<|#attr#|>", "<|#endofattr#|>"] if use_format_v2: additional_special_tokens += ["<|#previsual#|>", "<|#prebox#|>"] if enhance_data: additional_special_tokens += ["<|#NOTHING#|>"] text_tokenizer.add_special_tokens( {"additional_special_tokens": additional_special_tokens} ) if text_tokenizer.pad_token is None: # Issue: GPT models don't have a pad token, which we use to # modify labels for the loss. text_tokenizer.add_special_tokens({"pad_token": ""}) lang_encoder = AutoModelForCausalLM.from_pretrained( lang_encoder_path, local_files_only=use_local_files ) extend_instance(lang_encoder, FlamingoLMMixin) if decoder_layers_attr_name is None: decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) lang_encoder.resize_token_embeddings(len(text_tokenizer)) lang_encoder_name = lang_encoder.__class__.__name__.lower() if checkpoint_activations: from fairscale.nn.checkpoint import checkpoint_wrapper if use_sam is None: for i in range(len(vision_encoder.visual.transformer.resblocks)): vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper( vision_encoder.visual.transformer.resblocks[i], offload_to_cpu=False, ) else: for i in range(len(vision_encoder.blocks)): vision_encoder.blocks[i] = checkpoint_wrapper( vision_encoder.blocks[i], offload_to_cpu=False, ) if "opt" in lang_encoder_name: for i in range(len(lang_encoder.model.decoder.layers)): lang_encoder.model.decoder.layers[i] = checkpoint_wrapper( lang_encoder.model.decoder.layers[i], offload_to_cpu=False, ) elif "codegen" in lang_encoder_name: for i in range(len(lang_encoder.transformer.h)): lang_encoder.transformer.h[i] = checkpoint_wrapper( lang_encoder.transformer.h[i], offload_to_cpu=False, ) elif "llama" in lang_encoder_name: for i in range(len(lang_encoder.model.layers)): lang_encoder.model.layers[i] = checkpoint_wrapper( lang_encoder.model.layers[i], offload_to_cpu=False, ) elif "gptneo" in lang_encoder_name: for i in range(len(lang_encoder.gpt_neox.layers)): lang_encoder.gpt_neox.layers[i] = checkpoint_wrapper( lang_encoder.gpt_neox.layers[i], offload_to_cpu=False, ) else: raise ValueError(f"unknown model {lang_encoder_name}") if use_sam is None: vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"] image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"] patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"] else: # SAM config vis_dim = 1024 image_size = 256 patch_size = 16 assert image_size % patch_size == 0 vis_embed_size = (image_size // patch_size) ** 2 if lora: from peft import LoraConfig, TaskType from peft import get_peft_model if "codegen" in lang_encoder_name: lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"] elif "opt" in lang_encoder_name: lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] elif "llama" in lang_encoder_name: lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"] else: raise NotImplementedError lang_peft_config = LoraConfig( task_type="CAUSAL_LM", r=16, lora_alpha=16, target_modules=lang_target_modules, lora_dropout=0.05, bias="none", ) lang_encoder = get_peft_model(lang_encoder, lang_peft_config) lang_encoder.print_trainable_parameters() if fix_ffn: if "opt" in lang_encoder_name: for i in range(len(lang_encoder.model.decoder.layers)): lang_encoder.model.decoder.layers[i].requires_grad_(False) lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True) else: raise NotImplementedError lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size) if hasattr(lang_encoder.config, "word_embed_proj_dim"): hidden_state_dim = lang_encoder.config.word_embed_proj_dim else: hidden_state_dim = lang_encoder.config.hidden_size model = Flamingo( vision_encoder=vision_encoder, lang_encoder=lang_encoder, eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1], media_token_id=text_tokenizer.encode("<|#image#|>")[-1], image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1], visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None, previsual_token_id=text_tokenizer.encode("<|#previsual#|>")[-1] if add_visual_token else None, box_token_id=text_tokenizer.encode("<|#box#|>")[-1] if add_box else None, prebox_token_id=text_tokenizer.encode("<|#prebox#|>")[-1] if add_box else None, nothing_token_id=text_tokenizer.encode("<|#NOTHING#|>")[-1] if enhance_data else None, endofobject_token_id=text_tokenizer.encode("<|#endofobject#|>")[-1], vis_dim=vis_dim, vis_embed_size=vis_embed_size, lang_dim=lang_dim, image_size=image_size, patch_size=patch_size, hidden_state_dim=hidden_state_dim, add_visual_token=add_visual_token, add_pe=add_pe, add_relation=add_relation, use_format_v2=use_format_v2, roi_align=roi_align, roi_output_size=roi_output_size, apply_mask=apply_mask, **flamingo_kwargs, ) if freeze_vision_encoder: print("freeze vision encoder") model.vision_encoder.requires_grad_(False) print( f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" ) return model, image_processor, text_tokenizer, vis_embed_size def _infer_decoder_layers_attr_name(model): for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: if k.lower() in model.__class__.__name__.lower(): return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] raise ValueError( f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." ) __KNOWN_DECODER_LAYERS_ATTR_NAMES = { "opt": "model.decoder.layers", # "gptneo": "transformer.h", "gptj": "transformer.h", "gpt-j": "transformer.h", "pythia": "gpt_neox.layers", "gptneox": "gpt_neox.layers", "llama": "model.layers", "llamaforcausallm": "model.layers", "gpt2": "transformer.h", "codegen": "transformer.h", }