|
from transformers import PreTrainedModel, PretrainedConfig |
|
from .module import ConditionalViT |
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
class CondViTConfig(PretrainedConfig): |
|
model_type = "condvit" |
|
|
|
def __init__( |
|
self, |
|
input_resolution: int = 224, |
|
patch_size: int = 16, |
|
width: int = 768, |
|
layers: int = 12, |
|
heads: int = 12, |
|
output_dim: int = 512, |
|
n_categories: int = 10, |
|
lm_backbone: str = "sentence-transformers/sentence-t5-xl", |
|
lm_revision: str = "e0976ba9afd18be963c22c680367a3928c44fd22", |
|
device: str = "cpu", |
|
**kwargs |
|
): |
|
self.input_resolution = input_resolution |
|
self.patch_size = patch_size |
|
self.width = width |
|
self.layers = layers |
|
self.heads = heads |
|
self.output_dim = output_dim |
|
self.n_categories = n_categories |
|
|
|
self.lm_backbone = lm_backbone |
|
self.lm_revision = lm_revision |
|
|
|
self.device = device |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class CondViTForEmbedding(PreTrainedModel): |
|
config_class = CondViTConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.condvit = ConditionalViT( |
|
input_resolution=config.input_resolution, |
|
patch_size=config.patch_size, |
|
width=config.width, |
|
layers=config.layers, |
|
heads=config.heads, |
|
output_dim=config.output_dim, |
|
) |
|
if config.device: |
|
self.condvit.to(config.device) |
|
|
|
self.lm = SentenceTransformer( |
|
config.lm_backbone, revision=config.lm_revision, device=config.device |
|
) |
|
|
|
def forward(self, pixel_values, texts=None): |
|
if texts is not None: |
|
text_embeddings = self.lm.encode( |
|
texts, |
|
convert_to_tensor=True, |
|
convert_to_numpy=False, |
|
) |
|
text_embeddings = text_embeddings.to(pixel_values.device) |
|
else: |
|
text_embeddings = None |
|
return self.condvit(imgs=pixel_values, c=text_embeddings) |
|
|