File size: 7,547 Bytes
9fa3d89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
import torch.nn as nn
from ola_vlm.model.multimodal_encoder.openclip_utils import create_model_from_pretrained
from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower
from timm.models.convnext import ConvNeXt
import torch
from torch import nn
import torch.nn.functional as F
from .base_encoder import BaseVisionTower, ProcessorWrapper
from typing import Optional

class CLIP(nn.Module):
    output_dict: torch.jit.Final[bool]

    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
            output_dict: bool = False,
            drop_path: bool = False,
    ):
        super().__init__()
        self.output_dict = output_dict

        # Fix drop path during training
        if not drop_path:
            print('Not using drop path during training.')
            vision_cfg['timm_drop_path'] = 0.0

        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)

def extract_res_interp(model_name):
    valid_model_prefixes = {
        "CLIP-convnext_large":"hf-hub:laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
        "CLIP-convnext_xxlarge":"hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup"
    }


    res = None
    interp = None

    for prefix in valid_model_prefixes:
        if model_name.split("/")[-1].startswith(prefix):
            base_model_name = valid_model_prefixes[prefix]
            break
    else:
        raise ValueError(f"Unknown vision tower: {model_name}")

    parts = model_name.split("-")
    for part in parts:
        if part.startswith("res"):
            res = int(part[3:])
        elif part.startswith("interp"):
            interp = int(part[6:])
    return base_model_name, res, interp


class CLIPConvNextVisionTower(BaseVisionTower):
    def __init__(self, vision_tower, args, delay_load=False):
        """
        Initialize the CLIPConvNextTower.

        Args:
            vision_tower (str): The name of the vision tower model in the format "clip-convnext-resXXX-interpYYY".
            args (argparse.Namespace): The arguments parsed from the command line.
            delay_load (bool, optional): Whether to delay loading the model. Defaults to False.
        """
        super().__init__(vision_tower, args, delay_load)

        self.is_multi_stage = "multi-stage" in vision_tower
        base_model_name, res, interp = extract_res_interp(vision_tower)
        self.vision_tower_name = base_model_name
        self.ckpt_path = vision_tower.split("-res")[0]
        self._image_size = res if res is not None else 768
        self._interp_size = interp
        self._reduction = 32

        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
        self.unfreeze_mm_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
        self.is_loaded = False

        if not delay_load:
            self.load_model()
        elif self.unfreeze_mm_vision_tower:
            self.load_model()
        else:
            assert "CLIP-convnext_large" in vision_tower or "CLIP-convnext_xxlarge" in vision_tower
            if "CLIP-convnext_large" in vision_tower:
                if "multi-stage" in  vision_tower:
                    self._hidden_size = sum([192, 384, 768, 1536])
                else:
                    self._hidden_size = 1536
            else:
                if "multi-stage" in  vision_tower:
                    self._hidden_size = sum([384, 768, 1536, 3072])
                else:
                    self._hidden_size = 3072
    

    def load_model(self, device_map=None):
        """
        Load the CLIP-ConvNext model.
        """

        assert "clip-convnext" in self.vision_tower_name.lower()
        self.vision_model = "convnext"
        try:
            clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=True)
        except:
            clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=False)
        processor.transforms[0].size = self._image_size
        processor.transforms[1].size = (self._image_size, self._image_size)
        self.image_processor = ProcessorWrapper(processor, height=self._image_size, width=self._image_size)
 
        self.vision_tower: ConvNeXt = clip_model.visual.trunk
        self.vision_tower.output_tokens = True
        feature_info = self.vision_tower.feature_info
        if self.is_multi_stage:
            self._hidden_size = sum([stage['num_chs'] for stage in feature_info])
        else:
            self._hidden_size = feature_info[-1]['num_chs']
        self.is_loaded = True

    def interpolate(self, image_forward_outs):
        """
        Interpolate the image features to the desired number of patches.

        Args:
            image_forward_outs (torch.Tensor): The output features from the vision tower.

        Returns:
            torch.Tensor: The interpolated image features.
        """
        if self._interp_size is None:
            return image_forward_outs

        image_features = F.interpolate(
            image_forward_outs.float(),
            size=(self.num_patches_per_side, self.num_patches_per_side),
            mode='bilinear',
            align_corners=False
        ).to(dtype=image_forward_outs.dtype)
        image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous()
        return image_features

    def _forward(self, images):
        """
        Perform the forward pass of the CLIPConvNextTower.

        Args:
            images (torch.Tensor): The input images.

        Returns:
            torch.Tensor: The output features from the vision tower after interpolation.
        """
        image_features_stages = []
        x = self.vision_tower.stem(images.to(device=self.device, dtype=self.dtype))
        for stage in self.vision_tower.stages:
            x = stage(x)
            image_features_stages.append(x)
        image_features = self.vision_tower.norm_pre(x).contiguous()
        # if not self.is_multi_stage:
        #     image_features_stages = image_features_stages[-1:]
        # image_features_stages_rescaled = []
        # for image_features_single_stage in image_features_stages:
        #     image_features_single_stage_rescaled = self.interpolate(image_features_single_stage)
        #     image_features_stages_rescaled.append(image_features_single_stage_rescaled)
        # image_features = torch.cat(image_features_stages_rescaled, -1)
        image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous()
        return image_features

    @property
    def image_size(self):
        return self._image_size

    @property
    def num_patches_per_side(self):
        """
        Get the number of patches per side.

        Returns:
            int: The number of patches per side.
        """
        if self._interp_size is None:
            return self._image_size // self._reduction
        else:
            return int(self._interp_size ** 0.5)

    @property
    def num_patches(self):
        """
        Get the total number of patches.

        Default: 256

        Returns:
            int: The total number of patches.
        """
        if self._interp_size is None:
            return (self._image_size // self._reduction) ** 2
        else:
            return self._interp_size