Spaces:
Build error
Build error
''' | |
Code adapted from Stitch it in Time by Tzaban et al. | |
https://github.com/rotemtzaban/STIT | |
''' | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from pathlib import Path | |
import os | |
import clip | |
imagenet_templates = [ | |
'a bad photo of a {}.', | |
'a photo of many {}.', | |
'a sculpture of a {}.', | |
'a photo of the hard to see {}.', | |
'a low resolution photo of the {}.', | |
'a rendering of a {}.', | |
'graffiti of a {}.', | |
'a bad photo of the {}.', | |
'a cropped photo of the {}.', | |
'a tattoo of a {}.', | |
'the embroidered {}.', | |
'a photo of a hard to see {}.', | |
'a bright photo of a {}.', | |
'a photo of a clean {}.', | |
'a photo of a dirty {}.', | |
'a dark photo of the {}.', | |
'a drawing of a {}.', | |
'a photo of my {}.', | |
'the plastic {}.', | |
'a photo of the cool {}.', | |
'a close-up photo of a {}.', | |
'a black and white photo of the {}.', | |
'a painting of the {}.', | |
'a painting of a {}.', | |
'a pixelated photo of the {}.', | |
'a sculpture of the {}.', | |
'a bright photo of the {}.', | |
'a cropped photo of a {}.', | |
'a plastic {}.', | |
'a photo of the dirty {}.', | |
'a jpeg corrupted photo of a {}.', | |
'a blurry photo of the {}.', | |
'a photo of the {}.', | |
'a good photo of the {}.', | |
'a rendering of the {}.', | |
'a {} in a video game.', | |
'a photo of one {}.', | |
'a doodle of a {}.', | |
'a close-up photo of the {}.', | |
'a photo of a {}.', | |
'the origami {}.', | |
'the {} in a video game.', | |
'a sketch of a {}.', | |
'a doodle of the {}.', | |
'a origami {}.', | |
'a low resolution photo of a {}.', | |
'the toy {}.', | |
'a rendition of the {}.', | |
'a photo of the clean {}.', | |
'a photo of a large {}.', | |
'a rendition of a {}.', | |
'a photo of a nice {}.', | |
'a photo of a weird {}.', | |
'a blurry photo of a {}.', | |
'a cartoon {}.', | |
'art of a {}.', | |
'a sketch of the {}.', | |
'a embroidered {}.', | |
'a pixelated photo of a {}.', | |
'itap of the {}.', | |
'a jpeg corrupted photo of the {}.', | |
'a good photo of a {}.', | |
'a plushie {}.', | |
'a photo of the nice {}.', | |
'a photo of the small {}.', | |
'a photo of the weird {}.', | |
'the cartoon {}.', | |
'art of the {}.', | |
'a drawing of the {}.', | |
'a photo of the large {}.', | |
'a black and white photo of a {}.', | |
'the plushie {}.', | |
'a dark photo of a {}.', | |
'itap of a {}.', | |
'graffiti of the {}.', | |
'a toy {}.', | |
'itap of my {}.', | |
'a photo of a cool {}.', | |
'a photo of a small {}.', | |
'a tattoo of the {}.', | |
] | |
CONV_CODE_INDICES = [(0, 512), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] | |
FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \ | |
[(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)] | |
def zeroshot_classifier(model, classnames, templates, device): | |
with torch.no_grad(): | |
zeroshot_weights = [] | |
for classname in tqdm(classnames): | |
texts = [template.format(classname) for template in templates] # format with class | |
texts = clip.tokenize(texts).to(device) # tokenize | |
class_embeddings = model.encode_text(texts) # embed with text encoder | |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
class_embedding = class_embeddings.mean(dim=0) | |
class_embedding /= class_embedding.norm() | |
zeroshot_weights.append(class_embedding) | |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) | |
return zeroshot_weights | |
def expand_to_full_dim(partial_tensor): | |
full_dim_tensor = torch.zeros(size=(1, 9088)) | |
start_idx = 0 | |
for conv_start, conv_end in CONV_CODE_INDICES: | |
length = conv_end - conv_start | |
full_dim_tensor[:, conv_start:conv_end] = partial_tensor[start_idx:start_idx + length] | |
start_idx += length | |
return full_dim_tensor | |
def get_direction(neutral_class, target_class, beta, di, clip_model=None): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if clip_model is None: | |
clip_model, _ = clip.load("ViT-B/32", device=device) | |
class_names = [neutral_class, target_class] | |
class_weights = zeroshot_classifier(clip_model, class_names, imagenet_templates, device) | |
dt = class_weights[:, 1] - class_weights[:, 0] | |
dt = dt / dt.norm() | |
dt = dt.float() | |
di = di.float() | |
relevance = di @ dt | |
mask = relevance.abs() > beta | |
direction = relevance * mask | |
direction_max = direction.abs().max() | |
if direction_max > 0: | |
direction = direction / direction_max | |
else: | |
raise ValueError(f'Beta value {beta} is too high for mapping from {neutral_class} to {target_class},' | |
f' try setting it to a lower value') | |
return direction | |
def style_tensor_to_style_dict(style_tensor, refernce_generator): | |
style_layers = refernce_generator.modulation_layers | |
style_dict = {} | |
for layer_idx, layer in enumerate(style_layers): | |
style_dict[layer] = style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] | |
return style_dict | |
def style_dict_to_style_tensor(style_dict, reference_generator): | |
style_layers = reference_generator.modulation_layers | |
style_tensor = torch.zeros(size=(1, 9088)) | |
for layer in style_dict: | |
layer_idx = style_layers.index(layer) | |
style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer] | |
return style_tensor | |
def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None): | |
edit_direction = get_direction(source_class, target_class, beta, di, clip_model) | |
edit_full_dim = expand_to_full_dim(edit_direction) | |
source_s = style_dict_to_style_tensor(source_latent, reference_generator) | |
return source_s + alpha * edit_full_dim |