QA-CLIP / clip /
Upload 30 files
history blame contribute delete
No virus
6.72 kB
# Code modified from
import json
import os
from pathlib import Path
from typing import Union, List
import urllib
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, InterpolationMode
from tqdm import tqdm
from clip import _tokenizer
from clip.model import convert_weights, CLIP, restore_model
__all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"]
"ViT-B-16": "",
"ViT-L-14": "",
"RN50": "",
"ViT-B-16": {
"struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
"ViT-L-14": {
"struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
"RN50": {
"struct": "RN50@RBT3-chinese",
"input_resolution": 224
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int("Content-Length")), ncols=80, unit='iB', unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer =
if not buffer:
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None):
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
elif os.path.isfile(name):
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
model_path = name
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
# loading saved checkpoint
checkpoint = torch.load(opened_file, map_location="cpu")
model = create_model(model_name, checkpoint)
if str(device) == "cpu":
return model, image_transform(model_input_resolution)
def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
bert_path=None, use_flash_attention=False):
"""Load CLIP and BERT model weights
bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None
restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device)
if str(device) == "cpu":
return model
def tokenize(texts: Union[str, List[str]], context_length: int = 52) -> torch.LongTensor:
Returns the tokenized representation of given input string(s)
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all baseline models use 52 as the context length
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
:context_length - 2] + [_tokenizer.vocab['[SEP]']])
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
assert len(tokens) <= context_length
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def _convert_to_rgb(image):
return image.convert('RGB')
def image_transform(image_size=224):
transform = Compose([
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
return transform
def create_model(model_name, checkpoint=None):
vision_model, text_model = model_name.split('@')
# Initialize the model.
vision_model_config_file = Path(
__file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
print('Loading vision model config from', vision_model_config_file)
assert os.path.exists(vision_model_config_file)
text_model_config_file = Path(
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
print('Loading text model config from', text_model_config_file)
assert os.path.exists(text_model_config_file)
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
for k, v in json.load(ft).items():
model_info[k] = v
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
print('Model info', model_info)
model = CLIP(**model_info)
if checkpoint:
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
return model