File size: 2,100 Bytes
fa2f5fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from transformers import PreTrainedModel, PretrainedConfig
from .module import ConditionalViT
from sentence_transformers import SentenceTransformer
class CondViTConfig(PretrainedConfig):
model_type = "condvit"
def __init__(
self,
input_resolution: int = 224,
patch_size: int = 16,
width: int = 768,
layers: int = 12,
heads: int = 12,
output_dim: int = 512,
n_categories: int = 10,
lm_backbone: str = "sentence-transformers/sentence-t5-xl",
lm_revision: str = "e0976ba9afd18be963c22c680367a3928c44fd22",
device: str = "cpu",
**kwargs
):
self.input_resolution = input_resolution
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.output_dim = output_dim
self.n_categories = n_categories
self.lm_backbone = lm_backbone
self.lm_revision = lm_revision
self.device = device
super().__init__(**kwargs)
class CondViTForEmbedding(PreTrainedModel):
config_class = CondViTConfig
def __init__(self, config):
super().__init__(config)
self.condvit = ConditionalViT(
input_resolution=config.input_resolution,
patch_size=config.patch_size,
width=config.width,
layers=config.layers,
heads=config.heads,
output_dim=config.output_dim,
)
if config.device:
self.condvit.to(config.device)
self.lm = SentenceTransformer(
config.lm_backbone, revision=config.lm_revision, device=config.device
)
def forward(self, pixel_values, texts=None):
if texts is not None:
text_embeddings = self.lm.encode(
texts,
convert_to_tensor=True,
convert_to_numpy=False,
)
text_embeddings = text_embeddings.to(pixel_values.device)
else:
text_embeddings = None
return self.condvit(imgs=pixel_values, c=text_embeddings)
|