chendl's picture
Add application file
0b7b08a
raw
history blame
11.4 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import open_clip
import torch
from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .utils import extend_instance
import logging
import random
import time
def create_model_and_transforms(
clip_vision_encoder_path: str,
clip_vision_encoder_pretrained: str,
lang_encoder_path: str,
tokenizer_path: str,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
location_token_num: int = 1000,
checkpoint_activations: bool = False,
freeze_vision_encoder: bool = False,
lora: bool = False,
lora_r: int = 16,
fix_ffn: bool = False,
add_visual_token: bool = False,
add_box: bool = False,
add_pe: bool = False,
add_relation: bool = False,
use_format_v2: bool = False,
use_sam: str = None,
enhance_data: bool = False,
roi_align: bool = False,
roi_output_size: int = 4,
apply_mask: bool = False,
**flamingo_kwargs,
):
"""
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
Appends special tokens to the tokenizer and freezes backbones.
Args:
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
lang_encoder_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""
if use_sam is None:
no_success = True
while no_success:
try:
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
)
no_success = False
except:
logging.info("retry creating vision_encoder")
time.sleep(random.random() * 5)
# set the vision encoder to output the visual features
vision_encoder.visual.output_tokens = True
# delete text encoder part
del vision_encoder.transformer
del vision_encoder.text_projection
del vision_encoder.token_embedding
del vision_encoder.ln_final
del vision_encoder.positional_embedding
del vision_encoder.logit_scale
vision_encoder.visual.proj = None
vision_encoder.visual.ln_post = torch.nn.Identity()
else:
from segment_anything import SamPredictor, sam_model_registry
assert use_sam == "vit_l"
sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth")
del sam.prompt_encoder
del sam.mask_decoder
sam.image_encoder.neck = torch.nn.Identity()
vision_encoder = sam.image_encoder
from open_clip.transform import image_transform
image_processor = image_transform(
256,
is_train=False,
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, local_files_only=use_local_files
)
# add Flamingo special tokens to the tokenizer
additional_special_tokens = ["<|#image#|>", "<|#endofimage#|>"]
if add_visual_token:
additional_special_tokens += ["<|#visual#|>", "<|#object#|>"]
if add_box:
additional_special_tokens += ["<|#box#|>", "<|#endofobject#|>", "<|#attr#|>", "<|#endofattr#|>"]
if use_format_v2:
additional_special_tokens += ["<|#previsual#|>", "<|#prebox#|>"]
if enhance_data:
additional_special_tokens += ["<|#NOTHING#|>"]
text_tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path, local_files_only=use_local_files
)
extend_instance(lang_encoder, FlamingoLMMixin)
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))
lang_encoder_name = lang_encoder.__class__.__name__.lower()
if checkpoint_activations:
from fairscale.nn.checkpoint import checkpoint_wrapper
if use_sam is None:
for i in range(len(vision_encoder.visual.transformer.resblocks)):
vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper(
vision_encoder.visual.transformer.resblocks[i],
offload_to_cpu=False,
)
else:
for i in range(len(vision_encoder.blocks)):
vision_encoder.blocks[i] = checkpoint_wrapper(
vision_encoder.blocks[i],
offload_to_cpu=False,
)
if "opt" in lang_encoder_name:
for i in range(len(lang_encoder.model.decoder.layers)):
lang_encoder.model.decoder.layers[i] = checkpoint_wrapper(
lang_encoder.model.decoder.layers[i],
offload_to_cpu=False,
)
elif "codegen" in lang_encoder_name:
for i in range(len(lang_encoder.transformer.h)):
lang_encoder.transformer.h[i] = checkpoint_wrapper(
lang_encoder.transformer.h[i],
offload_to_cpu=False,
)
elif "llama" in lang_encoder_name:
for i in range(len(lang_encoder.model.layers)):
lang_encoder.model.layers[i] = checkpoint_wrapper(
lang_encoder.model.layers[i],
offload_to_cpu=False,
)
elif "gptneo" in lang_encoder_name:
for i in range(len(lang_encoder.gpt_neox.layers)):
lang_encoder.gpt_neox.layers[i] = checkpoint_wrapper(
lang_encoder.gpt_neox.layers[i],
offload_to_cpu=False,
)
else:
raise ValueError(f"unknown model {lang_encoder_name}")
if use_sam is None:
vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"]
image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"]
patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"]
else:
# SAM config
vis_dim = 1024
image_size = 256
patch_size = 16
assert image_size % patch_size == 0
vis_embed_size = (image_size // patch_size) ** 2
if lora:
from peft import LoraConfig, TaskType
from peft import get_peft_model
if "codegen" in lang_encoder_name:
lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
elif "opt" in lang_encoder_name:
lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
elif "llama" in lang_encoder_name:
lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
else:
raise NotImplementedError
lang_peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=16, lora_alpha=16,
target_modules=lang_target_modules,
lora_dropout=0.05, bias="none",
)
lang_encoder = get_peft_model(lang_encoder, lang_peft_config)
lang_encoder.print_trainable_parameters()
if fix_ffn:
if "opt" in lang_encoder_name:
for i in range(len(lang_encoder.model.decoder.layers)):
lang_encoder.model.decoder.layers[i].requires_grad_(False)
lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True)
else:
raise NotImplementedError
lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size)
if hasattr(lang_encoder.config, "word_embed_proj_dim"):
hidden_state_dim = lang_encoder.config.word_embed_proj_dim
else:
hidden_state_dim = lang_encoder.config.hidden_size
model = Flamingo(
vision_encoder=vision_encoder,
lang_encoder=lang_encoder,
eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1],
media_token_id=text_tokenizer.encode("<|#image#|>")[-1],
image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1],
visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None,
previsual_token_id=text_tokenizer.encode("<|#previsual#|>")[-1] if add_visual_token else None,
box_token_id=text_tokenizer.encode("<|#box#|>")[-1] if add_box else None,
prebox_token_id=text_tokenizer.encode("<|#prebox#|>")[-1] if add_box else None,
nothing_token_id=text_tokenizer.encode("<|#NOTHING#|>")[-1] if enhance_data else None,
endofobject_token_id=text_tokenizer.encode("<|#endofobject#|>")[-1],
vis_dim=vis_dim,
vis_embed_size=vis_embed_size,
lang_dim=lang_dim,
image_size=image_size,
patch_size=patch_size,
hidden_state_dim=hidden_state_dim,
add_visual_token=add_visual_token,
add_pe=add_pe,
add_relation=add_relation,
use_format_v2=use_format_v2,
roi_align=roi_align,
roi_output_size=roi_output_size,
apply_mask=apply_mask,
**flamingo_kwargs,
)
if freeze_vision_encoder:
print("freeze vision encoder")
model.vision_encoder.requires_grad_(False)
print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
)
return model, image_processor, text_tokenizer, vis_embed_size
def _infer_decoder_layers_attr_name(model):
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
if k.lower() in model.__class__.__name__.lower():
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
raise ValueError(
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
)
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
"opt": "model.decoder.layers",
# "gptneo": "transformer.h",
"gptj": "transformer.h",
"gpt-j": "transformer.h",
"pythia": "gpt_neox.layers",
"gptneox": "gpt_neox.layers",
"llama": "model.layers",
"llamaforcausallm": "model.layers",
"gpt2": "transformer.h",
"codegen": "transformer.h",
}