Spaces:
Runtime error
Runtime error
from transformers import CLIPModel, CLIPTokenizer | |
import os | |
import json | |
import argparse | |
from random import shuffle, seed | |
import string | |
# non-standard dependencies: | |
import h5py | |
from six.moves import cPickle | |
import numpy as np | |
import torch | |
import torchvision.models as models | |
import skimage.io | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
from PIL import Image | |
from torch import nn | |
class CLIPScore(nn.Module): | |
def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False): | |
super(CLIPScore, self).__init__() | |
# from transformers import CLIPModel, CLIPTokenizer | |
self.clip_model = CLIPModel.from_pretrained( | |
'openai/clip-vit-base-patch32') | |
self.tokenizer = CLIPTokenizer.from_pretrained( | |
'openai/clip-vit-base-patch32') | |
self.clip_model.eval() | |
self.clipscore_w = clipscore_w | |
self.image_transform = self._transform(image_size) | |
self.mode = mode | |
assert mode in ['clip_s', 'refclip_s'] | |
self.use_grammar = use_grammar | |
self.joint_out = joint_out | |
if self.use_grammar and joint_out is False: | |
self.grammar_score_head = nn.Sequential( | |
nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False), | |
nn.ReLU(), | |
nn.Linear(self.clip_model.projection_dim, 2, bias=False) | |
) | |
def _transform(self, n_px): | |
return Compose([ | |
Resize(n_px, interpolation=Image.BICUBIC), | |
CenterCrop(n_px), | |
lambda image: image.convert("RGB"), | |
ToTensor(), | |
Normalize((0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711)), | |
]) | |
def load_image(self, image_path): | |
image = Image.open(image_path) | |
return image | |
# @torch.no_grad() | |
def image_extract(self, image): | |
if isinstance(image, str): | |
image = self.load_image(image) | |
if not isinstance(image, torch.Tensor): | |
image = self.image_transform(image) | |
img_tensor = image.view(-1, 3, 224, 224) | |
device = next(self.clip_model.parameters()).device | |
img_tensor = img_tensor.to(device) | |
clip_model = self.clip_model | |
img_feat = clip_model.vision_model(img_tensor).pooler_output | |
img_feat = clip_model.visual_projection(img_feat) | |
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) | |
return img_feat | |
# @torch.no_grad() | |
def text_extract(self, text, prompt="A photo depicts", proj_norm=True): | |
if isinstance(text, str): | |
text_batch = [" ".join([prompt, text])] | |
elif isinstance(text, list): | |
text_batch = [" ".join([prompt, txt]) for txt in text] | |
if isinstance(text, tuple) and isinstance(text[0], torch.Tensor): | |
input_ids, attention_mask = text | |
else: | |
input_text = text_batch | |
tokenized = self.tokenizer( | |
input_text, return_tensors='pt', padding=True, truncation=True) | |
input_ids = tokenized.input_ids | |
attention_mask = tokenized.attention_mask | |
clip_model = self.clip_model | |
device = next(self.clip_model.parameters()).device | |
input_ids = input_ids.to(device) | |
attention_mask = attention_mask.to(device) | |
text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output | |
if proj_norm: | |
text_feat = clip_model.text_projection(text_feat) | |
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) | |
return text_feat | |
# @torch.no_grad() | |
def calc_clip_s(self, img_feat, text_feat): | |
return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) | |
# @torch.no_grad() | |
def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): | |
if clip_s is None: | |
clip_s = self.calc_clip_s(img_feat, text_feat) | |
B, dim = img_feat.size() | |
ref_text_feat = ref_text_feat.view(B, -1, dim) | |
K = ref_text_feat.size(1) | |
text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) | |
assert ref_text_feat.size() == text_feat.size( | |
), (ref_text_feat.size(), text_feat.size()) | |
ref_score = self.calc_clip_s(text_feat, ref_text_feat) | |
if ref_text_mask is not None: | |
if not isinstance(ref_text_mask, torch.Tensor): | |
ref_text_mask = torch.tensor( | |
ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) | |
ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) | |
ref_score = ref_score.view(B, K).max(dim=1).values | |
assert clip_s.size() == (B,) | |
assert clip_s.size() == ref_score.size() | |
# harmonic mean | |
refclip_s = 2 / (1 / clip_s + 1 / ref_score) | |
return refclip_s | |
def forward(self, | |
images=None, text=None, | |
img_feat=None, text_feat=None, | |
ref_text=None, ref_text_feat=None, ref_text_mask=None, | |
prompt="A photo depicts", | |
mode=None): | |
if img_feat is None: | |
img_feat = self.image_extract(images) | |
img_feat = img_feat.view(-1, 512) | |
B = img_feat.size(0) | |
if text_feat is None: | |
text_feat = self.text_extract(text, prompt=prompt) | |
text_feat = text_feat.view(-1, 512) | |
if mode is None: | |
mode = self.mode | |
assert mode in ['clip_s', 'refclip_s'] | |
if mode == 'clip_s': | |
clip_s = self.calc_clip_s(img_feat, text_feat) | |
return clip_s | |
elif mode == 'refclip_s': | |
if ref_text_feat is None: | |
ref_text_feat = self.text_extract(ref_text, prompt=prompt) | |
ref_text_feat = ref_text_feat.view(-1, 512) | |
refclip_s = self.calc_refclip_s( | |
img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) | |
return refclip_s | |
def train_step(self, | |
images=None, text=None, | |
img_feat=None, text_feat=None, | |
neg_text=None, neg_text_feat=None, | |
# ref_text=None, ref_text_feat=None, ref_text_mask=None, | |
prompt="A photo depicts", | |
# return_loss=True, | |
**kwargs): | |
if img_feat is None: | |
img_feat = self.image_extract(images) | |
img_feat = img_feat.view(-1, 512) | |
B = img_feat.size(0) | |
if text_feat is None: | |
text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) | |
text_cont_feat = self.clip_model.text_projection(text_feat) | |
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) | |
text_cont_feat = text_cont_feat.view(B, 512) | |
# cosine similarity as logits | |
logit_scale = self.clip_model.logit_scale.exp() | |
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale | |
# logits_per_image = logits_per_text.T | |
clip_loss = clip_loss_fn(logits_per_text) | |
# negative sampling | |
pos_text_feat = text_feat.view(B, 512) | |
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) | |
grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) | |
# 2B, 1 | |
grammar_text_logit = self.grammar_score_head(grammar_text_feat) | |
grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) | |
grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) | |
grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) | |
grammar_pos_pred = grammar_pred[:B] | |
grammar_neg_pred = grammar_pred[B:] | |
# grammar_acc = (grammar_pred == grammar_labels).float().mean() | |
out = { | |
'clip_loss': clip_loss, | |
'grammar_loss': grammar_loss, | |
'img_feat': img_feat, | |
'text_feat': text_cont_feat, | |
'neg_text_feat': neg_text_feat, | |
'grammar_pos_pred': grammar_pos_pred, | |
'grammar_neg_pred': grammar_neg_pred, | |
} | |
return out | |
# contrastive loss function, adapted from | |
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html | |
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: | |
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) | |
return -neg_ce.mean() | |
def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor: | |
caption_loss = contrastive_loss(similarity, dim=0) | |
image_loss = contrastive_loss(similarity, dim=1) | |
return (caption_loss + image_loss) / 2.0 | |
# class CLIPScore(nn.Module): | |
# def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s'): | |
# super(CLIPScore, self).__init__() | |
# # from transformers import CLIPModel, CLIPTokenizer | |
# self.clip_model = CLIPModel.from_pretrained( | |
# 'openai/clip-vit-base-patch32') | |
# self.tokenizer = CLIPTokenizer.from_pretrained( | |
# 'openai/clip-vit-base-patch32') | |
# self.clip_model.eval() | |
# self.clipscore_w = clipscore_w | |
# self.image_transform = self._transform(image_size) | |
# self.mode = mode | |
# assert mode in ['clip_s', 'refclip_s'] | |
# def _transform(self, n_px): | |
# return Compose([ | |
# Resize(n_px, interpolation=Image.BICUBIC), | |
# CenterCrop(n_px), | |
# lambda image: image.convert("RGB"), | |
# ToTensor(), | |
# Normalize((0.48145466, 0.4578275, 0.40821073), | |
# (0.26862954, 0.26130258, 0.27577711)), | |
# ]) | |
# def load_image(self, image_path): | |
# image = Image.open(image_path) | |
# return image | |
# @torch.no_grad() | |
# def image_extract(self, image): | |
# if isinstance(image, str): | |
# image = self.load_image(image) | |
# if not isinstance(image, torch.Tensor): | |
# image = self.image_transform(image) | |
# img_tensor = image.view(-1, 3, 224, 224) | |
# device = next(self.clip_model.parameters()).device | |
# img_tensor = img_tensor.to(device) | |
# clip_model = self.clip_model | |
# img_feat = clip_model.vision_model(img_tensor).pooler_output | |
# img_feat = clip_model.visual_projection(img_feat) | |
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) | |
# return img_feat | |
# @torch.no_grad() | |
# def text_extract(self, text, prompt="A photo depicts"): | |
# if isinstance(text, str): | |
# text_batch = [" ".join([prompt, text])] | |
# else: | |
# text_batch = [" ".join([prompt, txt]) for txt in text] | |
# input_text = text_batch | |
# tokenized = self.tokenizer( | |
# input_text, return_tensors='pt', padding=True) | |
# input_ids = tokenized.input_ids | |
# attention_mask = tokenized.attention_mask | |
# clip_model = self.clip_model | |
# device = next(self.clip_model.parameters()).device | |
# input_ids = input_ids.to(device) | |
# attention_mask = attention_mask.to(device) | |
# text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output | |
# text_feat = clip_model.text_projection(text_feat) | |
# text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) | |
# return text_feat | |
# @torch.no_grad() | |
# def calc_clip_s(self, img_feat, text_feat): | |
# return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) | |
# @torch.no_grad() | |
# def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): | |
# if clip_s is None: | |
# clip_s = self.calc_clip_s(img_feat, text_feat) | |
# B, dim = img_feat.size() | |
# ref_text_feat = ref_text_feat.view(B, -1, dim) | |
# K = ref_text_feat.size(1) | |
# text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) | |
# assert ref_text_feat.size() == text_feat.size(), (ref_text_feat.size(), text_feat.size()) | |
# ref_score = self.calc_clip_s(text_feat, ref_text_feat) | |
# if ref_text_mask is not None: | |
# if not isinstance(ref_text_mask, torch.Tensor): | |
# ref_text_mask = torch.tensor(ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) | |
# ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) | |
# ref_score = ref_score.view(B, K).max(dim=1).values | |
# assert clip_s.size() == (B,) | |
# assert clip_s.size() == ref_score.size() | |
# # harmonic mean | |
# refclip_s = 2 / (1 / clip_s + 1 / ref_score) | |
# return refclip_s | |
# @torch.no_grad() | |
# def forward(self, | |
# images=None, text=None, | |
# img_feat=None, text_feat=None, | |
# ref_text=None, ref_text_feat=None, ref_text_mask=None, | |
# prompt="A photo depicts", | |
# mode=None): | |
# if img_feat is None: | |
# img_feat = self.image_extract(images) | |
# img_feat = img_feat.view(-1, 512) | |
# if text_feat is None: | |
# text_feat = self.text_extract(text, prompt=prompt) | |
# text_feat = text_feat.view(-1, 512) | |
# if mode is None: | |
# mode = self.mode | |
# assert mode in ['clip_s', 'refclip_s'] | |
# if mode == 'clip_s': | |
# clip_s = self.calc_clip_s(img_feat, text_feat) | |
# return clip_s | |
# elif mode == 'refclip_s': | |
# if ref_text_feat is None: | |
# ref_text_feat = self.text_extract(ref_text, prompt=prompt) | |
# ref_text_feat = ref_text_feat.view(-1, 512) | |
# refclip_s = self.calc_refclip_s(img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) | |
# return refclip_s | |