OLA-VLM / ola_vlm /model /multimodal_encoder /clip_convnext_encoder.py
praeclarumjj3's picture
:zap: add code
9fa3d89
raw
history blame
7.55 kB
import torch
import torch.nn as nn
from ola_vlm.model.multimodal_encoder.openclip_utils import create_model_from_pretrained
from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower
from timm.models.convnext import ConvNeXt
import torch
from torch import nn
import torch.nn.functional as F
from .base_encoder import BaseVisionTower, ProcessorWrapper
from typing import Optional
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
drop_path: bool = False,
):
super().__init__()
self.output_dict = output_dict
# Fix drop path during training
if not drop_path:
print('Not using drop path during training.')
vision_cfg['timm_drop_path'] = 0.0
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
def extract_res_interp(model_name):
valid_model_prefixes = {
"CLIP-convnext_large":"hf-hub:laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
"CLIP-convnext_xxlarge":"hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup"
}
res = None
interp = None
for prefix in valid_model_prefixes:
if model_name.split("/")[-1].startswith(prefix):
base_model_name = valid_model_prefixes[prefix]
break
else:
raise ValueError(f"Unknown vision tower: {model_name}")
parts = model_name.split("-")
for part in parts:
if part.startswith("res"):
res = int(part[3:])
elif part.startswith("interp"):
interp = int(part[6:])
return base_model_name, res, interp
class CLIPConvNextVisionTower(BaseVisionTower):
def __init__(self, vision_tower, args, delay_load=False):
"""
Initialize the CLIPConvNextTower.
Args:
vision_tower (str): The name of the vision tower model in the format "clip-convnext-resXXX-interpYYY".
args (argparse.Namespace): The arguments parsed from the command line.
delay_load (bool, optional): Whether to delay loading the model. Defaults to False.
"""
super().__init__(vision_tower, args, delay_load)
self.is_multi_stage = "multi-stage" in vision_tower
base_model_name, res, interp = extract_res_interp(vision_tower)
self.vision_tower_name = base_model_name
self.ckpt_path = vision_tower.split("-res")[0]
self._image_size = res if res is not None else 768
self._interp_size = interp
self._reduction = 32
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.unfreeze_mm_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
self.is_loaded = False
if not delay_load:
self.load_model()
elif self.unfreeze_mm_vision_tower:
self.load_model()
else:
assert "CLIP-convnext_large" in vision_tower or "CLIP-convnext_xxlarge" in vision_tower
if "CLIP-convnext_large" in vision_tower:
if "multi-stage" in vision_tower:
self._hidden_size = sum([192, 384, 768, 1536])
else:
self._hidden_size = 1536
else:
if "multi-stage" in vision_tower:
self._hidden_size = sum([384, 768, 1536, 3072])
else:
self._hidden_size = 3072
def load_model(self, device_map=None):
"""
Load the CLIP-ConvNext model.
"""
assert "clip-convnext" in self.vision_tower_name.lower()
self.vision_model = "convnext"
try:
clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=True)
except:
clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=False)
processor.transforms[0].size = self._image_size
processor.transforms[1].size = (self._image_size, self._image_size)
self.image_processor = ProcessorWrapper(processor, height=self._image_size, width=self._image_size)
self.vision_tower: ConvNeXt = clip_model.visual.trunk
self.vision_tower.output_tokens = True
feature_info = self.vision_tower.feature_info
if self.is_multi_stage:
self._hidden_size = sum([stage['num_chs'] for stage in feature_info])
else:
self._hidden_size = feature_info[-1]['num_chs']
self.is_loaded = True
def interpolate(self, image_forward_outs):
"""
Interpolate the image features to the desired number of patches.
Args:
image_forward_outs (torch.Tensor): The output features from the vision tower.
Returns:
torch.Tensor: The interpolated image features.
"""
if self._interp_size is None:
return image_forward_outs
image_features = F.interpolate(
image_forward_outs.float(),
size=(self.num_patches_per_side, self.num_patches_per_side),
mode='bilinear',
align_corners=False
).to(dtype=image_forward_outs.dtype)
image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous()
return image_features
def _forward(self, images):
"""
Perform the forward pass of the CLIPConvNextTower.
Args:
images (torch.Tensor): The input images.
Returns:
torch.Tensor: The output features from the vision tower after interpolation.
"""
image_features_stages = []
x = self.vision_tower.stem(images.to(device=self.device, dtype=self.dtype))
for stage in self.vision_tower.stages:
x = stage(x)
image_features_stages.append(x)
image_features = self.vision_tower.norm_pre(x).contiguous()
# if not self.is_multi_stage:
# image_features_stages = image_features_stages[-1:]
# image_features_stages_rescaled = []
# for image_features_single_stage in image_features_stages:
# image_features_single_stage_rescaled = self.interpolate(image_features_single_stage)
# image_features_stages_rescaled.append(image_features_single_stage_rescaled)
# image_features = torch.cat(image_features_stages_rescaled, -1)
image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous()
return image_features
@property
def image_size(self):
return self._image_size
@property
def num_patches_per_side(self):
"""
Get the number of patches per side.
Returns:
int: The number of patches per side.
"""
if self._interp_size is None:
return self._image_size // self._reduction
else:
return int(self._interp_size ** 0.5)
@property
def num_patches(self):
"""
Get the total number of patches.
Default: 256
Returns:
int: The total number of patches.
"""
if self._interp_size is None:
return (self._image_size // self._reduction) ** 2
else:
return self._interp_size