Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from PIL import Image | |
from diffusers import UNet2DConditionModel | |
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput | |
from transformers import CLIPVisionModel | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
from transformers.utils import ModelOutput | |
import numpy as np | |
import argparse | |
from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline | |
from diffusers import ( | |
UNet2DConditionModel, | |
DDIMScheduler, | |
) | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected.") | |
# perturb_tensor() adds a fixed amount of noise to the tensor. | |
def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=False, | |
std_dim=-1, norm_dim=-1, verbose=True): | |
orig_ts = ts | |
if perturb_std_is_relative: | |
ts_std_mean = ts.std(dim=std_dim).mean().detach() | |
perturb_std *= ts_std_mean | |
# ts_std_mean: 50~80 for unnormalized images, perturb_std: 2.5-4 for 0.05 noise. | |
if verbose: | |
print(f"ts_std_mean: {ts_std_mean:.03f}, perturb_std: {perturb_std:.03f}") | |
noise = torch.randn_like(ts) * perturb_std | |
if keep_norm: | |
orig_norm = ts.norm(dim=norm_dim, keepdim=True) | |
ts = ts + noise | |
new_norm = ts.norm(dim=norm_dim, keepdim=True).detach() | |
ts = ts * orig_norm / (new_norm + 1e-8) | |
else: | |
ts = ts + noise | |
if verbose: | |
print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}") | |
return ts | |
def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_dim=-1): | |
ts = torch.from_numpy(np_array).to(dtype=torch.float32) | |
ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim) | |
return ts.numpy().astype(np_array.dtype) | |
def calc_stats(emb_name, embeddings, mean_dim=0): | |
print("%s:" %emb_name) | |
repeat_count = [1] * embeddings.ndim | |
repeat_count[mean_dim] = embeddings.shape[mean_dim] | |
# Average across the mean_dim dim. | |
# Make emb_mean the same size as embeddings, as required by F.l1_loss. | |
emb_mean = embeddings.mean(mean_dim, keepdim=True).repeat(repeat_count) | |
l1_loss = F.l1_loss(embeddings, emb_mean) | |
# F.l2_loss doesn't take sqrt. So the loss is very small. | |
# Compute it manually. | |
l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt() | |
norms = torch.norm(embeddings, dim=1).detach().cpu().numpy() | |
print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item())) | |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std())) | |
# Revised from RevGrad, by removing the grad negation. | |
class ScaleGrad(torch.autograd.Function): | |
def forward(ctx, input_, alpha_, debug=False): | |
ctx.save_for_backward(alpha_, debug) | |
output = input_ | |
if debug: | |
print(f"input: {input_.abs().mean().item()}") | |
return output | |
def backward(ctx, grad_output): # pragma: no cover | |
# saved_tensors returns a tuple of tensors. | |
alpha_, debug = ctx.saved_tensors | |
if ctx.needs_input_grad[0]: | |
grad_output2 = grad_output * alpha_ | |
if debug: | |
print(f"grad_output2: {grad_output2.abs().mean().item()}") | |
else: | |
grad_output2 = None | |
return grad_output2, None, None | |
class GradientScaler(nn.Module): | |
def __init__(self, alpha=1., debug=False, *args, **kwargs): | |
""" | |
A gradient scaling layer. | |
This layer has no parameters, and simply scales the gradient in the backward pass. | |
""" | |
super().__init__(*args, **kwargs) | |
self._alpha = torch.tensor(alpha, requires_grad=False) | |
self._debug = torch.tensor(debug, requires_grad=False) | |
def forward(self, input_): | |
_debug = self._debug if hasattr(self, '_debug') else False | |
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug) | |
def gen_gradient_scaler(alpha, debug=False): | |
if alpha == 1: | |
return nn.Identity() | |
if alpha > 0: | |
return GradientScaler(alpha, debug=debug) | |
else: | |
assert alpha == 0 | |
# Don't use lambda function here, otherwise the object can't be pickled. | |
return torch.detach | |
def pad_image_obj_to_square(image_obj, new_size=-1): | |
# Remove alpha channel if it exists. | |
if image_obj.mode == 'RGBA': | |
image_obj = image_obj.convert('RGB') | |
# Pad input to be width == height | |
width, height = orig_size = image_obj.size | |
new_width, new_height = max(width, height), max(width, height) | |
if width != height: | |
if width > height: | |
pads = (0, (width - height) // 2) | |
elif height > width: | |
pads = ((height - width) // 2, 0) | |
square_image_obj = Image.new("RGB", (new_width, new_height)) | |
# pads indicates the upper left corner to paste the input. | |
square_image_obj.paste(image_obj, pads) | |
#square_image_obj = square_image_obj.resize((512, 512)) | |
print(f"{width}x{height} -> {new_width}x{new_height} -> {square_image_obj.size}") | |
long_short_ratio = max(width, height) / min(width, height) | |
else: | |
square_image_obj = image_obj | |
pads = (0, 0) | |
long_short_ratio = 1 | |
if new_size > 0: | |
# Resize the shorter edge to 512. | |
square_image_obj = square_image_obj.resize([int(new_size * long_short_ratio), int(new_size * long_short_ratio)]) | |
return square_image_obj, pads, orig_size | |
class UNetEnsemble(nn.Module): | |
# The first unet is the unet already loaded in a pipeline. | |
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', torch_dtype=torch.float16): | |
super().__init__() | |
self.unets = nn.ModuleList() | |
if unets is not None: | |
self.unets += [ unet.to(device) for unet in unets ] | |
if unet_types is not None: | |
for unet_type in unet_types: | |
if unet_type == "arc2face": | |
from adaface.arc2face_models import create_arc2face_pipeline | |
unet = create_arc2face_pipeline(unet_only=True) | |
elif unet_type == "consistentID": | |
unet = create_consistentid_pipeline(unet_only=True) | |
else: | |
breakpoint() | |
self.unets.append(unet.to(device=device)) | |
if extra_unet_dirpaths is not None: | |
for unet_path in extra_unet_dirpaths: | |
unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype) | |
self.unets.append(unet.to(device=device)) | |
if unet_weights is None: | |
unet_weights = [1.] * len(self.unets) | |
elif len(self.unets) < len(unet_weights): | |
unet_weights = unet_weights[:len(self.unets)] | |
elif len(self.unets) > len(unet_weights): | |
breakpoint() | |
unet_weights = torch.tensor(unet_weights, dtype=torch_dtype) | |
unet_weights = unet_weights / unet_weights.sum() | |
self.unet_weights = nn.Parameter(unet_weights, requires_grad=False) | |
print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights.data.cpu().numpy()}") | |
# Set these fields to be compatible with diffusers. | |
self.dtype = self.unets[0].dtype | |
self.device = self.unets[0].device | |
self.config = self.unets[0].config | |
def forward(self, *args, **kwargs): | |
return_dict = kwargs.get('return_dict', True) | |
teacher_contexts = kwargs.pop('encoder_hidden_states', None) | |
# Only one teacher_context is provided. That means all unets will use the same teacher_context. | |
# We repeat the teacher_contexts to match the number of unets. | |
if not isinstance(teacher_contexts, (list, tuple)): | |
teacher_contexts = [teacher_contexts] | |
if len(teacher_contexts) == 1 and len(self.unets) > 1: | |
teacher_contexts = teacher_contexts * len(self.unets) | |
samples = [] | |
for unet, teacher_context in zip(self.unets, teacher_contexts): | |
sample = unet(encoder_hidden_states=teacher_context, *args, **kwargs) | |
if not return_dict: | |
sample = sample[0] | |
else: | |
sample = sample.sample | |
samples.append(sample) | |
samples = torch.stack(samples, dim=0) | |
unet_weights = self.unet_weights.reshape(-1, *([1] * (samples.ndim - 1))) | |
sample = (samples * unet_weights).sum(dim=0) | |
if not return_dict: | |
return (sample,) | |
else: | |
return UNet2DConditionOutput(sample=sample) | |
def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetensors", | |
dtype=torch.float16, unet_only=False): | |
pipe = ConsistentIDPipeline.from_single_file(base_model_path) | |
# consistentID specific modules are still in fp32. Will be converted to fp16 | |
# later with .to(device, torch_dtype) by the caller. | |
pipe.load_ConsistentID_model( | |
consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin", | |
bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth", | |
) | |
# Avoid passing dtype to ConsistentIDPipeline.from_single_file(), | |
# because we've overloaded .to() to convert consistentID specific modules as well, | |
# but diffusers will call .to(dtype) in .from_single_file(), | |
# and at that moment, the consistentID specific modules are not loaded yet. | |
pipe.to(dtype=dtype) | |
# We load the pipeline first, then use the unet in the pipeline. | |
# Since the pipeline initialization will load LoRA into the unet, | |
# now we have the unet with LoRA loaded. | |
if unet_only: | |
# We release text_encoder and VAE to save memory. | |
pipe.release_components(["text_encoder", "vae"]) | |
return pipe.unet | |
noise_scheduler = DDIMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
) | |
pipe.scheduler = noise_scheduler | |
return pipe | |
class BaseModelOutputWithPooling2(ModelOutput): | |
""" | |
Base class for model's outputs that also contains a pooling of the last hidden states. | |
Args: | |
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
Sequence of hidden-states at the output of the last layer of the model. | |
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): | |
Last layer hidden-state of the first token of the sequence (classification token) after further processing | |
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns | |
the classification token after processing through a linear layer and a tanh activation function. The linear | |
layer weights are trained from the next sentence prediction (classification) objective during pretraining. | |
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | |
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | |
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
sequence_length)`. | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
heads. | |
""" | |
last_hidden_state: torch.FloatTensor = None | |
pooler_output: torch.FloatTensor = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[torch.FloatTensor]] = None | |
attn_mask: Optional[torch.FloatTensor] = None | |
# Revised from CLIPVisionTransformer to support attention mask. | |
# self: a CLIPVisionTransformer instance. | |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821 | |
# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224] | |
# attn_mask: B*H*W attention mask. | |
def CLIPVisionTransformer_forward_with_mask(self, pixel_values = None, attn_mask=None, | |
output_attentions = None, | |
output_hidden_states = None, return_dict = None): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
# Visual tokens are flattended in embeddings(). | |
# self.embeddings: CLIPVisionEmbeddings. | |
# hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds). | |
# 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False). | |
hidden_states = self.embeddings(pixel_values) | |
hidden_states = self.pre_layrnorm(hidden_states) | |
if attn_mask is not None: | |
# feat_edge_size: 16. | |
feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int) | |
# attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16]. | |
attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest') | |
# Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256]. | |
attn_mask = attn_mask.flatten(2) | |
# Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257]. | |
# This 1 corresponds to class_embeds, which is always attended to. | |
attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1) | |
attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1) | |
else: | |
attn_mask_pairs = None | |
# encoder: CLIPEncoder. | |
encoder_outputs = self.encoder( | |
inputs_embeds=hidden_states, | |
# New feature: (***The official documentation is wrong***) | |
# attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*): | |
# Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`: | |
# - 1 for pairs that are **not masked**, | |
# - 0 for pairs that are **masked**. | |
# attention_mask is eventually used by CLIPEncoderLayer: | |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370 | |
attention_mask=attn_mask_pairs, | |
output_attentions=output_attentions, # False | |
output_hidden_states=output_hidden_states, # True | |
return_dict=return_dict, # True | |
) | |
# last_hidden_state: [BS, 257, 1280] | |
last_hidden_state = encoder_outputs[0] | |
pooled_output = last_hidden_state[:, 0, :] | |
pooled_output = self.post_layernorm(pooled_output) | |
# return_dict is True. | |
if not return_dict: | |
return (last_hidden_state, pooled_output) + encoder_outputs[1:] | |
return BaseModelOutputWithPooling2( | |
last_hidden_state=last_hidden_state, | |
pooler_output=pooled_output, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
# Newly added: return resized flattened attention mask. | |
# [BS, 1, 257] -> [BS, 257, 1] | |
attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None | |
) | |
def CLIPVisionModel_forward_with_mask(self, pixel_values = None, attn_mask = None, output_attentions = None, | |
output_hidden_states = None, return_dict = None): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
return self.vision_model( | |
pixel_values=pixel_values, | |
attn_mask=attn_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
# patch_clip_image_encoder_with_mask() is applicable to both CLIPVisionModel and CLIPVisionModelWithProjection. | |
def patch_clip_image_encoder_with_mask(clip_image_encoder): | |
clip_image_encoder.vision_model.forward = CLIPVisionTransformer_forward_with_mask.__get__(clip_image_encoder.vision_model) | |
clip_image_encoder.forward = CLIPVisionModel_forward_with_mask.__get__(clip_image_encoder) | |
return clip_image_encoder | |
class CLIPVisionModelWithMask(CLIPVisionModel): | |
def __init__(self, config): | |
super().__init__(config) | |
# Replace vision_model.forward() with the new one that supports mask. | |
patch_clip_image_encoder_with_mask(self) | |