adaface-neurips
Integrate do_neg_id_prompt_weight, fix bugs, various refinements
f0b9ada
raw
history blame
17.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from diffusers import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from transformers import CLIPVisionModel
from dataclasses import dataclass
from typing import Optional, Tuple
from transformers.utils import ModelOutput
import numpy as np
import argparse
from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline
from diffusers import (
UNet2DConditionModel,
DDIMScheduler,
)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
# perturb_tensor() adds a fixed amount of noise to the tensor.
def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=False,
std_dim=-1, norm_dim=-1, verbose=True):
orig_ts = ts
if perturb_std_is_relative:
ts_std_mean = ts.std(dim=std_dim).mean().detach()
perturb_std *= ts_std_mean
# ts_std_mean: 50~80 for unnormalized images, perturb_std: 2.5-4 for 0.05 noise.
if verbose:
print(f"ts_std_mean: {ts_std_mean:.03f}, perturb_std: {perturb_std:.03f}")
noise = torch.randn_like(ts) * perturb_std
if keep_norm:
orig_norm = ts.norm(dim=norm_dim, keepdim=True)
ts = ts + noise
new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
ts = ts * orig_norm / (new_norm + 1e-8)
else:
ts = ts + noise
if verbose:
print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}")
return ts
def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_dim=-1):
ts = torch.from_numpy(np_array).to(dtype=torch.float32)
ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
return ts.numpy().astype(np_array.dtype)
def calc_stats(emb_name, embeddings, mean_dim=0):
print("%s:" %emb_name)
repeat_count = [1] * embeddings.ndim
repeat_count[mean_dim] = embeddings.shape[mean_dim]
# Average across the mean_dim dim.
# Make emb_mean the same size as embeddings, as required by F.l1_loss.
emb_mean = embeddings.mean(mean_dim, keepdim=True).repeat(repeat_count)
l1_loss = F.l1_loss(embeddings, emb_mean)
# F.l2_loss doesn't take sqrt. So the loss is very small.
# Compute it manually.
l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item()))
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
# Revised from RevGrad, by removing the grad negation.
class ScaleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, alpha_, debug=False):
ctx.save_for_backward(alpha_, debug)
output = input_
if debug:
print(f"input: {input_.abs().mean().item()}")
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
# saved_tensors returns a tuple of tensors.
alpha_, debug = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_output2 = grad_output * alpha_
if debug:
print(f"grad_output2: {grad_output2.abs().mean().item()}")
else:
grad_output2 = None
return grad_output2, None, None
class GradientScaler(nn.Module):
def __init__(self, alpha=1., debug=False, *args, **kwargs):
"""
A gradient scaling layer.
This layer has no parameters, and simply scales the gradient in the backward pass.
"""
super().__init__(*args, **kwargs)
self._alpha = torch.tensor(alpha, requires_grad=False)
self._debug = torch.tensor(debug, requires_grad=False)
def forward(self, input_):
_debug = self._debug if hasattr(self, '_debug') else False
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
def gen_gradient_scaler(alpha, debug=False):
if alpha == 1:
return nn.Identity()
if alpha > 0:
return GradientScaler(alpha, debug=debug)
else:
assert alpha == 0
# Don't use lambda function here, otherwise the object can't be pickled.
return torch.detach
def pad_image_obj_to_square(image_obj, new_size=-1):
# Remove alpha channel if it exists.
if image_obj.mode == 'RGBA':
image_obj = image_obj.convert('RGB')
# Pad input to be width == height
width, height = orig_size = image_obj.size
new_width, new_height = max(width, height), max(width, height)
if width != height:
if width > height:
pads = (0, (width - height) // 2)
elif height > width:
pads = ((height - width) // 2, 0)
square_image_obj = Image.new("RGB", (new_width, new_height))
# pads indicates the upper left corner to paste the input.
square_image_obj.paste(image_obj, pads)
#square_image_obj = square_image_obj.resize((512, 512))
print(f"{width}x{height} -> {new_width}x{new_height} -> {square_image_obj.size}")
long_short_ratio = max(width, height) / min(width, height)
else:
square_image_obj = image_obj
pads = (0, 0)
long_short_ratio = 1
if new_size > 0:
# Resize the shorter edge to 512.
square_image_obj = square_image_obj.resize([int(new_size * long_short_ratio), int(new_size * long_short_ratio)])
return square_image_obj, pads, orig_size
class UNetEnsemble(nn.Module):
# The first unet is the unet already loaded in a pipeline.
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', torch_dtype=torch.float16):
super().__init__()
self.unets = nn.ModuleList()
if unets is not None:
self.unets += [ unet.to(device) for unet in unets ]
if unet_types is not None:
for unet_type in unet_types:
if unet_type == "arc2face":
from adaface.arc2face_models import create_arc2face_pipeline
unet = create_arc2face_pipeline(unet_only=True)
elif unet_type == "consistentID":
unet = create_consistentid_pipeline(unet_only=True)
else:
breakpoint()
self.unets.append(unet.to(device=device))
if extra_unet_dirpaths is not None:
for unet_path in extra_unet_dirpaths:
unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
self.unets.append(unet.to(device=device))
if unet_weights is None:
unet_weights = [1.] * len(self.unets)
elif len(self.unets) < len(unet_weights):
unet_weights = unet_weights[:len(self.unets)]
elif len(self.unets) > len(unet_weights):
breakpoint()
unet_weights = torch.tensor(unet_weights, dtype=torch_dtype)
unet_weights = unet_weights / unet_weights.sum()
self.unet_weights = nn.Parameter(unet_weights, requires_grad=False)
print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights.data.cpu().numpy()}")
# Set these fields to be compatible with diffusers.
self.dtype = self.unets[0].dtype
self.device = self.unets[0].device
self.config = self.unets[0].config
def forward(self, *args, **kwargs):
return_dict = kwargs.get('return_dict', True)
teacher_contexts = kwargs.pop('encoder_hidden_states', None)
# Only one teacher_context is provided. That means all unets will use the same teacher_context.
# We repeat the teacher_contexts to match the number of unets.
if not isinstance(teacher_contexts, (list, tuple)):
teacher_contexts = [teacher_contexts]
if len(teacher_contexts) == 1 and len(self.unets) > 1:
teacher_contexts = teacher_contexts * len(self.unets)
samples = []
for unet, teacher_context in zip(self.unets, teacher_contexts):
sample = unet(encoder_hidden_states=teacher_context, *args, **kwargs)
if not return_dict:
sample = sample[0]
else:
sample = sample.sample
samples.append(sample)
samples = torch.stack(samples, dim=0)
unet_weights = self.unet_weights.reshape(-1, *([1] * (samples.ndim - 1)))
sample = (samples * unet_weights).sum(dim=0)
if not return_dict:
return (sample,)
else:
return UNet2DConditionOutput(sample=sample)
def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetensors",
dtype=torch.float16, unet_only=False):
pipe = ConsistentIDPipeline.from_single_file(base_model_path)
# consistentID specific modules are still in fp32. Will be converted to fp16
# later with .to(device, torch_dtype) by the caller.
pipe.load_ConsistentID_model(
consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
)
# Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
# because we've overloaded .to() to convert consistentID specific modules as well,
# but diffusers will call .to(dtype) in .from_single_file(),
# and at that moment, the consistentID specific modules are not loaded yet.
pipe.to(dtype=dtype)
# We load the pipeline first, then use the unet in the pipeline.
# Since the pipeline initialization will load LoRA into the unet,
# now we have the unet with LoRA loaded.
if unet_only:
# We release text_encoder and VAE to save memory.
pipe.release_components(["text_encoder", "vae"])
return pipe.unet
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
pipe.scheduler = noise_scheduler
return pipe
@dataclass
class BaseModelOutputWithPooling2(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
attn_mask: Optional[torch.FloatTensor] = None
# Revised from CLIPVisionTransformer to support attention mask.
# self: a CLIPVisionTransformer instance.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821
# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224]
# attn_mask: B*H*W attention mask.
def CLIPVisionTransformer_forward_with_mask(self, pixel_values = None, attn_mask=None,
output_attentions = None,
output_hidden_states = None, return_dict = None):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Visual tokens are flattended in embeddings().
# self.embeddings: CLIPVisionEmbeddings.
# hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds).
# 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False).
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
if attn_mask is not None:
# feat_edge_size: 16.
feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
# attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16].
attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
# Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256].
attn_mask = attn_mask.flatten(2)
# Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257].
# This 1 corresponds to class_embeds, which is always attended to.
attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
else:
attn_mask_pairs = None
# encoder: CLIPEncoder.
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
# New feature: (***The official documentation is wrong***)
# attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*):
# Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`:
# - 1 for pairs that are **not masked**,
# - 0 for pairs that are **masked**.
# attention_mask is eventually used by CLIPEncoderLayer:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370
attention_mask=attn_mask_pairs,
output_attentions=output_attentions, # False
output_hidden_states=output_hidden_states, # True
return_dict=return_dict, # True
)
# last_hidden_state: [BS, 257, 1280]
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
# return_dict is True.
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling2(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
# Newly added: return resized flattened attention mask.
# [BS, 1, 257] -> [BS, 257, 1]
attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
)
def CLIPVisionModel_forward_with_mask(self, pixel_values = None, attn_mask = None, output_attentions = None,
output_hidden_states = None, return_dict = None):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
attn_mask=attn_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# patch_clip_image_encoder_with_mask() is applicable to both CLIPVisionModel and CLIPVisionModelWithProjection.
def patch_clip_image_encoder_with_mask(clip_image_encoder):
clip_image_encoder.vision_model.forward = CLIPVisionTransformer_forward_with_mask.__get__(clip_image_encoder.vision_model)
clip_image_encoder.forward = CLIPVisionModel_forward_with_mask.__get__(clip_image_encoder)
return clip_image_encoder
class CLIPVisionModelWithMask(CLIPVisionModel):
def __init__(self, config):
super().__init__(config)
# Replace vision_model.forward() with the new one that supports mask.
patch_clip_image_encoder_with_mask(self)