Spaces:
Running
on
Zero
Running
on
Zero
# Borrowed from ip-adapter resampler.py. | |
# https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py | |
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py | |
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py | |
import math | |
import torch | |
from torch import nn | |
from einops import rearrange | |
from einops.layers.torch import Rearrange | |
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig | |
from torch import einsum | |
from adaface.util import gen_gradient_scaler | |
from adaface.arc2face_models import CLIPTextModelWrapper | |
def reshape_tensor(x, num_heads): | |
bs, length, width = x.shape | |
# (bs, length, width) --> (bs, length, n_heads, dim_per_head) | |
x = x.view(bs, length, num_heads, -1) | |
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) | |
x = x.transpose(1, 2).contiguous() | |
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) | |
x = x.reshape(bs, num_heads, length, -1) | |
return x | |
# FFN. Added a Dropout layer at the end, so that it can still load the old ckpt. | |
def FeedForward(dim, mult=4, p_dropout=0.1): | |
inner_dim = int(dim * mult) | |
return nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, inner_dim, bias=False), | |
nn.GELU(), | |
nn.Linear(inner_dim, dim, bias=False), | |
nn.Dropout(p_dropout), | |
) | |
# IP-Adapter FaceID class. Only used in knn-faces.py. | |
# From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py | |
class IP_MLPProjModel(nn.Module): | |
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.num_tokens = num_tokens | |
self.proj = nn.Sequential( | |
nn.Linear(id_embeddings_dim, id_embeddings_dim*2), | |
nn.GELU(), | |
nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), | |
) | |
self.norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, id_embeds): | |
x = self.proj(id_embeds) | |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) | |
x = self.norm(x) | |
return x | |
# group_dim: the tensor dimension that corresponds to the multiple groups. | |
class LearnedSoftAggregate(nn.Module): | |
def __init__(self, num_feat, group_dim, keepdim=False): | |
super(LearnedSoftAggregate, self).__init__() | |
self.group_dim = group_dim | |
# num_feat = 1: element-wise score function & softmax. | |
# num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor. | |
self.num_feat = num_feat | |
self.feat2score = nn.Linear(num_feat, 1, bias=False) | |
self.keepdim = keepdim | |
def forward(self, x, score_basis=None): | |
# If there's only one mode, do nothing. | |
if x.shape[self.group_dim] == 1: | |
if self.keepdim: | |
return x | |
else: | |
return x.squeeze(self.group_dim) | |
# Assume the last dim of x is the feature dim. | |
if score_basis is None: | |
score_basis = x | |
if self.num_feat == 1: | |
mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1) | |
else: | |
mode_scores = self.feat2score(score_basis) | |
attn_probs = mode_scores.softmax(dim=self.group_dim) | |
x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim) | |
return x_aggr | |
def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes, | |
num_output_vecs, elementwise_affine=True, p_dropout=0.1): | |
return nn.Sequential( | |
# Project to [BS, lora_rank * output_dim * num_modes]. | |
# It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456. | |
nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False), | |
# Reshape to [BS, lora_rank, output_dim]. | |
Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim), | |
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), | |
# Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim]. | |
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \ | |
else Rearrange('b () q d -> b q d'), | |
nn.Dropout(p_dropout), | |
# Permute to [BS, output_dim, lora_rank]. | |
Rearrange('b q d -> b d q'), | |
# Project to [BS, output_dim, num_output_vecs]. | |
nn.Linear(lora_rank, num_output_vecs, bias=False), | |
# Permute to [BS, num_output_vecs, output_dim]. | |
Rearrange('b d q -> b q d'), | |
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), | |
nn.Dropout(p_dropout), | |
) | |
def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1): | |
return nn.Sequential( | |
# Project to [BS, num_output_vecs * output_dim]. | |
nn.Linear(input_dim, expansion_ratio * output_dim, bias=False), | |
# Reshape to [BS, num_output_vecs, output_dim]. | |
Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim), | |
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), | |
nn.Dropout(p_dropout), | |
) | |
# Input: [BS, N, D]. | |
def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1): | |
if output_dim == -1: | |
output_dim = input_dim | |
return nn.Sequential( | |
nn.Linear(input_dim, output_dim * num_modes, bias=False), | |
# Reshape to [BS, num_output_vecs, output_dim]. | |
Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim), | |
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), | |
# If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes. | |
LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \ | |
else Rearrange('b n () d -> b n d'), | |
nn.Dropout(p_dropout), | |
) | |
# Low-rank to high-rank transformation. | |
def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1): | |
return nn.Sequential( | |
# Permute to [BS, output_dim, lora_rank]. | |
Rearrange('b q d -> b d q'), | |
# Project to [BS, output_dim, hira_rank]. | |
nn.Linear(lora_rank, hira_rank * num_modes, bias=False), | |
# Reshape and permute to [BS, num_modes, num_output_vecs, output_dim]. | |
Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank), | |
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), | |
# Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim]. | |
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \ | |
else Rearrange('b () q d -> b q d'), | |
nn.Dropout(p_dropout), | |
) | |
class PerceiverAttention(nn.Module): | |
def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.dim_head = dim_head | |
self.num_heads = num_heads | |
inner_dim = dim_head * num_heads | |
self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine) | |
self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine) | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) | |
self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
def forward(self, x, latent_queries): | |
""" | |
Args: | |
x (torch.Tensor): image features | |
shape (b, n1, D) | |
latent (torch.Tensor): latent features | |
shape (b, n2, D) | |
""" | |
x = self.norm1(x) | |
latent_queries = self.norm2(latent_queries) | |
b, l, _ = latent_queries.shape | |
q = self.to_q(latent_queries) | |
kv_input = torch.cat((x, latent_queries), dim=-2) | |
k, v = self.to_kv(kv_input).chunk(2, dim=-1) | |
q = reshape_tensor(q, self.num_heads) | |
k = reshape_tensor(k, self.num_heads) | |
v = reshape_tensor(v, self.num_heads) | |
# attention | |
scale = 1 / math.sqrt(math.sqrt(self.dim_head)) | |
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards | |
attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
out = attn @ v | |
out = out.permute(0, 2, 1, 3).reshape(b, l, -1) | |
return self.to_out(out) | |
class CrossAttention(nn.Module): | |
# output_dim is always the same as input_dim. | |
# num_q only matters when q_aware_to_v is True. | |
# If q_aware_to_v is False, query x in forward() is still usable. | |
def __init__(self, input_dim, num_heads=6, p_dropout=0.05, | |
identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True, | |
q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64, | |
identity_to_out=False, out_has_skip=False): | |
super().__init__() | |
dim_head = input_dim // num_heads | |
inner_dim = dim_head * num_heads | |
self.num_heads = num_heads | |
self.q_aware_to_v = q_aware_to_v | |
self.v_has_skip = v_has_skip | |
self.to_q = nn.Sequential( | |
nn.Linear(input_dim, inner_dim, bias=False), | |
nn.LayerNorm(inner_dim, elementwise_affine=True) | |
) if not identity_to_q else nn.Identity() | |
self.to_k = nn.Sequential( | |
nn.Linear(input_dim, inner_dim, bias=False), | |
nn.LayerNorm(inner_dim, elementwise_affine=True) | |
) if not identity_to_k else nn.Identity() | |
self.v_repeat = v_repeat | |
self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104. | |
# If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim. | |
# Otherwise, self.to_v consists of a single projection of input_dim to inner_dim. | |
if q_aware_to_v: | |
# all_q_mid: 104 * 64 = 6656. | |
all_q_mid = num_q_group * q_aware_to_v_lora_rank | |
self.to_v = nn.Sequential( | |
# number of params: 768 * 6656 = 5,111,808. | |
# Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656]. | |
# Each 768-dim vec is dispersed into 104 64-dim vecs. | |
nn.Linear(input_dim, all_q_mid, bias=False), | |
nn.LayerNorm(all_q_mid, elementwise_affine=True), | |
# Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1. | |
Rearrange('b n q -> b q n', q=all_q_mid), | |
# Each q_aware_to_v projection has its own linear layer. | |
# The total number of parameters will be 6656*768 = 5,111,808. | |
# Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim. | |
nn.Conv1d( | |
in_channels=all_q_mid, | |
out_channels=num_q_group * input_dim, | |
kernel_size=1, | |
groups=num_q_group, | |
bias=False, | |
), | |
# Output: [BS, 104, 16, 768]. | |
Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim), | |
nn.LayerNorm(input_dim, elementwise_affine=True), | |
) | |
else: | |
self.to_v = nn.Sequential( | |
nn.Linear(input_dim, inner_dim, bias=False), | |
nn.LayerNorm(inner_dim, elementwise_affine=True) | |
) if not identity_to_v else nn.Identity() | |
if identity_to_out: | |
assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False." | |
if identity_to_out: | |
self.to_out = nn.Identity() | |
else: | |
self.to_out = nn.Sequential( | |
nn.Linear(input_dim, input_dim, bias=False), | |
nn.Dropout(p_dropout), | |
nn.LayerNorm(inner_dim, elementwise_affine=True) | |
) | |
self.out_has_skip = out_has_skip | |
self.attn_drop = nn.Dropout(p_dropout) | |
def forward(self, x, context=None, attn_mat=None, return_attn=False): | |
h = self.num_heads | |
if context is None: | |
context = x | |
if attn_mat is None: | |
# q: [BS, Q, D] -> [BS, Q, D]. | |
q = self.to_q(x) | |
# k: [BS, L, D] -> [BS, L, D]. | |
k = self.to_k(context) | |
# q: [6, 512, 128], k: [6, 17, 128]. | |
q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k)) | |
if self.q_aware_to_v: | |
# context: [BS, L, D]. v: [BS, Q, L, D]. | |
# There are effectively Q to_v projections. | |
v = self.to_v(context) | |
if self.v_has_skip: | |
v = v + context.unsqueeze(1) | |
else: | |
# v: [BS, L, D]. | |
v = self.to_v(context) | |
if self.v_has_skip: | |
v = v + context | |
#print(v.shape) | |
if self.q_aware_to_v: | |
# v: [6, 64, 17, 128]. | |
# v is query-specific, so there's an extra dim for the query. | |
v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h).contiguous() | |
# Each v is for a query group with 512/64 = 8 queries. | |
# So each v is repeated 8 times to match the number of queries. | |
# v: [6, 64, 17, 128] -> [6, 512, 17, 128]. | |
v = v.repeat(1, self.v_repeat, 1, 1) | |
else: | |
v = rearrange(v, 'b n (h d) -> (b h) n d', h=h).contiguous() | |
if attn_mat is None: | |
scale = q.size(-1) ** -0.25 | |
sim = einsum('b i d, b j d -> b i j', q * scale, k * scale) | |
# sim: [6, 64, 17]. 6: bs 1 * h 6. | |
# attention, what we cannot get enough of | |
# NOTE: the normalization is done across tokens, not across pixels. | |
# So for each pixel, the sum of attention scores across tokens is 1. | |
attn = sim.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
#print(attn.std()) | |
else: | |
attn = attn_mat | |
if self.q_aware_to_v: | |
# attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128]. | |
# out is combined with different attn weights and v for different queries. | |
out = einsum('b i j, b i j d -> b i d', attn, v) | |
else: | |
# v: [6, 17, 128]. out: [6, 32, 128]. | |
out = einsum('b i j, b j d -> b i d', attn, v) | |
# [6, 32, 128] -> [1, 32, 768]. | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h).contiguous() | |
if self.out_has_skip: | |
out = self.to_out(out) + out | |
else: | |
out = self.to_out(out) | |
if return_attn: | |
return out, attn | |
else: | |
return out | |
class ImgPrompt2TextPrompt(nn.Module): | |
def __init__(self, placeholder_is_bg, num_id_vecs, dtype=torch.float32, *args, **kwargs): | |
super().__init__() | |
self.N_ID = num_id_vecs | |
# If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components(). | |
self.N_SFX = 0 | |
if not placeholder_is_bg: | |
self.initialize_text_components(*args, **kwargs) | |
# prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**. | |
# prompt2token_proj is with the same architecture as the original arc2face text encoder, | |
# but retrained to do inverse mapping. | |
# To be initialized in the subclass. | |
self.prompt2token_proj = None | |
self.dtype = dtype | |
def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768): | |
self.N_SFX = num_static_img_suffix_embs | |
# We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs. | |
# So it's OK that static_img_suffix_embs is larger than required number num_static_img_suffix_embs. | |
# This holds even if num_static_img_suffix_embs is 0. | |
if hasattr(self, 'static_img_suffix_embs') and self.static_img_suffix_embs is not None: | |
if self.static_img_suffix_embs.shape[1] == self.N_SFX: | |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.") | |
elif self.static_img_suffix_embs.shape[1] < self.N_SFX: | |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.") | |
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) | |
elif self.N_SFX > 0: | |
# self.static_img_suffix_embs.shape[1] > self.N_SFX > 0. | |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.") | |
self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX]) | |
else: | |
# self.static_img_suffix_embs.shape[1] > self.N_SFX == 0. | |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.") | |
self.static_img_suffix_embs = None | |
else: | |
if self.N_SFX > 0: | |
# Either static_img_suffix_embs does not exist or is None, | |
# or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare, | |
# so we don't consider to reuse and extend a shorter static_img_suffix_embs). | |
# So we reinitialize it. | |
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) | |
else: | |
# If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance. | |
self.static_img_suffix_embs = None | |
# Implement a separate initialization function, so that it can be called from SubjBasisGenerator | |
# after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator | |
# ckpts which were not subclassed from ImgPrompt2TextPrompt. | |
def initialize_text_components(self, max_prompt_length=77, num_id_vecs=16, | |
num_static_img_suffix_embs=0, img_prompt_dim=768): | |
self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim) | |
self.max_prompt_length = max_prompt_length | |
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
# clip_text_embeddings: CLIPTextEmbeddings instance. | |
clip_text_embeddings = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").text_model.embeddings | |
# clip_text_embeddings() and clip_text_embeddings.token_embedding() differ in that | |
# clip_text_embeddings() adds positional embeddings, while clip_text_embeddings.token_embedding() doesn't. | |
# Adding positional embeddings seems to help somewhat. | |
# pad_tokens: pad_token_id 49407 repeated 77 times. | |
# pad_token_id is the EOS token. But BOS is 49406. | |
pad_tokens = torch.tensor([self.tokenizer.pad_token_id]).repeat(self.max_prompt_length) | |
# pad_embeddings: [77, 768]. | |
# pad_embeddings is still on CPU. But should be moved to GPU automatically. | |
# Note: detach pad_embeddings from the computation graph, otherwise | |
# deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail. | |
self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach() | |
# image prompt space -> text prompt space. | |
# return_emb_types: a list of strings, each string is among | |
# ['full', 'core', 'full_pad', 'full_half_pad']. | |
def inverse_img_prompt_embs(self, face_prompt_embs, list_extra_words, | |
return_emb_types, hidden_state_layer_weights=None, | |
enable_static_img_suffix_embs=False): | |
''' | |
face_prompt_embs: (BS, self.N_ID, 768), in the image prompt space. | |
Only the core embeddings, no paddings. | |
list_extra_words: None or [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt. | |
''' | |
if list_extra_words is not None: | |
if len(list_extra_words) != len(face_prompt_embs): | |
if len(face_prompt_embs) > 1: | |
print("Warn: list_extra_words has different length as face_prompt_embs.") | |
if len(list_extra_words) == 1: | |
list_extra_words = list_extra_words * len(face_prompt_embs) | |
else: | |
breakpoint() | |
else: | |
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_comp_prompt_distillation. | |
# But list_extra_words always corresponds to the actual batch size. So we only take the first element. | |
list_extra_words = list_extra_words[:1] | |
for extra_words in list_extra_words: | |
assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words." | |
# 16 or 4 ", " are placeholders for face_prompt_embs. | |
prompt_templates = [ "photo of a " + ", " * self.N_ID + list_extra_words[i] for i in range(len(list_extra_words)) ] | |
else: | |
# 16 or 4 ", " are placeholders for face_prompt_embs. | |
# No extra words are added to the prompt. So we add 2 more ", " to the template to keep | |
# the number of tokens roughly the same as when extra words are added. | |
prompt_templates = [ "photo of a " + ", " * (self.N_ID + 2) for _ in range(len(face_prompt_embs)) ] | |
# This step should be quite fast, and there's no need to cache the input_ids. | |
# input_ids: [BS, 77]. | |
input_ids = self.tokenizer( | |
prompt_templates, | |
truncation=True, | |
padding="max_length", | |
max_length=self.max_prompt_length, | |
return_tensors="pt", | |
).input_ids.to(face_prompt_embs.device) | |
face_prompt_embs_orig_dtype = face_prompt_embs.dtype | |
face_prompt_embs = face_prompt_embs.to(self.dtype) | |
ID_END = 4 + self.N_ID | |
PAD_BEGIN = ID_END + self.N_SFX + 2 | |
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping). | |
token_embs = self.prompt2token_proj(input_ids=input_ids, return_token_embs=True) | |
# token 4: first ", " in the template prompt. | |
# Replace embeddings of 16 or 4 placeholder ", " with face_prompt_embs. | |
token_embs[:, 4:ID_END] = face_prompt_embs | |
# Only when do_unet_distill == True, we append the static image suffix embeddings. | |
# Otherwise, static image suffix embeddings are ignored, | |
# and token_embs[:, ID_END:ID_END+self.N_SFX] are the filler embeddings of the | |
# extra ", " in the template prompt. | |
if enable_static_img_suffix_embs and self.N_SFX > 0: | |
# Put the static image suffix embeddings right after face_prompt_embs. | |
token_embs[:, ID_END:ID_END+self.N_SFX] = self.static_img_suffix_embs[:, :self.N_SFX] | |
# This call does the ordinary CLIP text encoding pass. | |
prompt_embeds = self.prompt2token_proj( | |
input_ids=input_ids, | |
input_token_embs=token_embs, | |
hidden_state_layer_weights=hidden_state_layer_weights, | |
return_token_embs=False | |
)[0] | |
# Restore the original dtype of prompt_embeds: float16 -> float32. | |
prompt_embeds = prompt_embeds.to(face_prompt_embs_orig_dtype) | |
# token 4: first ", " in the template prompt. | |
# When N_ID == 16, | |
# prompt_embeds 4:20 are the most important 16 embeddings that contain the subject's identity. | |
# 20:22 are embeddings of the (at most) two extra words. | |
# [N, 77, 768] -> [N, 16, 768] | |
if enable_static_img_suffix_embs: | |
core_prompt_embs = prompt_embeds[:, 4:ID_END+self.N_SFX] | |
else: | |
core_prompt_embs = prompt_embeds[:, 4:ID_END] | |
if list_extra_words is not None: | |
# [N, 16, 768] -> [N, 18, 768] | |
extra_words_embs = prompt_embeds[:, ID_END+self.N_SFX:PAD_BEGIN] | |
core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1) | |
returned_prompt_embs = [] | |
for emb_type in return_emb_types: | |
if emb_type == 'full': | |
returned_prompt_embs.append(prompt_embeds) | |
elif emb_type == 'full_half_pad': | |
prompt_embeds2 = prompt_embeds.clone() | |
# PAD_BEGIN is 22 or 10. Also exclude the last EOS token. | |
# So we subtract max_prompt_length by (PAD_BEGIN + 1). | |
PADS = self.max_prompt_length - PAD_BEGIN - 1 | |
if PADS >= 2: | |
# Fill half of the remaining embeddings with pad embeddings. | |
prompt_embeds2[:, PAD_BEGIN:PAD_BEGIN+PADS//2] = self.pad_embeddings[PAD_BEGIN:PAD_BEGIN+PADS//2] | |
returned_prompt_embs.append(prompt_embeds2) | |
elif emb_type == 'full_pad': | |
prompt_embeds2 = prompt_embeds.clone() | |
# Replace the PAD_BEGIN-th to the second last embeddings with pad embeddings. | |
# Skip replacing the last embedding, which might has special roles. | |
# (Although all padding tokens are the same EOS, the last token might acquire special semantics | |
# due to its special position.) | |
prompt_embeds2[:, PAD_BEGIN:-1] = self.pad_embeddings[PAD_BEGIN:-1] | |
returned_prompt_embs.append(prompt_embeds2) | |
elif emb_type == 'full_zeroed_extra': | |
prompt_embeds2 = prompt_embeds.clone() | |
# Only add two pad embeddings. The remaining embeddings are set to 0. | |
# Make the positional embeddings align with the actual positions. | |
prompt_embeds2[:, 22:24] = self.pad_embeddings[22:24] | |
prompt_embeds2[:, 24:-1] = 0 | |
returned_prompt_embs.append(prompt_embeds2) | |
elif emb_type == 'core': | |
returned_prompt_embs.append(core_prompt_embs) | |
else: | |
breakpoint() | |
return returned_prompt_embs | |
class SubjBasisGenerator(ImgPrompt2TextPrompt): | |
def __init__( | |
self, | |
# number of cross-attention heads of the bg prompt translator. | |
# Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14: | |
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json | |
num_bg_encoder_heads=6, | |
# number of subject input identity vectors (only when the subject is not face), | |
# or number of background input identity vectors (no matter the subject is face or not). | |
# 257: 257 CLIP tokens. | |
num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 }, | |
num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4. | |
num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings. | |
bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above. | |
obj_embedding_dim=384, # DINO object feature dimension for objects. | |
output_dim=768, # CLIP text embedding input dimension. | |
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens. | |
prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj. | |
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer. | |
bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection. | |
): | |
# If not placeholder_is_bg, then it calls initialize_text_components() in the superclass. | |
super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs, max_prompt_length=77, | |
num_static_img_suffix_embs=num_static_img_suffix_embs, img_prompt_dim=output_dim) | |
self.placeholder_is_bg = placeholder_is_bg | |
self.num_out_embs = self.N_ID + self.N_SFX | |
self.output_dim = output_dim | |
# num_nonface_in_id_vecs should be the number of core ID embs, 16. | |
# However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set. | |
self.num_nonface_in_id_vecs = num_nonface_in_id_vecs['bg'] if placeholder_is_bg else num_nonface_in_id_vecs['subj'] | |
self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj | |
if not self.placeholder_is_bg: | |
# [1, 384] -> [1, 16, 768]. | |
# TODO: use CLIPTextModelWrapper as obj_proj_in. | |
self.obj_proj_in = ExpandEmbs(obj_embedding_dim, output_dim, expansion_ratio=self.num_nonface_in_id_vecs) | |
# ** prompt2token_proj does the actual job: ** | |
# it is the inverse projection that maps from faceid2img_prompt_embs to adaface_prompt_embs. | |
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings). | |
# If self.placeholder_is_bg: prompt2token_proj is set to None. | |
# Use an attention dropout of 0.2 to increase robustness. | |
clip_dropout_config = None #CLIPTextConfig.from_pretrained('openai/clip-vit-large-patch14', attention_dropout=0.05, dropout=0.05) | |
self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14', | |
config=clip_dropout_config) | |
self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale | |
self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale) | |
print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.") | |
# If prompt2token_proj_grad_scale is 0, freeze all params in prompt2token_proj. | |
# Otherwise, only freeze token and positional embeddings of the original CLIPTextModel. | |
self.freeze_prompt2token_proj() | |
# These multipliers are relative to the original CLIPTextModel. | |
self.prompt2token_proj_attention_multipliers = [1] * 12 | |
self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu') | |
self.bg_proj_in = None | |
self.pos_embs = self.pos_embs_ln = self.latent_queries = self.latent_queries_ln = None | |
else: | |
# For background placeholders, face and object embeddings are not used as they are foreground. | |
self.obj_proj_in = None | |
self.bg_proj_in = nn.Sequential( | |
nn.Linear(bg_image_embedding_dim, output_dim, bias=False), | |
nn.LayerNorm(output_dim), | |
) | |
self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, output_dim)) | |
self.pos_embs_ln = nn.LayerNorm(output_dim) | |
self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim)) | |
self.latent_queries_ln = nn.LayerNorm(output_dim) | |
identity_to_v = False | |
v_has_skip = not identity_to_v # True | |
identity_to_out = not bg_prompt_translator_has_to_out_proj # True | |
out_has_skip = not identity_to_out # False | |
# prompt_translator maps the clip image features (of the background) to the prompt embedding space. | |
# It is only used during training when placeholder_is_bg is True. | |
# prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection. | |
# dim=768, num_bg_encoder_heads=6. | |
self.prompt_translator = \ | |
CrossAttention(input_dim=output_dim, num_heads=num_bg_encoder_heads, p_dropout=0.05, | |
identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v, | |
q_aware_to_v=False, v_has_skip=v_has_skip, | |
num_q=0, # When not q_aware_to_v, num_q is not referenced. | |
identity_to_out=identity_to_out, | |
out_has_skip=out_has_skip) | |
self.output_scale = output_dim ** -0.5 | |
''' | |
prompt_translator: CLIPEncoder | |
# https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566 | |
# CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being | |
(0): CLIPEncoderLayer( | |
(self_attn): CLIPAttention( | |
(k_proj): Linear(in_features=768, out_features=768, bias=True) | |
(v_proj): Linear(in_features=768, out_features=768, bias=True) | |
(q_proj): Linear(in_features=768, out_features=768, bias=True) | |
(out_proj): Linear(in_features=768, out_features=768, bias=True) | |
) | |
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) | |
(mlp): CLIPMLP( | |
(activation_fn): QuickGELUActivation() | |
(fc1): Linear(in_features=768, out_features=3072, bias=True) | |
(fc2): Linear(in_features=3072, out_features=768, bias=True) | |
) | |
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) | |
) | |
''' | |
print(repr(self)) | |
# raw_id_embs: only used when the subject is non-faces. In that case it's DINO embeddings. | |
# Otherwise, raw_id_embs is not used. | |
# faceid2img_prompt_embs: [BS, 16, 768], the core ID prompt embeddings generated by ID2ImgPrompt. | |
def forward(self, faceid2img_prompt_embs, clip_features=None, raw_id_embs=None, out_id_embs_cfg_scale=1.0, | |
is_face=True, enable_static_img_suffix_embs=False): | |
if not self.placeholder_is_bg: | |
BS = faceid2img_prompt_embs.shape[0] | |
else: | |
# If bg, then faceid2img_prompt_embs is set to None, but clip_features is not None. | |
BS = clip_features.shape[0] | |
clip_features = clip_features.to(self.dtype) | |
# No need to use raw_id_embs if placeholder_is_bg. | |
if not self.placeholder_is_bg: | |
if is_face: | |
assert faceid2img_prompt_embs is not None | |
# id2img_embs has been projected to the (modified) prompt embedding space | |
# by ID2AdaPrompt::map_init_id_to_img_prompt_embs(). This prompt embedding space is modified because | |
# the ID2ImgPrompt module (at least when it's arc2face) may have finetuned the | |
# text encoder and the U-Net. | |
# in embedding_manager: [BS, 16, 768] -> [BS, 77, 768]. | |
# faceid2img_prompt_embs is part of id2img_embs: [BS, 77, 768] -> [BS, 16, 768]. | |
# adaface_prompt_embs is projected to the prompt embedding spaces. This is the | |
# original U-Net prompt embedding space. | |
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] | |
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights) | |
# faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space. | |
with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0): | |
# If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens | |
# and (at most) two extra words in adaface_prompt_embs, without BOS and EOS. | |
# If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs. | |
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] | |
# ada_id_embs: [BS, 16, 768]. | |
# return_emb_types: a list of strings, each string is among | |
# ['full', 'core', 'full_pad', 'full_half_pad']. | |
ada_id_embs, = \ | |
self.inverse_img_prompt_embs(faceid2img_prompt_embs, | |
list_extra_words=None, | |
return_emb_types=['core'], | |
hidden_state_layer_weights=hidden_state_layer_weights, | |
enable_static_img_suffix_embs=enable_static_img_suffix_embs) | |
ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs) | |
elif raw_id_embs is not None: | |
# id_embs: [BS, 384] -> [BS, 18, 768]. | |
# obj_proj_in is expected to project the DINO object features to | |
# the token embedding space. So no need to use prompt2token_proj. | |
id_embs = self.obj_proj_in(raw_id_embs) | |
else: | |
breakpoint() | |
else: | |
# Otherwise, context is the ad-hoc CLIP image features. | |
# id_embs: [BS, 257, 768]. | |
id_embs = self.bg_proj_in(clip_features) | |
if self.placeholder_is_bg: | |
id_embs = id_embs + self.pos_embs_ln(self.pos_embs) | |
latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1) | |
# If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs, | |
# and we take the first 16*4=64. | |
# Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768]. | |
# prompt_translator: better named as bg_prompt_translator. It maps the bg features | |
# to bg prompt embeddings. | |
with torch.set_grad_enabled(self.training): | |
id_embs_out = self.prompt_translator(latent_queries, id_embs) | |
adaface_out_embs = id_embs_out * self.output_scale # * 0.036 | |
else: | |
adaface_out_embs = ada_id_embs | |
# If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings. | |
if out_id_embs_cfg_scale != 1: | |
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768]. | |
# NOTE: Never do cfg on static image suffix embeddings. | |
# So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX, | |
# even if enable_static_img_suffix_embs=True. | |
pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device) | |
adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \ | |
+ pad_embeddings * (1 - out_id_embs_cfg_scale) | |
return adaface_out_embs | |
def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device): | |
if learnable_hidden_state_weights_scheme == 'none': | |
self.hidden_state_layer_weights = None | |
# A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input. | |
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1) | |
print("hidden_state_layer_weights is set to None.") | |
elif learnable_hidden_state_weights_scheme == 'per-layer': | |
# Learnable weights of the last 3 layers, initialized to putting more focus on the last layer. | |
# 'per-layer': Different weights for different layers, but the same for different channels. | |
# hidden_state_layer_weights: [3, 1]. | |
self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device), | |
requires_grad=True) | |
# A gradient scaler of 5 makes the gradients on hidden_state_layer_weights 5 times larger. | |
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5) | |
print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.") | |
else: | |
breakpoint() | |
def extend_prompt2token_proj_attention(self, prompt2token_proj_attention_multipliers=None, | |
begin_layer_idx=-1, end_layer_idx=-1, multiplier=1, perturb_std=0.1): | |
if begin_layer_idx == -1: | |
begin_layer_idx = 0 | |
if end_layer_idx == -1: | |
end_layer_idx = 11 | |
if prompt2token_proj_attention_multipliers is None and multiplier == 1: | |
print("prompt2token_proj_attention_multipliers are all 1. No extension is done.") | |
return | |
elif prompt2token_proj_attention_multipliers is None: | |
# prompt2token_proj_attention_multipliers are relative to the current prompt2token_proj. | |
prompt2token_proj_attention_multipliers = [1] * 12 | |
for i in range(begin_layer_idx, end_layer_idx+1): | |
prompt2token_proj_attention_multipliers[i] = multiplier | |
# Otherwise, use the given prompt2token_proj_attention_multipliers. | |
num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(prompt2token_proj_attention_multipliers, perturb_std) | |
# Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). | |
for i in range(begin_layer_idx, end_layer_idx+1): | |
self.prompt2token_proj_attention_multipliers[i] *= prompt2token_proj_attention_multipliers[i] | |
print(f"{num_extended_layers} layers in prompt2token_proj_attention are extended by {prompt2token_proj_attention_multipliers}") | |
return num_extended_layers | |
def squeeze_prompt2token_proj_attention(self, prompt2token_proj_attention_divisors=None, | |
begin_layer_idx=-1, end_layer_idx=-1, divisor=1): | |
if begin_layer_idx == -1: | |
begin_layer_idx = 0 | |
if end_layer_idx == -1: | |
end_layer_idx = 11 | |
if prompt2token_proj_attention_divisors is None and divisor == 1: | |
print("prompt2token_proj_attention_divisors are all 1. No squeezing is done.") | |
return | |
elif prompt2token_proj_attention_divisors is None: | |
prompt2token_proj_attention_divisors = [1] * 12 | |
for i in range(begin_layer_idx, end_layer_idx+1): | |
prompt2token_proj_attention_divisors[i] = divisor | |
# Otherwise, use the given prompt2token_proj_attention_divisors. | |
num_squeezed_layers = self.prompt2token_proj.squeeze_clip_attention_MKV_divisor(prompt2token_proj_attention_divisors) | |
# Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). | |
for i in range(begin_layer_idx, end_layer_idx+1): | |
self.prompt2token_proj_attention_multipliers[i] //= prompt2token_proj_attention_divisors[i] | |
print(f"{num_squeezed_layers} layers in prompt2token_proj_attention are squeezed by {prompt2token_proj_attention_divisors}") | |
return num_squeezed_layers | |
def freeze_prompt2token_proj(self): | |
# Only applicable to fg basis generator. | |
if self.placeholder_is_bg: | |
return | |
# If bg, then prompt2token_proj is set to None. Therefore no need to freeze it. | |
# Then we don't have to check whether it's for subj or bg. | |
if self.prompt2token_proj_grad_scale == 0: | |
frozen_components_name = 'all' | |
frozen_param_set = self.prompt2token_proj.named_parameters() | |
else: | |
frozen_components_name = 'token_pos_embeddings' | |
frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters() | |
if self.prompt2token_proj is not None: | |
frozen_param_names = [] | |
for param_name, param in frozen_param_set: | |
if param.requires_grad: | |
param.requires_grad = False | |
frozen_param_names.append(param_name) | |
# If param is already frozen, then no need to freeze it again. | |
print(f"{frozen_components_name} {len(frozen_param_names)} params in Subj prompt2token_proj is frozen.") | |
#print(f"Frozen parameters:\n{frozen_param_names}") | |
def patch_old_subj_basis_generator_ckpt(self): | |
# Fix compatability with the previous version. | |
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'): | |
self.bg_prompt_translator_has_to_out_proj = False | |
if not hasattr(self, 'num_out_embs'): | |
self.num_out_embs = -1 | |
if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'): | |
self.N_ID = self.num_id_vecs | |
if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'): | |
self.num_nonface_in_id_vecs = self.N_ID | |
if not hasattr(self, 'dtype'): | |
self.dtype = torch.float32 | |
if self.placeholder_is_bg: | |
if not hasattr(self, 'pos_embs') or self.pos_embs is None: | |
self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, self.output_dim)) | |
if not hasattr(self, 'latent_queries') or self.latent_queries is None: | |
self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, self.output_dim)) | |
# Background encoder doesn't require initializing text components. | |
else: | |
self.initialize_hidden_state_layer_weights('per-layer', 'cpu') | |
if not hasattr(self, 'prompt2token_proj_attention_multipliers'): | |
# Please manually set prompt2token_proj_attention_multipliers in the ckpt. | |
breakpoint() | |
self.initialize_text_components(max_prompt_length=77, num_id_vecs=self.N_ID, | |
num_static_img_suffix_embs=self.N_SFX, | |
img_prompt_dim=self.output_dim) | |
def __repr__(self): | |
type_sig = 'subj' if not self.placeholder_is_bg else 'bg' | |
return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \ | |
f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}" | |