Spaces:
Runtime error
Runtime error
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import open_clip | |
import torch | |
from .flamingo import Flamingo | |
from .flamingo_lm import FlamingoLMMixin | |
from .utils import extend_instance | |
import logging | |
import random | |
import time | |
def create_model_and_transforms( | |
clip_vision_encoder_path: str, | |
clip_vision_encoder_pretrained: str, | |
lang_encoder_path: str, | |
tokenizer_path: str, | |
use_local_files: bool = False, | |
decoder_layers_attr_name: str = None, | |
location_token_num: int = 1000, | |
checkpoint_activations: bool = False, | |
freeze_vision_encoder: bool = False, | |
lora: bool = False, | |
lora_r: int = 16, | |
fix_ffn: bool = False, | |
add_visual_token: bool = False, | |
add_box: bool = False, | |
add_pe: bool = False, | |
add_relation: bool = False, | |
use_format_v2: bool = False, | |
use_sam: str = None, | |
enhance_data: bool = False, | |
roi_align: bool = False, | |
roi_output_size: int = 4, | |
apply_mask: bool = False, | |
**flamingo_kwargs, | |
): | |
""" | |
Initialize a Flamingo model from a pretrained vision encoder and language encoder. | |
Appends special tokens to the tokenizer and freezes backbones. | |
Args: | |
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") | |
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") | |
lang_encoder_path (str): path to pretrained language encoder | |
tokenizer_path (str): path to pretrained tokenizer | |
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. | |
use_local_files (bool, optional): whether to use local files. Defaults to False. | |
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. | |
Returns: | |
Flamingo: Flamingo model from pretrained vision and language encoders | |
Image processor: Pipeline to preprocess input images | |
Tokenizer: A tokenizer for the language model | |
""" | |
if use_sam is None: | |
no_success = True | |
while no_success: | |
try: | |
vision_encoder, _, image_processor = open_clip.create_model_and_transforms( | |
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained | |
) | |
no_success = False | |
except: | |
logging.info("retry creating vision_encoder") | |
time.sleep(random.random() * 5) | |
# set the vision encoder to output the visual features | |
vision_encoder.visual.output_tokens = True | |
# delete text encoder part | |
del vision_encoder.transformer | |
del vision_encoder.text_projection | |
del vision_encoder.token_embedding | |
del vision_encoder.ln_final | |
del vision_encoder.positional_embedding | |
del vision_encoder.logit_scale | |
vision_encoder.visual.proj = None | |
vision_encoder.visual.ln_post = torch.nn.Identity() | |
else: | |
from segment_anything import SamPredictor, sam_model_registry | |
assert use_sam == "vit_l" | |
sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth") | |
del sam.prompt_encoder | |
del sam.mask_decoder | |
sam.image_encoder.neck = torch.nn.Identity() | |
vision_encoder = sam.image_encoder | |
from open_clip.transform import image_transform | |
image_processor = image_transform( | |
256, | |
is_train=False, | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
) | |
text_tokenizer = AutoTokenizer.from_pretrained( | |
tokenizer_path, local_files_only=use_local_files | |
) | |
# add Flamingo special tokens to the tokenizer | |
additional_special_tokens = ["<|#image#|>", "<|#endofimage#|>"] | |
if add_visual_token: | |
additional_special_tokens += ["<|#visual#|>", "<|#object#|>"] | |
if add_box: | |
additional_special_tokens += ["<|#box#|>", "<|#endofobject#|>", "<|#attr#|>", "<|#endofattr#|>"] | |
if use_format_v2: | |
additional_special_tokens += ["<|#previsual#|>", "<|#prebox#|>"] | |
if enhance_data: | |
additional_special_tokens += ["<|#NOTHING#|>"] | |
text_tokenizer.add_special_tokens( | |
{"additional_special_tokens": additional_special_tokens} | |
) | |
if text_tokenizer.pad_token is None: | |
# Issue: GPT models don't have a pad token, which we use to | |
# modify labels for the loss. | |
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | |
lang_encoder = AutoModelForCausalLM.from_pretrained( | |
lang_encoder_path, local_files_only=use_local_files | |
) | |
extend_instance(lang_encoder, FlamingoLMMixin) | |
if decoder_layers_attr_name is None: | |
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |
lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |
lang_encoder_name = lang_encoder.__class__.__name__.lower() | |
if checkpoint_activations: | |
from fairscale.nn.checkpoint import checkpoint_wrapper | |
if use_sam is None: | |
for i in range(len(vision_encoder.visual.transformer.resblocks)): | |
vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper( | |
vision_encoder.visual.transformer.resblocks[i], | |
offload_to_cpu=False, | |
) | |
else: | |
for i in range(len(vision_encoder.blocks)): | |
vision_encoder.blocks[i] = checkpoint_wrapper( | |
vision_encoder.blocks[i], | |
offload_to_cpu=False, | |
) | |
if "opt" in lang_encoder_name: | |
for i in range(len(lang_encoder.model.decoder.layers)): | |
lang_encoder.model.decoder.layers[i] = checkpoint_wrapper( | |
lang_encoder.model.decoder.layers[i], | |
offload_to_cpu=False, | |
) | |
elif "codegen" in lang_encoder_name: | |
for i in range(len(lang_encoder.transformer.h)): | |
lang_encoder.transformer.h[i] = checkpoint_wrapper( | |
lang_encoder.transformer.h[i], | |
offload_to_cpu=False, | |
) | |
elif "llama" in lang_encoder_name: | |
for i in range(len(lang_encoder.model.layers)): | |
lang_encoder.model.layers[i] = checkpoint_wrapper( | |
lang_encoder.model.layers[i], | |
offload_to_cpu=False, | |
) | |
elif "gptneo" in lang_encoder_name: | |
for i in range(len(lang_encoder.gpt_neox.layers)): | |
lang_encoder.gpt_neox.layers[i] = checkpoint_wrapper( | |
lang_encoder.gpt_neox.layers[i], | |
offload_to_cpu=False, | |
) | |
else: | |
raise ValueError(f"unknown model {lang_encoder_name}") | |
if use_sam is None: | |
vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"] | |
image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"] | |
patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"] | |
else: | |
# SAM config | |
vis_dim = 1024 | |
image_size = 256 | |
patch_size = 16 | |
assert image_size % patch_size == 0 | |
vis_embed_size = (image_size // patch_size) ** 2 | |
if lora: | |
from peft import LoraConfig, TaskType | |
from peft import get_peft_model | |
if "codegen" in lang_encoder_name: | |
lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"] | |
elif "opt" in lang_encoder_name: | |
lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] | |
elif "llama" in lang_encoder_name: | |
lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"] | |
else: | |
raise NotImplementedError | |
lang_peft_config = LoraConfig( | |
task_type="CAUSAL_LM", | |
r=16, lora_alpha=16, | |
target_modules=lang_target_modules, | |
lora_dropout=0.05, bias="none", | |
) | |
lang_encoder = get_peft_model(lang_encoder, lang_peft_config) | |
lang_encoder.print_trainable_parameters() | |
if fix_ffn: | |
if "opt" in lang_encoder_name: | |
for i in range(len(lang_encoder.model.decoder.layers)): | |
lang_encoder.model.decoder.layers[i].requires_grad_(False) | |
lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True) | |
else: | |
raise NotImplementedError | |
lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size) | |
if hasattr(lang_encoder.config, "word_embed_proj_dim"): | |
hidden_state_dim = lang_encoder.config.word_embed_proj_dim | |
else: | |
hidden_state_dim = lang_encoder.config.hidden_size | |
model = Flamingo( | |
vision_encoder=vision_encoder, | |
lang_encoder=lang_encoder, | |
eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1], | |
media_token_id=text_tokenizer.encode("<|#image#|>")[-1], | |
image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1], | |
visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None, | |
previsual_token_id=text_tokenizer.encode("<|#previsual#|>")[-1] if add_visual_token else None, | |
box_token_id=text_tokenizer.encode("<|#box#|>")[-1] if add_box else None, | |
prebox_token_id=text_tokenizer.encode("<|#prebox#|>")[-1] if add_box else None, | |
nothing_token_id=text_tokenizer.encode("<|#NOTHING#|>")[-1] if enhance_data else None, | |
endofobject_token_id=text_tokenizer.encode("<|#endofobject#|>")[-1], | |
vis_dim=vis_dim, | |
vis_embed_size=vis_embed_size, | |
lang_dim=lang_dim, | |
image_size=image_size, | |
patch_size=patch_size, | |
hidden_state_dim=hidden_state_dim, | |
add_visual_token=add_visual_token, | |
add_pe=add_pe, | |
add_relation=add_relation, | |
use_format_v2=use_format_v2, | |
roi_align=roi_align, | |
roi_output_size=roi_output_size, | |
apply_mask=apply_mask, | |
**flamingo_kwargs, | |
) | |
if freeze_vision_encoder: | |
print("freeze vision encoder") | |
model.vision_encoder.requires_grad_(False) | |
print( | |
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" | |
) | |
return model, image_processor, text_tokenizer, vis_embed_size | |
def _infer_decoder_layers_attr_name(model): | |
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | |
if k.lower() in model.__class__.__name__.lower(): | |
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | |
raise ValueError( | |
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." | |
) | |
__KNOWN_DECODER_LAYERS_ATTR_NAMES = { | |
"opt": "model.decoder.layers", | |
# "gptneo": "transformer.h", | |
"gptj": "transformer.h", | |
"gpt-j": "transformer.h", | |
"pythia": "gpt_neox.layers", | |
"gptneox": "gpt_neox.layers", | |
"llama": "model.layers", | |
"llamaforcausallm": "model.layers", | |
"gpt2": "transformer.h", | |
"codegen": "transformer.h", | |
} | |