text2live / models /clip_extractor.py
SupermanxKiaski's picture
Upload 351 files
3b40f46
raw
history blame contribute delete
No virus
4.29 kB
import torch
from torch.nn import functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from CLIP import clip
from util.util import compose_text_with_templates
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ClipExtractor(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
model = clip.load(cfg["clip_model_name"], device=device)[0]
self.model = model.eval().requires_grad_(False)
self.clip_input_size = 224
self.clip_normalize = T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
)
self.basic_transform = T.Compose(
[
# we added interpolation to CLIP positional embedding, allowing to work with arbitrary resolution.
T.Resize(self.clip_input_size, max_size=380),
self.clip_normalize,
]
)
# list of augmentations we apply before calculating the CLIP losses
self.augs = T.Compose(
[
T.RandomHorizontalFlip(p=0.5),
T.RandomApply(
[
T.RandomAffine(
degrees=15,
translate=(0.1, 0.1),
fill=cfg["clip_affine_transform_fill"],
interpolation=InterpolationMode.BILINEAR,
)
],
p=0.8,
),
T.RandomPerspective(
distortion_scale=0.4,
p=0.5,
interpolation=InterpolationMode.BILINEAR,
fill=cfg["clip_affine_transform_fill"],
),
T.RandomApply([T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)], p=0.7),
T.RandomGrayscale(p=0.15),
]
)
self.n_aug = cfg["n_aug"]
def augment_input(self, input, n_aug=None, clip_input_size=None):
if n_aug is None:
n_aug = self.n_aug
if clip_input_size is None:
clip_input_size = self.clip_input_size
cutouts = []
cutout = T.Resize(clip_input_size, max_size=320)(input)
cutout_h, cutout_w = cutout.shape[-2:]
cutout = self.augs(cutout)
cutouts.append(cutout)
sideY, sideX = input.shape[2:4]
for _ in range(n_aug - 1):
s = (
torch.zeros(
1,
)
.uniform_(0.6, 1)
.item()
)
h = int(sideY * s)
w = int(sideX * s)
cutout = T.RandomCrop(size=(h, w))(input)
cutout = T.Resize((cutout_h, cutout_w))(cutout)
cutout = self.augs(cutout)
cutouts.append(cutout)
cutouts = torch.cat(cutouts)
return cutouts
def get_image_embedding(self, x, aug=True):
if aug:
views = self.augment_input(x)
else:
views = self.basic_transform(x)
if type(views) == list:
image_embeds = []
for view in views:
image_embeds.append(self.encode_image(self.clip_normalize(view)))
image_embeds = torch.cat(image_embeds)
else:
image_embeds = self.encode_image(self.clip_normalize(views))
return image_embeds
def encode_image(self, x):
return self.model.encode_image(x)
def get_text_embedding(self, text, template, average_embeddings=False):
if type(text) == str:
text = [text]
embeddings = []
for prompt in text:
with torch.no_grad():
embedding = self.model.encode_text(
clip.tokenize(compose_text_with_templates(prompt, template)).to(device)
)
embeddings.append(embedding)
embeddings = torch.cat(embeddings)
if average_embeddings:
embeddings = embeddings.mean(dim=0, keepdim=True)
return embeddings
def get_self_sim(self, x):
x = self.basic_transform(x)
return self.model.calculate_self_sim(x)