File size: 6,722 Bytes
f76d30f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# Code modified from https://github.com/openai/CLIP
import json
import os
from pathlib import Path
from typing import Union, List
import urllib
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, InterpolationMode
from tqdm import tqdm
from clip import _tokenizer
from clip.model import convert_weights, CLIP, restore_model
__all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"]
_MODELS = {
"ViT-B-16": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt",
"ViT-L-14": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt",
"RN50": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt",
}
_MODEL_INFO = {
"ViT-B-16": {
"struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
},
"ViT-L-14": {
"struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
},
"RN50": {
"struct": "RN50@RBT3-chinese",
"input_resolution": 224
},
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None):
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
elif os.path.isfile(name):
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
model_path = name
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
# loading saved checkpoint
checkpoint = torch.load(opened_file, map_location="cpu")
model = create_model(model_name, checkpoint)
if str(device) == "cpu":
model.float()
else:
model.to(device)
return model, image_transform(model_input_resolution)
def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
bert_path=None, use_flash_attention=False):
"""Load CLIP and BERT model weights
"""
bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None
restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device)
if str(device) == "cpu":
model.float()
return model
def tokenize(texts: Union[str, List[str]], context_length: int = 52) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all baseline models use 52 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
:context_length - 2] + [_tokenizer.vocab['[SEP]']])
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
assert len(tokens) <= context_length
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def _convert_to_rgb(image):
return image.convert('RGB')
def image_transform(image_size=224):
transform = Compose([
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
return transform
def create_model(model_name, checkpoint=None):
vision_model, text_model = model_name.split('@')
# Initialize the model.
vision_model_config_file = Path(
__file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
print('Loading vision model config from', vision_model_config_file)
assert os.path.exists(vision_model_config_file)
text_model_config_file = Path(
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
print('Loading text model config from', text_model_config_file)
assert os.path.exists(text_model_config_file)
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
for k, v in json.load(ft).items():
model_info[k] = v
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
print('Model info', model_info)
model = CLIP(**model_info)
convert_weights(model)
if checkpoint:
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
model.load_state_dict(sd)
return model
|