BAAI
/

Feature Extraction
Transformers
PyTorch
clip
custom_code
EVA-CLIP-8B-448 / convert_evaclip_8b_448_pytorch_to_hf.py
ryanzhangfan's picture
Upload 15 files
bf6e2be verified
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Part of the code was taken from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py
import argparse
import os, sys
sys.path.append(os.getcwd())
import torch
from PIL import Image
from transformers import AutoModel, AutoConfig
from transformers import CLIPImageProcessor, pipeline, CLIPTokenizer
from EVA_CLIP_8B_448.configuration_evaclip import EvaCLIPConfig
from EVA_CLIP_8B_448.modeling_evaclip import EvaCLIPModel
KEYS_TO_MODIFY_MAPPING = {
"cls_token":"embeddings.class_embedding",
"pos_embed":"embeddings.position_embedding.weight",
"patch_embed.proj":"embeddings.patch_embedding",
".positional_embedding":".embeddings.position_embedding.weight",
".token_embedding":".embeddings.token_embedding",
"text.text_projection":"text_projection.weight",
"mlp.c_fc":"mlp.fc1",
"mlp.c_proj":"mlp.fc2",
".proj.":".out_proj.",
"q_bias":"q_proj.bias",
"v_bias":"v_proj.bias",
"out.":"out_proj.",
"norm1":"layer_norm1",
"norm2":"layer_norm2",
"ln_1":"layer_norm1",
"ln_2":"layer_norm2",
"attn":"self_attn",
"norm.":"post_layernorm.",
"ln_final":"final_layer_norm",
"visual.blocks":"vision_model.encoder.layers",
"text.transformer.resblocks":"text_model.encoder.layers",
"visual.head":"visual_projection",
"visual.":"vision_model.",
"text.":"text_model.",
}
def rename_state_dict(state_dict):
model_state_dict = {}
for key, value in state_dict.items():
# check if any key needs to be modified
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
if "text_projection" in key:
model_state_dict[key] = value.T
elif "attn.qkv" in key:
# split qkv into query key and value
mixed_qkv = value
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
model_state_dict[key.replace("qkv", "q_proj")] = query_layer
model_state_dict[key.replace("qkv", "k_proj")] = key_layer
model_state_dict[key.replace("qkv", "v_proj")] = value_layer
elif "attn.in_proj" in key:
# split qkv into query key and value
mixed_qkv = value
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
model_state_dict[key.replace("in_proj_", "q_proj.")] = query_layer
model_state_dict[key.replace("in_proj_", "k_proj.")] = key_layer
model_state_dict[key.replace("in_proj_", "v_proj.")] = value_layer
elif "class_embedding" in key:
model_state_dict[key] = value[0,0,:]
elif "vision_model.embeddings.position_embedding" in key:
model_state_dict[key] = value[0,:,:]
else:
model_state_dict[key] = value
return model_state_dict
def save_model_and_config(pytorch_dump_folder_path, hf_model, transformers_config):
hf_model.save_pretrained(pytorch_dump_folder_path)
transformers_config.save_pretrained(pytorch_dump_folder_path)
def check_loaded_model(pytorch_dump_folder_path, tokenizer, processor, image, captions):
hf_config = AutoConfig.from_pretrained(pytorch_dump_folder_path, trust_remote_code=True)
hf_model = AutoModel.from_pretrained(pytorch_dump_folder_path, config=hf_config, trust_remote_code=True)
detector = pipeline(model=hf_model, task="zero-shot-image-classification", tokenizer = tokenizer, image_processor=processor)
detector_probs = detector(image, candidate_labels=captions)
print(f"text_probs loaded hf_model using pipeline: {detector_probs}")
def convert_evaclip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, image_path, save=False):
processor = CLIPImageProcessor(size={"shortest_edge":448}, do_center_crop=True, crop_size=448)
print(f"processor={str(processor)}")
image = Image.open(image_path)
captions = ["a diagram", "a dog", "a cat"]
tokenizer = CLIPTokenizer.from_pretrained(pytorch_dump_folder_path)
input_ids = tokenizer(captions, return_tensors="pt", padding=True).input_ids
input_pixels = processor(images=image, size=448, return_tensors="pt", padding=True).pixel_values
print("input_pixels.shape", input_pixels.shape)
transformers_config = EvaCLIPConfig.from_pretrained(config_path)
hf_model = EvaCLIPModel(transformers_config)
pt_model_state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = rename_state_dict(pt_model_state_dict)
hf_model.load_state_dict(state_dict, strict=True)
with torch.no_grad():
image_features = hf_model.encode_image(input_pixels)
text_features = hf_model.encode_text(input_ids)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
label_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
print(f"hf_model label probs: {label_probs}")
if save:
save_model_and_config(pytorch_dump_folder_path, hf_model, transformers_config)
check_loaded_model(pytorch_dump_folder_path, tokenizer, processor, image, captions)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default="EVA_CLIP_8B_448" ,type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default="EVA_CLIP_8B_psz14_plus_s0.6B.pt", type=str, help="Path to fairseq checkpoint" )
parser.add_argument("--config_path", default='EVA_CLIP_8B_448', type=str, help="Path to hf config.json of model to convert")
parser.add_argument("--image_path", default='EVA_CLIP_8B_448/CLIP.png', type=str, help="Path to image")
parser.add_argument("--save", default=False, action="store_true", help="Save the model and config to the pytorch_dump_folder_path. Default is True.")
args = parser.parse_args()
convert_evaclip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.image_path, args.save)