Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,547 Bytes
9fa3d89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import torch
import torch.nn as nn
from ola_vlm.model.multimodal_encoder.openclip_utils import create_model_from_pretrained
from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower
from timm.models.convnext import ConvNeXt
import torch
from torch import nn
import torch.nn.functional as F
from .base_encoder import BaseVisionTower, ProcessorWrapper
from typing import Optional
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
drop_path: bool = False,
):
super().__init__()
self.output_dict = output_dict
# Fix drop path during training
if not drop_path:
print('Not using drop path during training.')
vision_cfg['timm_drop_path'] = 0.0
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
def extract_res_interp(model_name):
valid_model_prefixes = {
"CLIP-convnext_large":"hf-hub:laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
"CLIP-convnext_xxlarge":"hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup"
}
res = None
interp = None
for prefix in valid_model_prefixes:
if model_name.split("/")[-1].startswith(prefix):
base_model_name = valid_model_prefixes[prefix]
break
else:
raise ValueError(f"Unknown vision tower: {model_name}")
parts = model_name.split("-")
for part in parts:
if part.startswith("res"):
res = int(part[3:])
elif part.startswith("interp"):
interp = int(part[6:])
return base_model_name, res, interp
class CLIPConvNextVisionTower(BaseVisionTower):
def __init__(self, vision_tower, args, delay_load=False):
"""
Initialize the CLIPConvNextTower.
Args:
vision_tower (str): The name of the vision tower model in the format "clip-convnext-resXXX-interpYYY".
args (argparse.Namespace): The arguments parsed from the command line.
delay_load (bool, optional): Whether to delay loading the model. Defaults to False.
"""
super().__init__(vision_tower, args, delay_load)
self.is_multi_stage = "multi-stage" in vision_tower
base_model_name, res, interp = extract_res_interp(vision_tower)
self.vision_tower_name = base_model_name
self.ckpt_path = vision_tower.split("-res")[0]
self._image_size = res if res is not None else 768
self._interp_size = interp
self._reduction = 32
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.unfreeze_mm_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
self.is_loaded = False
if not delay_load:
self.load_model()
elif self.unfreeze_mm_vision_tower:
self.load_model()
else:
assert "CLIP-convnext_large" in vision_tower or "CLIP-convnext_xxlarge" in vision_tower
if "CLIP-convnext_large" in vision_tower:
if "multi-stage" in vision_tower:
self._hidden_size = sum([192, 384, 768, 1536])
else:
self._hidden_size = 1536
else:
if "multi-stage" in vision_tower:
self._hidden_size = sum([384, 768, 1536, 3072])
else:
self._hidden_size = 3072
def load_model(self, device_map=None):
"""
Load the CLIP-ConvNext model.
"""
assert "clip-convnext" in self.vision_tower_name.lower()
self.vision_model = "convnext"
try:
clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=True)
except:
clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=False)
processor.transforms[0].size = self._image_size
processor.transforms[1].size = (self._image_size, self._image_size)
self.image_processor = ProcessorWrapper(processor, height=self._image_size, width=self._image_size)
self.vision_tower: ConvNeXt = clip_model.visual.trunk
self.vision_tower.output_tokens = True
feature_info = self.vision_tower.feature_info
if self.is_multi_stage:
self._hidden_size = sum([stage['num_chs'] for stage in feature_info])
else:
self._hidden_size = feature_info[-1]['num_chs']
self.is_loaded = True
def interpolate(self, image_forward_outs):
"""
Interpolate the image features to the desired number of patches.
Args:
image_forward_outs (torch.Tensor): The output features from the vision tower.
Returns:
torch.Tensor: The interpolated image features.
"""
if self._interp_size is None:
return image_forward_outs
image_features = F.interpolate(
image_forward_outs.float(),
size=(self.num_patches_per_side, self.num_patches_per_side),
mode='bilinear',
align_corners=False
).to(dtype=image_forward_outs.dtype)
image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous()
return image_features
def _forward(self, images):
"""
Perform the forward pass of the CLIPConvNextTower.
Args:
images (torch.Tensor): The input images.
Returns:
torch.Tensor: The output features from the vision tower after interpolation.
"""
image_features_stages = []
x = self.vision_tower.stem(images.to(device=self.device, dtype=self.dtype))
for stage in self.vision_tower.stages:
x = stage(x)
image_features_stages.append(x)
image_features = self.vision_tower.norm_pre(x).contiguous()
# if not self.is_multi_stage:
# image_features_stages = image_features_stages[-1:]
# image_features_stages_rescaled = []
# for image_features_single_stage in image_features_stages:
# image_features_single_stage_rescaled = self.interpolate(image_features_single_stage)
# image_features_stages_rescaled.append(image_features_single_stage_rescaled)
# image_features = torch.cat(image_features_stages_rescaled, -1)
image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous()
return image_features
@property
def image_size(self):
return self._image_size
@property
def num_patches_per_side(self):
"""
Get the number of patches per side.
Returns:
int: The number of patches per side.
"""
if self._interp_size is None:
return self._image_size // self._reduction
else:
return int(self._interp_size ** 0.5)
@property
def num_patches(self):
"""
Get the total number of patches.
Default: 256
Returns:
int: The total number of patches.
"""
if self._interp_size is None:
return (self._image_size // self._reduction) ** 2
else:
return self._interp_size |