COVER / cover /models /xclip_backbone.py
nanushio
+ [MAJOR] [ROOT] [CREATE] 1. fork repo from COVER github
feb2918
raw
history blame
28.7 kB
import copy
import math
from collections import OrderedDict
from typing import Tuple, Union
import clip
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import trunc_normal_
from torch import nn
from torch.utils.checkpoint import checkpoint_sequential
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
# orig_type = x.dtype
# ret = super().forward(x.type(torch.float32))
# return ret.type(orig_type)
return super().forward(x)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(
self, d_model: int, n_head: int, attn_mask: torch.Tensor = None,
):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head,)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
)
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
):
super().__init__()
self.context_length = context_length
# vision_heads = vision_width // 64
# self.visual = VisionTransformer(
# input_resolution=image_resolution,
# patch_size=vision_patch_size,
# width=vision_width,
# layers=vision_layers,
# heads=vision_heads,
# output_dim=embed_dim
# )
# self.transformer = Transformer(
# width=transformer_width,
# layers=transformer_layers,
# heads=transformer_heads,
# attn_mask=self.build_attention_mask()
# )
# self.vocab_size = vocab_size
# self.token_embedding = nn.Embedding(vocab_size, transformer_width)
# self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
# self.ln_final = LayerNorm(transformer_width)
# self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
# self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
class CrossFramelAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
droppath=0.0,
T=0,
):
super().__init__()
self.T = T
self.message_fc = nn.Linear(d_model, d_model)
self.message_ln = LayerNorm(d_model)
self.message_attn = nn.MultiheadAttention(d_model, n_head,)
self.attn = nn.MultiheadAttention(d_model, n_head,)
self.ln_1 = LayerNorm(d_model)
self.drop_path = DropPath(droppath) if droppath > 0.0 else nn.Identity()
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x):
l, bt, d = x.size()
b = bt // self.T
x = x.view(l, b, self.T, d)
msg_token = self.message_fc(x[0, :, :, :])
msg_token = msg_token.view(b, self.T, 1, d)
msg_token = msg_token.permute(1, 2, 0, 3).view(self.T, b, d)
msg_token = msg_token + self.drop_path(
self.message_attn(
self.message_ln(msg_token),
self.message_ln(msg_token),
self.message_ln(msg_token),
need_weights=False,
)[0]
)
msg_token = msg_token.view(self.T, 1, b, d).permute(1, 2, 0, 3)
x = torch.cat([x, msg_token], dim=0)
x = x.view(l + 1, -1, d)
x = x + self.drop_path(self.attention(self.ln_1(x)))
x = x[:l, :, :]
x = x + self.drop_path(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
droppath=None,
use_checkpoint=False,
T=8,
):
super().__init__()
self.use_checkpoint = use_checkpoint
if droppath is None:
droppath = [0.0 for i in range(layers)]
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(
*[
CrossFramelAttentionBlock(width, heads, attn_mask, droppath[i], T)
for i in range(layers)
]
)
def forward(self, x: torch.Tensor):
if not self.use_checkpoint:
return self.resblocks(x)
else:
return checkpoint_sequential(self.resblocks, 3, x)
class CrossFrameCommunicationTransformer(nn.Module):
def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
droppath=None,
T=8,
use_checkpoint=False,
):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
)
self.ln_pre = LayerNorm(width)
## Attention Blocks
self.transformer = Transformer(
width, layers, heads, droppath=droppath, use_checkpoint=use_checkpoint, T=T,
)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def init_weights(self):
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2)
x = self.transformer(x)
x = x.permute(1, 0, 2)
cls_x = self.ln_post(x[:, 0, :])
if self.proj is not None:
cls_x = cls_x @ self.proj
return cls_x, x[:, 1:, :]
class MulitHeadAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, q, k, v):
B, N, C = q.shape
B, M, C = k.shape
q = (
self.q_proj(q)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = (
self.k_proj(k)
.reshape(B, M, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v = (
self.v_proj(v)
.reshape(B, M, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class PromptGeneratorLayer(nn.Module):
def __init__(
self, d_model, nhead, dropout=0.0,
):
super().__init__()
self.cross_attn = MulitHeadAttention(d_model, nhead, proj_drop=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4),
QuickGELU(),
nn.Dropout(dropout),
nn.Linear(d_model * 4, d_model),
)
def forward(self, x, visual):
q = k = v = self.norm1(x)
x = x + self.cross_attn(q, visual, visual)
x = x + self.dropout(self.mlp(self.norm3(x)))
return x
class VideoSpecificPrompt(nn.Module):
def __init__(
self, layers=2, embed_dim=512, alpha=0.1,
):
super().__init__()
self.norm = nn.LayerNorm(embed_dim)
self.decoder = nn.ModuleList(
[PromptGeneratorLayer(embed_dim, embed_dim // 64) for _ in range(layers)]
)
self.alpha = nn.Parameter(torch.ones(embed_dim) * alpha)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, text, visual):
B, N, C = visual.shape
visual = self.norm(visual)
for layer in self.decoder:
text = layer(text, visual)
from collections import OrderedDict
from timm.models.layers import trunc_normal_
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = nn.LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class MultiframeIntegrationTransformer(nn.Module):
def __init__(
self, T, embed_dim=512, layers=1,
):
super().__init__()
self.T = T
transformer_heads = embed_dim // 64
self.positional_embedding = nn.Parameter(torch.empty(1, T, embed_dim))
trunc_normal_(self.positional_embedding, std=0.02)
self.resblocks = nn.Sequential(
*[
ResidualAttentionBlock(d_model=embed_dim, n_head=transformer_heads)
for _ in range(layers)
]
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear,)):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def forward(self, x):
ori_x = x
x = x + self.positional_embedding
x = x.permute(1, 0, 2)
x = self.resblocks(x)
x = x.permute(1, 0, 2)
x = x.type(ori_x.dtype) + ori_x
return x.mean(dim=1, keepdim=False)
class XCLIP(CLIP):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
# video
T=8,
droppath=0.0,
mit_layers=1,
# prompt
prompts_alpha=1e-4,
prompts_layers=1,
# other
use_cache=True,
use_checkpoint=False,
):
super().__init__(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
)
self.prompts_generator = VideoSpecificPrompt(
layers=prompts_layers, embed_dim=embed_dim, alpha=prompts_alpha,
)
self.use_cache = use_cache
self.mit = MultiframeIntegrationTransformer(
T=T, embed_dim=embed_dim, layers=mit_layers,
)
dpr = (
[x.item() for x in torch.linspace(0, droppath, vision_layers)]
if droppath > 0.0
else None
)
vision_heads = vision_width // 64
self.visual = CrossFrameCommunicationTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
droppath=dpr,
T=T,
use_checkpoint=use_checkpoint,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width)
)
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.cache_text_features = None
self.prompts_visual_ln = LayerNorm(vision_width)
self.prompts_visual_proj = nn.Parameter(torch.randn(vision_width, embed_dim))
self.initialize_parameters()
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"positional_embedding"}
def encode_image(self, image):
return self.visual(image)
def encode_text(self, text):
x = self.token_embedding(text)
eos_indx = text.argmax(dim=-1)
K, N1, C = x.shape
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection
x = x.reshape(K, -1)
return x
def encode_video(self, image):
b, t, c, h, w = image.size()
image = image.reshape(-1, c, h, w)
cls_features, img_features = self.encode_image(image)
img_features = self.prompts_visual_ln(img_features)
img_features = img_features @ self.prompts_visual_proj
cls_features = cls_features.view(b, t, -1)
img_features = img_features.view(b, t, -1, cls_features.shape[-1])
video_features = self.mit(cls_features)
return video_features, img_features
def forward(self, image, **kwargs):
image = rearrange(image, "b c t h w -> b t c h w")
video_features, _ = self.encode_video(image)
return video_features.reshape(*video_features.shape, 1, 1, 1)
def cache_text(self, text):
self.eval()
with torch.no_grad():
if self.cache_text_features is None:
self.cache_text_features = self.encode_text(text)
self.train()
return self.cache_text_features
def forward_original(self, image, text):
b = image.shape[0]
video_features, img_features = self.encode_video(image)
img_features = img_features.mean(dim=1, keepdim=False)
if self.use_cache:
text_features = self.cache_text(text)
else:
text_features = self.encode_text(text)
text_features = text_features.unsqueeze(0).expand(b, -1, -1)
text_features = text_features + self.prompts_generator(
text_features, img_features
)
video_features = video_features / video_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
logits = torch.einsum("bd,bkd->bk", video_features, logit_scale * text_features)
return logits
def build_x_clip_model(
pretrained_path="./pretrained_weights/k400_32_8.pth",
droppath=0.0,
use_checkpoint=False,
logger=None,
prompts_alpha=1e-1,
prompts_layers=2,
use_cache=True,
mit_layers=4,
**kwargs,
):
state_dict = torch.load(pretrained_path, map_location="cpu")["model"]
T = int(pretrained_path.split("_")[-1].split(".")[0])
print(T)
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round(
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"visual.layer{b}")
)
)
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round(
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
)
vision_patch_size = None
assert (
output_width ** 2 + 1
== state_dict["visual.attnpool.positional_embedding"].shape[0]
)
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"transformer.resblocks")
)
)
model = XCLIP(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
T=T,
droppath=droppath,
mit_layers=mit_layers,
prompts_alpha=prompts_alpha,
prompts_layers=prompts_layers,
use_checkpoint=use_checkpoint,
use_cache=use_cache,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
msg = model.load_state_dict(state_dict, strict=False)
return model.eval()