Holiday-StyleGAN-NADA / styleclip /styleclip_global.py
mjdolan's picture
Duplicate from Gradio-Blocks/StyleGAN-NADA
07998f9
raw
history blame
6.32 kB
'''
Code adapted from Stitch it in Time by Tzaban et al.
https://github.com/rotemtzaban/STIT
'''
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
import os
import clip
imagenet_templates = [
'a bad photo of a {}.',
'a photo of many {}.',
'a sculpture of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a tattoo of a {}.',
'the embroidered {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a plastic {}.',
'a photo of the dirty {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a {} in a video game.',
'a photo of one {}.',
'a doodle of a {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the origami {}.',
'the {} in a video game.',
'a sketch of a {}.',
'a doodle of the {}.',
'a origami {}.',
'a low resolution photo of a {}.',
'the toy {}.',
'a rendition of the {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a rendition of a {}.',
'a photo of a nice {}.',
'a photo of a weird {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a sketch of the {}.',
'a embroidered {}.',
'a pixelated photo of a {}.',
'itap of the {}.',
'a jpeg corrupted photo of the {}.',
'a good photo of a {}.',
'a plushie {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'the cartoon {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a black and white photo of a {}.',
'the plushie {}.',
'a dark photo of a {}.',
'itap of a {}.',
'graffiti of the {}.',
'a toy {}.',
'itap of my {}.',
'a photo of a cool {}.',
'a photo of a small {}.',
'a tattoo of the {}.',
]
CONV_CODE_INDICES = [(0, 512), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)]
FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
[(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
def zeroshot_classifier(model, classnames, templates, device):
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template.format(classname) for template in templates] # format with class
texts = clip.tokenize(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts) # embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
def expand_to_full_dim(partial_tensor):
full_dim_tensor = torch.zeros(size=(1, 9088))
start_idx = 0
for conv_start, conv_end in CONV_CODE_INDICES:
length = conv_end - conv_start
full_dim_tensor[:, conv_start:conv_end] = partial_tensor[start_idx:start_idx + length]
start_idx += length
return full_dim_tensor
def get_direction(neutral_class, target_class, beta, di, clip_model=None):
device = "cuda" if torch.cuda.is_available() else "cpu"
if clip_model is None:
clip_model, _ = clip.load("ViT-B/32", device=device)
class_names = [neutral_class, target_class]
class_weights = zeroshot_classifier(clip_model, class_names, imagenet_templates, device)
dt = class_weights[:, 1] - class_weights[:, 0]
dt = dt / dt.norm()
dt = dt.float()
di = di.float()
relevance = di @ dt
mask = relevance.abs() > beta
direction = relevance * mask
direction_max = direction.abs().max()
if direction_max > 0:
direction = direction / direction_max
else:
raise ValueError(f'Beta value {beta} is too high for mapping from {neutral_class} to {target_class},'
f' try setting it to a lower value')
return direction
def style_tensor_to_style_dict(style_tensor, refernce_generator):
style_layers = refernce_generator.modulation_layers
style_dict = {}
for layer_idx, layer in enumerate(style_layers):
style_dict[layer] = style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]]
return style_dict
def style_dict_to_style_tensor(style_dict, reference_generator):
style_layers = reference_generator.modulation_layers
style_tensor = torch.zeros(size=(1, 9088))
for layer in style_dict:
layer_idx = style_layers.index(layer)
style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]
return style_tensor
def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
edit_full_dim = expand_to_full_dim(edit_direction)
source_s = style_dict_to_style_tensor(source_latent, reference_generator)
return source_s + alpha * edit_full_dim