File size: 2,100 Bytes
fa2f5fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from transformers import PreTrainedModel, PretrainedConfig
from .module import ConditionalViT

from sentence_transformers import SentenceTransformer


class CondViTConfig(PretrainedConfig):
    model_type = "condvit"

    def __init__(
        self,
        input_resolution: int = 224,
        patch_size: int = 16,
        width: int = 768,
        layers: int = 12,
        heads: int = 12,
        output_dim: int = 512,
        n_categories: int = 10,
        lm_backbone: str = "sentence-transformers/sentence-t5-xl",
        lm_revision: str = "e0976ba9afd18be963c22c680367a3928c44fd22",
        device: str = "cpu",
        **kwargs
    ):
        self.input_resolution = input_resolution
        self.patch_size = patch_size
        self.width = width
        self.layers = layers
        self.heads = heads
        self.output_dim = output_dim
        self.n_categories = n_categories

        self.lm_backbone = lm_backbone
        self.lm_revision = lm_revision

        self.device = device

        super().__init__(**kwargs)


class CondViTForEmbedding(PreTrainedModel):
    config_class = CondViTConfig

    def __init__(self, config):
        super().__init__(config)

        self.condvit = ConditionalViT(
            input_resolution=config.input_resolution,
            patch_size=config.patch_size,
            width=config.width,
            layers=config.layers,
            heads=config.heads,
            output_dim=config.output_dim,
        )
        if config.device:
            self.condvit.to(config.device)

        self.lm = SentenceTransformer(
            config.lm_backbone, revision=config.lm_revision, device=config.device
        )

    def forward(self, pixel_values, texts=None):
        if texts is not None:
            text_embeddings = self.lm.encode(
                texts,
                convert_to_tensor=True,
                convert_to_numpy=False,
            )
            text_embeddings = text_embeddings.to(pixel_values.device)
        else:
            text_embeddings = None
        return self.condvit(imgs=pixel_values, c=text_embeddings)