File size: 4,287 Bytes
3b40f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch.nn import functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

from CLIP import clip

from util.util import compose_text_with_templates

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ClipExtractor(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        model = clip.load(cfg["clip_model_name"], device=device)[0]
        self.model = model.eval().requires_grad_(False)

        self.clip_input_size = 224
        self.clip_normalize = T.Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
        )
        self.basic_transform = T.Compose(
            [
                # we added interpolation to CLIP positional embedding, allowing to work with arbitrary resolution.
                T.Resize(self.clip_input_size, max_size=380),
                self.clip_normalize,
            ]
        )
        # list of augmentations we apply before calculating the CLIP losses
        self.augs = T.Compose(
            [
                T.RandomHorizontalFlip(p=0.5),
                T.RandomApply(
                    [
                        T.RandomAffine(
                            degrees=15,
                            translate=(0.1, 0.1),
                            fill=cfg["clip_affine_transform_fill"],
                            interpolation=InterpolationMode.BILINEAR,
                        )
                    ],
                    p=0.8,
                ),
                T.RandomPerspective(
                    distortion_scale=0.4,
                    p=0.5,
                    interpolation=InterpolationMode.BILINEAR,
                    fill=cfg["clip_affine_transform_fill"],
                ),
                T.RandomApply([T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)], p=0.7),
                T.RandomGrayscale(p=0.15),
            ]
        )

        self.n_aug = cfg["n_aug"]

    def augment_input(self, input, n_aug=None, clip_input_size=None):
        if n_aug is None:
            n_aug = self.n_aug
        if clip_input_size is None:
            clip_input_size = self.clip_input_size

        cutouts = []
        cutout = T.Resize(clip_input_size, max_size=320)(input)
        cutout_h, cutout_w = cutout.shape[-2:]
        cutout = self.augs(cutout)
        cutouts.append(cutout)
        sideY, sideX = input.shape[2:4]
        for _ in range(n_aug - 1):
            s = (
                torch.zeros(
                    1,
                )
                .uniform_(0.6, 1)
                .item()
            )
            h = int(sideY * s)
            w = int(sideX * s)
            cutout = T.RandomCrop(size=(h, w))(input)
            cutout = T.Resize((cutout_h, cutout_w))(cutout)
            cutout = self.augs(cutout)
            cutouts.append(cutout)

        cutouts = torch.cat(cutouts)
        return cutouts

    def get_image_embedding(self, x, aug=True):
        if aug:
            views = self.augment_input(x)
        else:
            views = self.basic_transform(x)
        if type(views) == list:
            image_embeds = []
            for view in views:
                image_embeds.append(self.encode_image(self.clip_normalize(view)))
            image_embeds = torch.cat(image_embeds)
        else:
            image_embeds = self.encode_image(self.clip_normalize(views))
        return image_embeds

    def encode_image(self, x):
        return self.model.encode_image(x)

    def get_text_embedding(self, text, template, average_embeddings=False):
        if type(text) == str:
            text = [text]
        embeddings = []
        for prompt in text:
            with torch.no_grad():
                embedding = self.model.encode_text(
                    clip.tokenize(compose_text_with_templates(prompt, template)).to(device)
                )
            embeddings.append(embedding)
        embeddings = torch.cat(embeddings)
        if average_embeddings:
            embeddings = embeddings.mean(dim=0, keepdim=True)
        return embeddings

    def get_self_sim(self, x):
        x = self.basic_transform(x)
        return self.model.calculate_self_sim(x)