Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from ola_vlm.model.multimodal_encoder.openclip_utils import create_model_from_pretrained | |
from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower | |
from timm.models.convnext import ConvNeXt | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from .base_encoder import BaseVisionTower, ProcessorWrapper | |
from typing import Optional | |
class CLIP(nn.Module): | |
output_dict: torch.jit.Final[bool] | |
def __init__( | |
self, | |
embed_dim: int, | |
vision_cfg: CLIPVisionCfg, | |
text_cfg: CLIPTextCfg, | |
quick_gelu: bool = False, | |
cast_dtype: Optional[torch.dtype] = None, | |
output_dict: bool = False, | |
drop_path: bool = False, | |
): | |
super().__init__() | |
self.output_dict = output_dict | |
# Fix drop path during training | |
if not drop_path: | |
print('Not using drop path during training.') | |
vision_cfg['timm_drop_path'] = 0.0 | |
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) | |
def extract_res_interp(model_name): | |
valid_model_prefixes = { | |
"CLIP-convnext_large":"hf-hub:laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup", | |
"CLIP-convnext_xxlarge":"hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup" | |
} | |
res = None | |
interp = None | |
for prefix in valid_model_prefixes: | |
if model_name.split("/")[-1].startswith(prefix): | |
base_model_name = valid_model_prefixes[prefix] | |
break | |
else: | |
raise ValueError(f"Unknown vision tower: {model_name}") | |
parts = model_name.split("-") | |
for part in parts: | |
if part.startswith("res"): | |
res = int(part[3:]) | |
elif part.startswith("interp"): | |
interp = int(part[6:]) | |
return base_model_name, res, interp | |
class CLIPConvNextVisionTower(BaseVisionTower): | |
def __init__(self, vision_tower, args, delay_load=False): | |
""" | |
Initialize the CLIPConvNextTower. | |
Args: | |
vision_tower (str): The name of the vision tower model in the format "clip-convnext-resXXX-interpYYY". | |
args (argparse.Namespace): The arguments parsed from the command line. | |
delay_load (bool, optional): Whether to delay loading the model. Defaults to False. | |
""" | |
super().__init__(vision_tower, args, delay_load) | |
self.is_multi_stage = "multi-stage" in vision_tower | |
base_model_name, res, interp = extract_res_interp(vision_tower) | |
self.vision_tower_name = base_model_name | |
self.ckpt_path = vision_tower.split("-res")[0] | |
self._image_size = res if res is not None else 768 | |
self._interp_size = interp | |
self._reduction = 32 | |
self.select_layer = args.mm_vision_select_layer | |
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
self.unfreeze_mm_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) | |
self.is_loaded = False | |
if not delay_load: | |
self.load_model() | |
elif self.unfreeze_mm_vision_tower: | |
self.load_model() | |
else: | |
assert "CLIP-convnext_large" in vision_tower or "CLIP-convnext_xxlarge" in vision_tower | |
if "CLIP-convnext_large" in vision_tower: | |
if "multi-stage" in vision_tower: | |
self._hidden_size = sum([192, 384, 768, 1536]) | |
else: | |
self._hidden_size = 1536 | |
else: | |
if "multi-stage" in vision_tower: | |
self._hidden_size = sum([384, 768, 1536, 3072]) | |
else: | |
self._hidden_size = 3072 | |
def load_model(self, device_map=None): | |
""" | |
Load the CLIP-ConvNext model. | |
""" | |
assert "clip-convnext" in self.vision_tower_name.lower() | |
self.vision_model = "convnext" | |
try: | |
clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=True) | |
except: | |
clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=False) | |
processor.transforms[0].size = self._image_size | |
processor.transforms[1].size = (self._image_size, self._image_size) | |
self.image_processor = ProcessorWrapper(processor, height=self._image_size, width=self._image_size) | |
self.vision_tower: ConvNeXt = clip_model.visual.trunk | |
self.vision_tower.output_tokens = True | |
feature_info = self.vision_tower.feature_info | |
if self.is_multi_stage: | |
self._hidden_size = sum([stage['num_chs'] for stage in feature_info]) | |
else: | |
self._hidden_size = feature_info[-1]['num_chs'] | |
self.is_loaded = True | |
def interpolate(self, image_forward_outs): | |
""" | |
Interpolate the image features to the desired number of patches. | |
Args: | |
image_forward_outs (torch.Tensor): The output features from the vision tower. | |
Returns: | |
torch.Tensor: The interpolated image features. | |
""" | |
if self._interp_size is None: | |
return image_forward_outs | |
image_features = F.interpolate( | |
image_forward_outs.float(), | |
size=(self.num_patches_per_side, self.num_patches_per_side), | |
mode='bilinear', | |
align_corners=False | |
).to(dtype=image_forward_outs.dtype) | |
image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous() | |
return image_features | |
def _forward(self, images): | |
""" | |
Perform the forward pass of the CLIPConvNextTower. | |
Args: | |
images (torch.Tensor): The input images. | |
Returns: | |
torch.Tensor: The output features from the vision tower after interpolation. | |
""" | |
image_features_stages = [] | |
x = self.vision_tower.stem(images.to(device=self.device, dtype=self.dtype)) | |
for stage in self.vision_tower.stages: | |
x = stage(x) | |
image_features_stages.append(x) | |
image_features = self.vision_tower.norm_pre(x).contiguous() | |
# if not self.is_multi_stage: | |
# image_features_stages = image_features_stages[-1:] | |
# image_features_stages_rescaled = [] | |
# for image_features_single_stage in image_features_stages: | |
# image_features_single_stage_rescaled = self.interpolate(image_features_single_stage) | |
# image_features_stages_rescaled.append(image_features_single_stage_rescaled) | |
# image_features = torch.cat(image_features_stages_rescaled, -1) | |
image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous() | |
return image_features | |
def image_size(self): | |
return self._image_size | |
def num_patches_per_side(self): | |
""" | |
Get the number of patches per side. | |
Returns: | |
int: The number of patches per side. | |
""" | |
if self._interp_size is None: | |
return self._image_size // self._reduction | |
else: | |
return int(self._interp_size ** 0.5) | |
def num_patches(self): | |
""" | |
Get the total number of patches. | |
Default: 256 | |
Returns: | |
int: The total number of patches. | |
""" | |
if self._interp_size is None: | |
return (self._image_size // self._reduction) ** 2 | |
else: | |
return self._interp_size |