# Borrowed from ip-adapter resampler.py. # https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py import math import torch from torch import nn from einops import rearrange from einops.layers.torch import Rearrange from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig from torch import einsum from adaface.util import gen_gradient_scaler from adaface.arc2face_models import CLIPTextModelWrapper def reshape_tensor(x, num_heads): bs, length, width = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, num_heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2).contiguous() # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs, num_heads, length, -1) return x # FFN. Added a Dropout layer at the end, so that it can still load the old ckpt. def FeedForward(dim, mult=4, p_dropout=0.1): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), nn.Dropout(p_dropout), ) # IP-Adapter FaceID class. Only used in knn-faces.py. # From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py class IP_MLPProjModel(nn.Module): def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): super().__init__() self.cross_attention_dim = cross_attention_dim self.num_tokens = num_tokens self.proj = nn.Sequential( nn.Linear(id_embeddings_dim, id_embeddings_dim*2), nn.GELU(), nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), ) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, id_embeds): x = self.proj(id_embeds) x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) x = self.norm(x) return x # group_dim: the tensor dimension that corresponds to the multiple groups. class LearnedSoftAggregate(nn.Module): def __init__(self, num_feat, group_dim, keepdim=False): super(LearnedSoftAggregate, self).__init__() self.group_dim = group_dim # num_feat = 1: element-wise score function & softmax. # num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor. self.num_feat = num_feat self.feat2score = nn.Linear(num_feat, 1, bias=False) self.keepdim = keepdim def forward(self, x, score_basis=None): # If there's only one mode, do nothing. if x.shape[self.group_dim] == 1: if self.keepdim: return x else: return x.squeeze(self.group_dim) # Assume the last dim of x is the feature dim. if score_basis is None: score_basis = x if self.num_feat == 1: mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1) else: mode_scores = self.feat2score(score_basis) attn_probs = mode_scores.softmax(dim=self.group_dim) x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim) return x_aggr def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes, num_output_vecs, elementwise_affine=True, p_dropout=0.1): return nn.Sequential( # Project to [BS, lora_rank * output_dim * num_modes]. # It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456. nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False), # Reshape to [BS, lora_rank, output_dim]. Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim), nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), # Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim]. LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \ else Rearrange('b () q d -> b q d'), nn.Dropout(p_dropout), # Permute to [BS, output_dim, lora_rank]. Rearrange('b q d -> b d q'), # Project to [BS, output_dim, num_output_vecs]. nn.Linear(lora_rank, num_output_vecs, bias=False), # Permute to [BS, num_output_vecs, output_dim]. Rearrange('b d q -> b q d'), nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), nn.Dropout(p_dropout), ) def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1): return nn.Sequential( # Project to [BS, num_output_vecs * output_dim]. nn.Linear(input_dim, expansion_ratio * output_dim, bias=False), # Reshape to [BS, num_output_vecs, output_dim]. Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim), nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), nn.Dropout(p_dropout), ) # Input: [BS, N, D]. def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1): if output_dim == -1: output_dim = input_dim return nn.Sequential( nn.Linear(input_dim, output_dim * num_modes, bias=False), # Reshape to [BS, num_output_vecs, output_dim]. Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim), nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), # If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes. LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \ else Rearrange('b n () d -> b n d'), nn.Dropout(p_dropout), ) # Low-rank to high-rank transformation. def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1): return nn.Sequential( # Permute to [BS, output_dim, lora_rank]. Rearrange('b q d -> b d q'), # Project to [BS, output_dim, hira_rank]. nn.Linear(lora_rank, hira_rank * num_modes, bias=False), # Reshape and permute to [BS, num_modes, num_output_vecs, output_dim]. Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank), nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), # Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim]. LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \ else Rearrange('b () q d -> b q d'), nn.Dropout(p_dropout), ) class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True): super().__init__() self.scale = dim_head**-0.5 self.dim_head = dim_head self.num_heads = num_heads inner_dim = dim_head * num_heads self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine) self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latent_queries): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latent_queries = self.norm2(latent_queries) b, l, _ = latent_queries.shape q = self.to_q(latent_queries) kv_input = torch.cat((x, latent_queries), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q = reshape_tensor(q, self.num_heads) k = reshape_tensor(k, self.num_heads) v = reshape_tensor(v, self.num_heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = attn @ v out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) class CrossAttention(nn.Module): # output_dim is always the same as input_dim. # num_q only matters when q_aware_to_v is True. # If q_aware_to_v is False, query x in forward() is still usable. def __init__(self, input_dim, num_heads=6, p_dropout=0.05, identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True, q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64, identity_to_out=False, out_has_skip=False): super().__init__() dim_head = input_dim // num_heads inner_dim = dim_head * num_heads self.num_heads = num_heads self.q_aware_to_v = q_aware_to_v self.v_has_skip = v_has_skip self.to_q = nn.Sequential( nn.Linear(input_dim, inner_dim, bias=False), nn.LayerNorm(inner_dim, elementwise_affine=True) ) if not identity_to_q else nn.Identity() self.to_k = nn.Sequential( nn.Linear(input_dim, inner_dim, bias=False), nn.LayerNorm(inner_dim, elementwise_affine=True) ) if not identity_to_k else nn.Identity() self.v_repeat = v_repeat self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104. # If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim. # Otherwise, self.to_v consists of a single projection of input_dim to inner_dim. if q_aware_to_v: # all_q_mid: 104 * 64 = 6656. all_q_mid = num_q_group * q_aware_to_v_lora_rank self.to_v = nn.Sequential( # number of params: 768 * 6656 = 5,111,808. # Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656]. # Each 768-dim vec is dispersed into 104 64-dim vecs. nn.Linear(input_dim, all_q_mid, bias=False), nn.LayerNorm(all_q_mid, elementwise_affine=True), # Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1. Rearrange('b n q -> b q n', q=all_q_mid), # Each q_aware_to_v projection has its own linear layer. # The total number of parameters will be 6656*768 = 5,111,808. # Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim. nn.Conv1d( in_channels=all_q_mid, out_channels=num_q_group * input_dim, kernel_size=1, groups=num_q_group, bias=False, ), # Output: [BS, 104, 16, 768]. Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim), nn.LayerNorm(input_dim, elementwise_affine=True), ) else: self.to_v = nn.Sequential( nn.Linear(input_dim, inner_dim, bias=False), nn.LayerNorm(inner_dim, elementwise_affine=True) ) if not identity_to_v else nn.Identity() if identity_to_out: assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False." if identity_to_out: self.to_out = nn.Identity() else: self.to_out = nn.Sequential( nn.Linear(input_dim, input_dim, bias=False), nn.Dropout(p_dropout), nn.LayerNorm(inner_dim, elementwise_affine=True) ) self.out_has_skip = out_has_skip self.attn_drop = nn.Dropout(p_dropout) def forward(self, x, context=None, attn_mat=None, return_attn=False): h = self.num_heads if context is None: context = x if attn_mat is None: # q: [BS, Q, D] -> [BS, Q, D]. q = self.to_q(x) # k: [BS, L, D] -> [BS, L, D]. k = self.to_k(context) # q: [6, 512, 128], k: [6, 17, 128]. q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k)) if self.q_aware_to_v: # context: [BS, L, D]. v: [BS, Q, L, D]. # There are effectively Q to_v projections. v = self.to_v(context) if self.v_has_skip: v = v + context.unsqueeze(1) else: # v: [BS, L, D]. v = self.to_v(context) if self.v_has_skip: v = v + context #print(v.shape) if self.q_aware_to_v: # v: [6, 64, 17, 128]. # v is query-specific, so there's an extra dim for the query. v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h).contiguous() # Each v is for a query group with 512/64 = 8 queries. # So each v is repeated 8 times to match the number of queries. # v: [6, 64, 17, 128] -> [6, 512, 17, 128]. v = v.repeat(1, self.v_repeat, 1, 1) else: v = rearrange(v, 'b n (h d) -> (b h) n d', h=h).contiguous() if attn_mat is None: scale = q.size(-1) ** -0.25 sim = einsum('b i d, b j d -> b i j', q * scale, k * scale) # sim: [6, 64, 17]. 6: bs 1 * h 6. # attention, what we cannot get enough of # NOTE: the normalization is done across tokens, not across pixels. # So for each pixel, the sum of attention scores across tokens is 1. attn = sim.softmax(dim=-1) attn = self.attn_drop(attn) #print(attn.std()) else: attn = attn_mat if self.q_aware_to_v: # attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128]. # out is combined with different attn weights and v for different queries. out = einsum('b i j, b i j d -> b i d', attn, v) else: # v: [6, 17, 128]. out: [6, 32, 128]. out = einsum('b i j, b j d -> b i d', attn, v) # [6, 32, 128] -> [1, 32, 768]. out = rearrange(out, '(b h) n d -> b n (h d)', h=h).contiguous() if self.out_has_skip: out = self.to_out(out) + out else: out = self.to_out(out) if return_attn: return out, attn else: return out class ImgPrompt2TextPrompt(nn.Module): def __init__(self, placeholder_is_bg, num_id_vecs, dtype=torch.float32, *args, **kwargs): super().__init__() self.N_ID = num_id_vecs # If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components(). self.N_SFX = 0 if not placeholder_is_bg: self.initialize_text_components(*args, **kwargs) # prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**. # prompt2token_proj is with the same architecture as the original arc2face text encoder, # but retrained to do inverse mapping. # To be initialized in the subclass. self.prompt2token_proj = None self.dtype = dtype def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768): self.N_SFX = num_static_img_suffix_embs # We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs. # So it's OK that static_img_suffix_embs is larger than required number num_static_img_suffix_embs. # This holds even if num_static_img_suffix_embs is 0. if hasattr(self, 'static_img_suffix_embs') and self.static_img_suffix_embs is not None: if self.static_img_suffix_embs.shape[1] == self.N_SFX: print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.") elif self.static_img_suffix_embs.shape[1] < self.N_SFX: print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.") self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) elif self.N_SFX > 0: # self.static_img_suffix_embs.shape[1] > self.N_SFX > 0. print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.") self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX]) else: # self.static_img_suffix_embs.shape[1] > self.N_SFX == 0. print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.") self.static_img_suffix_embs = None else: if self.N_SFX > 0: # Either static_img_suffix_embs does not exist or is None, # or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare, # so we don't consider to reuse and extend a shorter static_img_suffix_embs). # So we reinitialize it. self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) else: # If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance. self.static_img_suffix_embs = None # Implement a separate initialization function, so that it can be called from SubjBasisGenerator # after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator # ckpts which were not subclassed from ImgPrompt2TextPrompt. def initialize_text_components(self, max_prompt_length=77, num_id_vecs=16, num_static_img_suffix_embs=0, img_prompt_dim=768): self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim) self.max_prompt_length = max_prompt_length self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # clip_text_embeddings: CLIPTextEmbeddings instance. clip_text_embeddings = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").text_model.embeddings # clip_text_embeddings() and clip_text_embeddings.token_embedding() differ in that # clip_text_embeddings() adds positional embeddings, while clip_text_embeddings.token_embedding() doesn't. # Adding positional embeddings seems to help somewhat. # pad_tokens: pad_token_id 49407 repeated 77 times. # pad_token_id is the EOS token. But BOS is 49406. pad_tokens = torch.tensor([self.tokenizer.pad_token_id]).repeat(self.max_prompt_length) # pad_embeddings: [77, 768]. # pad_embeddings is still on CPU. But should be moved to GPU automatically. # Note: detach pad_embeddings from the computation graph, otherwise # deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail. self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach() # image prompt space -> text prompt space. # return_emb_types: a list of strings, each string is among # ['full', 'core', 'full_pad', 'full_half_pad']. def inverse_img_prompt_embs(self, face_prompt_embs, list_extra_words, return_emb_types, hidden_state_layer_weights=None, enable_static_img_suffix_embs=False): ''' face_prompt_embs: (BS, self.N_ID, 768), in the image prompt space. Only the core embeddings, no paddings. list_extra_words: None or [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt. ''' if list_extra_words is not None: if len(list_extra_words) != len(face_prompt_embs): if len(face_prompt_embs) > 1: print("Warn: list_extra_words has different length as face_prompt_embs.") if len(list_extra_words) == 1: list_extra_words = list_extra_words * len(face_prompt_embs) else: breakpoint() else: # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_comp_prompt_distillation. # But list_extra_words always corresponds to the actual batch size. So we only take the first element. list_extra_words = list_extra_words[:1] for extra_words in list_extra_words: assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words." # 16 or 4 ", " are placeholders for face_prompt_embs. prompt_templates = [ "photo of a " + ", " * self.N_ID + list_extra_words[i] for i in range(len(list_extra_words)) ] else: # 16 or 4 ", " are placeholders for face_prompt_embs. # No extra words are added to the prompt. So we add 2 more ", " to the template to keep # the number of tokens roughly the same as when extra words are added. prompt_templates = [ "photo of a " + ", " * (self.N_ID + 2) for _ in range(len(face_prompt_embs)) ] # This step should be quite fast, and there's no need to cache the input_ids. # input_ids: [BS, 77]. input_ids = self.tokenizer( prompt_templates, truncation=True, padding="max_length", max_length=self.max_prompt_length, return_tensors="pt", ).input_ids.to(face_prompt_embs.device) face_prompt_embs_orig_dtype = face_prompt_embs.dtype face_prompt_embs = face_prompt_embs.to(self.dtype) ID_END = 4 + self.N_ID PAD_BEGIN = ID_END + self.N_SFX + 2 # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping). token_embs = self.prompt2token_proj(input_ids=input_ids, return_token_embs=True) # token 4: first ", " in the template prompt. # Replace embeddings of 16 or 4 placeholder ", " with face_prompt_embs. token_embs[:, 4:ID_END] = face_prompt_embs # Only when do_unet_distill == True, we append the static image suffix embeddings. # Otherwise, static image suffix embeddings are ignored, # and token_embs[:, ID_END:ID_END+self.N_SFX] are the filler embeddings of the # extra ", " in the template prompt. if enable_static_img_suffix_embs and self.N_SFX > 0: # Put the static image suffix embeddings right after face_prompt_embs. token_embs[:, ID_END:ID_END+self.N_SFX] = self.static_img_suffix_embs[:, :self.N_SFX] # This call does the ordinary CLIP text encoding pass. prompt_embeds = self.prompt2token_proj( input_ids=input_ids, input_token_embs=token_embs, hidden_state_layer_weights=hidden_state_layer_weights, return_token_embs=False )[0] # Restore the original dtype of prompt_embeds: float16 -> float32. prompt_embeds = prompt_embeds.to(face_prompt_embs_orig_dtype) # token 4: first ", " in the template prompt. # When N_ID == 16, # prompt_embeds 4:20 are the most important 16 embeddings that contain the subject's identity. # 20:22 are embeddings of the (at most) two extra words. # [N, 77, 768] -> [N, 16, 768] if enable_static_img_suffix_embs: core_prompt_embs = prompt_embeds[:, 4:ID_END+self.N_SFX] else: core_prompt_embs = prompt_embeds[:, 4:ID_END] if list_extra_words is not None: # [N, 16, 768] -> [N, 18, 768] extra_words_embs = prompt_embeds[:, ID_END+self.N_SFX:PAD_BEGIN] core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1) returned_prompt_embs = [] for emb_type in return_emb_types: if emb_type == 'full': returned_prompt_embs.append(prompt_embeds) elif emb_type == 'full_half_pad': prompt_embeds2 = prompt_embeds.clone() # PAD_BEGIN is 22 or 10. Also exclude the last EOS token. # So we subtract max_prompt_length by (PAD_BEGIN + 1). PADS = self.max_prompt_length - PAD_BEGIN - 1 if PADS >= 2: # Fill half of the remaining embeddings with pad embeddings. prompt_embeds2[:, PAD_BEGIN:PAD_BEGIN+PADS//2] = self.pad_embeddings[PAD_BEGIN:PAD_BEGIN+PADS//2] returned_prompt_embs.append(prompt_embeds2) elif emb_type == 'full_pad': prompt_embeds2 = prompt_embeds.clone() # Replace the PAD_BEGIN-th to the second last embeddings with pad embeddings. # Skip replacing the last embedding, which might has special roles. # (Although all padding tokens are the same EOS, the last token might acquire special semantics # due to its special position.) prompt_embeds2[:, PAD_BEGIN:-1] = self.pad_embeddings[PAD_BEGIN:-1] returned_prompt_embs.append(prompt_embeds2) elif emb_type == 'full_zeroed_extra': prompt_embeds2 = prompt_embeds.clone() # Only add two pad embeddings. The remaining embeddings are set to 0. # Make the positional embeddings align with the actual positions. prompt_embeds2[:, 22:24] = self.pad_embeddings[22:24] prompt_embeds2[:, 24:-1] = 0 returned_prompt_embs.append(prompt_embeds2) elif emb_type == 'core': returned_prompt_embs.append(core_prompt_embs) else: breakpoint() return returned_prompt_embs class SubjBasisGenerator(ImgPrompt2TextPrompt): def __init__( self, # number of cross-attention heads of the bg prompt translator. # Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14: # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json num_bg_encoder_heads=6, # number of subject input identity vectors (only when the subject is not face), # or number of background input identity vectors (no matter the subject is face or not). # 257: 257 CLIP tokens. num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 }, num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4. num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings. bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above. obj_embedding_dim=384, # DINO object feature dimension for objects. output_dim=768, # CLIP text embedding input dimension. placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens. prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj. learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer. bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection. ): # If not placeholder_is_bg, then it calls initialize_text_components() in the superclass. super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs, max_prompt_length=77, num_static_img_suffix_embs=num_static_img_suffix_embs, img_prompt_dim=output_dim) self.placeholder_is_bg = placeholder_is_bg self.num_out_embs = self.N_ID + self.N_SFX self.output_dim = output_dim # num_nonface_in_id_vecs should be the number of core ID embs, 16. # However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set. self.num_nonface_in_id_vecs = num_nonface_in_id_vecs['bg'] if placeholder_is_bg else num_nonface_in_id_vecs['subj'] self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj if not self.placeholder_is_bg: # [1, 384] -> [1, 16, 768]. # TODO: use CLIPTextModelWrapper as obj_proj_in. self.obj_proj_in = ExpandEmbs(obj_embedding_dim, output_dim, expansion_ratio=self.num_nonface_in_id_vecs) # ** prompt2token_proj does the actual job: ** # it is the inverse projection that maps from faceid2img_prompt_embs to adaface_prompt_embs. # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings). # If self.placeholder_is_bg: prompt2token_proj is set to None. # Use an attention dropout of 0.2 to increase robustness. clip_dropout_config = None #CLIPTextConfig.from_pretrained('openai/clip-vit-large-patch14', attention_dropout=0.05, dropout=0.05) self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14', config=clip_dropout_config) self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale) print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.") # If prompt2token_proj_grad_scale is 0, freeze all params in prompt2token_proj. # Otherwise, only freeze token and positional embeddings of the original CLIPTextModel. self.freeze_prompt2token_proj() # These multipliers are relative to the original CLIPTextModel. self.prompt2token_proj_attention_multipliers = [1] * 12 self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu') self.bg_proj_in = None self.pos_embs = self.pos_embs_ln = self.latent_queries = self.latent_queries_ln = None else: # For background placeholders, face and object embeddings are not used as they are foreground. self.obj_proj_in = None self.bg_proj_in = nn.Sequential( nn.Linear(bg_image_embedding_dim, output_dim, bias=False), nn.LayerNorm(output_dim), ) self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, output_dim)) self.pos_embs_ln = nn.LayerNorm(output_dim) self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim)) self.latent_queries_ln = nn.LayerNorm(output_dim) identity_to_v = False v_has_skip = not identity_to_v # True identity_to_out = not bg_prompt_translator_has_to_out_proj # True out_has_skip = not identity_to_out # False # prompt_translator maps the clip image features (of the background) to the prompt embedding space. # It is only used during training when placeholder_is_bg is True. # prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection. # dim=768, num_bg_encoder_heads=6. self.prompt_translator = \ CrossAttention(input_dim=output_dim, num_heads=num_bg_encoder_heads, p_dropout=0.05, identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v, q_aware_to_v=False, v_has_skip=v_has_skip, num_q=0, # When not q_aware_to_v, num_q is not referenced. identity_to_out=identity_to_out, out_has_skip=out_has_skip) self.output_scale = output_dim ** -0.5 ''' prompt_translator: CLIPEncoder # https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566 # CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being (0): CLIPEncoderLayer( (self_attn): CLIPAttention( (k_proj): Linear(in_features=768, out_features=768, bias=True) (v_proj): Linear(in_features=768, out_features=768, bias=True) (q_proj): Linear(in_features=768, out_features=768, bias=True) (out_proj): Linear(in_features=768, out_features=768, bias=True) ) (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): CLIPMLP( (activation_fn): QuickGELUActivation() (fc1): Linear(in_features=768, out_features=3072, bias=True) (fc2): Linear(in_features=3072, out_features=768, bias=True) ) (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) ''' print(repr(self)) # raw_id_embs: only used when the subject is non-faces. In that case it's DINO embeddings. # Otherwise, raw_id_embs is not used. # faceid2img_prompt_embs: [BS, 16, 768], the core ID prompt embeddings generated by ID2ImgPrompt. def forward(self, faceid2img_prompt_embs, clip_features=None, raw_id_embs=None, out_id_embs_cfg_scale=1.0, is_face=True, enable_static_img_suffix_embs=False): if not self.placeholder_is_bg: BS = faceid2img_prompt_embs.shape[0] else: # If bg, then faceid2img_prompt_embs is set to None, but clip_features is not None. BS = clip_features.shape[0] clip_features = clip_features.to(self.dtype) # No need to use raw_id_embs if placeholder_is_bg. if not self.placeholder_is_bg: if is_face: assert faceid2img_prompt_embs is not None # id2img_embs has been projected to the (modified) prompt embedding space # by ID2AdaPrompt::map_init_id_to_img_prompt_embs(). This prompt embedding space is modified because # the ID2ImgPrompt module (at least when it's arc2face) may have finetuned the # text encoder and the U-Net. # in embedding_manager: [BS, 16, 768] -> [BS, 77, 768]. # faceid2img_prompt_embs is part of id2img_embs: [BS, 77, 768] -> [BS, 16, 768]. # adaface_prompt_embs is projected to the prompt embedding spaces. This is the # original U-Net prompt embedding space. # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights) # faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space. with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0): # If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens # and (at most) two extra words in adaface_prompt_embs, without BOS and EOS. # If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs. # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] # ada_id_embs: [BS, 16, 768]. # return_emb_types: a list of strings, each string is among # ['full', 'core', 'full_pad', 'full_half_pad']. ada_id_embs, = \ self.inverse_img_prompt_embs(faceid2img_prompt_embs, list_extra_words=None, return_emb_types=['core'], hidden_state_layer_weights=hidden_state_layer_weights, enable_static_img_suffix_embs=enable_static_img_suffix_embs) ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs) elif raw_id_embs is not None: # id_embs: [BS, 384] -> [BS, 18, 768]. # obj_proj_in is expected to project the DINO object features to # the token embedding space. So no need to use prompt2token_proj. id_embs = self.obj_proj_in(raw_id_embs) else: breakpoint() else: # Otherwise, context is the ad-hoc CLIP image features. # id_embs: [BS, 257, 768]. id_embs = self.bg_proj_in(clip_features) if self.placeholder_is_bg: id_embs = id_embs + self.pos_embs_ln(self.pos_embs) latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1) # If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs, # and we take the first 16*4=64. # Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768]. # prompt_translator: better named as bg_prompt_translator. It maps the bg features # to bg prompt embeddings. with torch.set_grad_enabled(self.training): id_embs_out = self.prompt_translator(latent_queries, id_embs) adaface_out_embs = id_embs_out * self.output_scale # * 0.036 else: adaface_out_embs = ada_id_embs # If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings. if out_id_embs_cfg_scale != 1: # pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768]. # NOTE: Never do cfg on static image suffix embeddings. # So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX, # even if enable_static_img_suffix_embs=True. pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device) adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \ + pad_embeddings * (1 - out_id_embs_cfg_scale) return adaface_out_embs def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device): if learnable_hidden_state_weights_scheme == 'none': self.hidden_state_layer_weights = None # A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input. self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1) print("hidden_state_layer_weights is set to None.") elif learnable_hidden_state_weights_scheme == 'per-layer': # Learnable weights of the last 3 layers, initialized to putting more focus on the last layer. # 'per-layer': Different weights for different layers, but the same for different channels. # hidden_state_layer_weights: [3, 1]. self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device), requires_grad=True) # A gradient scaler of 5 makes the gradients on hidden_state_layer_weights 5 times larger. self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5) print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.") else: breakpoint() def extend_prompt2token_proj_attention(self, prompt2token_proj_attention_multipliers=None, begin_layer_idx=-1, end_layer_idx=-1, multiplier=1, perturb_std=0.1): if begin_layer_idx == -1: begin_layer_idx = 0 if end_layer_idx == -1: end_layer_idx = 11 if prompt2token_proj_attention_multipliers is None and multiplier == 1: print("prompt2token_proj_attention_multipliers are all 1. No extension is done.") return elif prompt2token_proj_attention_multipliers is None: # prompt2token_proj_attention_multipliers are relative to the current prompt2token_proj. prompt2token_proj_attention_multipliers = [1] * 12 for i in range(begin_layer_idx, end_layer_idx+1): prompt2token_proj_attention_multipliers[i] = multiplier # Otherwise, use the given prompt2token_proj_attention_multipliers. num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(prompt2token_proj_attention_multipliers, perturb_std) # Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). for i in range(begin_layer_idx, end_layer_idx+1): self.prompt2token_proj_attention_multipliers[i] *= prompt2token_proj_attention_multipliers[i] print(f"{num_extended_layers} layers in prompt2token_proj_attention are extended by {prompt2token_proj_attention_multipliers}") return num_extended_layers def squeeze_prompt2token_proj_attention(self, prompt2token_proj_attention_divisors=None, begin_layer_idx=-1, end_layer_idx=-1, divisor=1): if begin_layer_idx == -1: begin_layer_idx = 0 if end_layer_idx == -1: end_layer_idx = 11 if prompt2token_proj_attention_divisors is None and divisor == 1: print("prompt2token_proj_attention_divisors are all 1. No squeezing is done.") return elif prompt2token_proj_attention_divisors is None: prompt2token_proj_attention_divisors = [1] * 12 for i in range(begin_layer_idx, end_layer_idx+1): prompt2token_proj_attention_divisors[i] = divisor # Otherwise, use the given prompt2token_proj_attention_divisors. num_squeezed_layers = self.prompt2token_proj.squeeze_clip_attention_MKV_divisor(prompt2token_proj_attention_divisors) # Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). for i in range(begin_layer_idx, end_layer_idx+1): self.prompt2token_proj_attention_multipliers[i] //= prompt2token_proj_attention_divisors[i] print(f"{num_squeezed_layers} layers in prompt2token_proj_attention are squeezed by {prompt2token_proj_attention_divisors}") return num_squeezed_layers def freeze_prompt2token_proj(self): # Only applicable to fg basis generator. if self.placeholder_is_bg: return # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it. # Then we don't have to check whether it's for subj or bg. if self.prompt2token_proj_grad_scale == 0: frozen_components_name = 'all' frozen_param_set = self.prompt2token_proj.named_parameters() else: frozen_components_name = 'token_pos_embeddings' frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters() if self.prompt2token_proj is not None: frozen_param_names = [] for param_name, param in frozen_param_set: if param.requires_grad: param.requires_grad = False frozen_param_names.append(param_name) # If param is already frozen, then no need to freeze it again. print(f"{frozen_components_name} {len(frozen_param_names)} params in Subj prompt2token_proj is frozen.") #print(f"Frozen parameters:\n{frozen_param_names}") def patch_old_subj_basis_generator_ckpt(self): # Fix compatability with the previous version. if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'): self.bg_prompt_translator_has_to_out_proj = False if not hasattr(self, 'num_out_embs'): self.num_out_embs = -1 if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'): self.N_ID = self.num_id_vecs if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'): self.num_nonface_in_id_vecs = self.N_ID if not hasattr(self, 'dtype'): self.dtype = torch.float32 if self.placeholder_is_bg: if not hasattr(self, 'pos_embs') or self.pos_embs is None: self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, self.output_dim)) if not hasattr(self, 'latent_queries') or self.latent_queries is None: self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, self.output_dim)) # Background encoder doesn't require initializing text components. else: self.initialize_hidden_state_layer_weights('per-layer', 'cpu') if not hasattr(self, 'prompt2token_proj_attention_multipliers'): # Please manually set prompt2token_proj_attention_multipliers in the ckpt. breakpoint() self.initialize_text_components(max_prompt_length=77, num_id_vecs=self.N_ID, num_static_img_suffix_embs=self.N_SFX, img_prompt_dim=self.output_dim) def __repr__(self): type_sig = 'subj' if not self.placeholder_is_bg else 'bg' return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \ f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"