Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import wandb | |
import cv2 | |
import torch.nn.functional as F | |
import numpy as np | |
from facenet_pytorch import MTCNN | |
from torchvision import transforms | |
from dreamsim import dreamsim | |
from einops import rearrange | |
import kornia.augmentation as K | |
import lpips | |
from pretrained_models.arcface import Backbone | |
from utils.vis_utils import add_text_to_image | |
from utils.utils import extract_faces_and_landmarks | |
import clip | |
class Loss(): | |
""" | |
General purpose loss class. | |
Mainly handles dtype and visualize_every_k. | |
keeps current iteration of loss, mainly for visualization purposes. | |
""" | |
def __init__(self, visualize_every_k=-1, dtype=torch.float32, accelerator=None, **kwargs): | |
self.visualize_every_k = visualize_every_k | |
self.iteration = -1 | |
self.dtype=dtype | |
self.accelerator = accelerator | |
def __call__(self, **kwargs): | |
self.iteration += 1 | |
return self.forward(**kwargs) | |
class L1Loss(Loss): | |
""" | |
Simple L1 loss between predicted_pixel_values and pixel_values | |
Args: | |
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. | |
encoder_pixel_values (torch.Tesnor): The input image to the encoder | |
""" | |
def forward( | |
self, | |
predict: torch.Tensor, | |
target: torch.Tensor, | |
**kwargs | |
) -> torch.Tensor: | |
return F.l1_loss(predict, target, reduction="mean") | |
class DreamSIMLoss(Loss): | |
"""DreamSIM loss between predicted_pixel_values and pixel_values. | |
DreamSIM is similar to LPIPS (https://dreamsim-nights.github.io/) but is trained on more human defined similarity dataset | |
DreamSIM expects an RGB image of size 224x224 and values between 0 and 1. So we need to normalize the input images to 0-1 range and resize them to 224x224. | |
Args: | |
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. | |
encoder_pixel_values (torch.Tesnor): The input image to the encoder | |
""" | |
def __init__(self, device: str='cuda:0', **kwargs): | |
super().__init__(**kwargs) | |
self.model, _ = dreamsim(pretrained=True, device=device) | |
self.model.to(dtype=self.dtype, device=device) | |
self.model = self.accelerator.prepare(self.model) | |
self.transforms = transforms.Compose([ | |
transforms.Lambda(lambda x: (x + 1) / 2), | |
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)]) | |
def forward( | |
self, | |
predicted_pixel_values: torch.Tensor, | |
encoder_pixel_values: torch.Tensor, | |
**kwargs, | |
) -> torch.Tensor: | |
predicted_pixel_values.to(dtype=self.dtype) | |
encoder_pixel_values.to(dtype=self.dtype) | |
return self.model(self.transforms(predicted_pixel_values), self.transforms(encoder_pixel_values)).mean() | |
class LPIPSLoss(Loss): | |
"""LPIPS loss between predicted_pixel_values and pixel_values. | |
Args: | |
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. | |
encoder_pixel_values (torch.Tesnor): The input image to the encoder | |
""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.model = lpips.LPIPS(net='vgg') | |
self.model.to(dtype=self.dtype, device=self.accelerator.device) | |
self.model = self.accelerator.prepare(self.model) | |
def forward(self, predict, target, **kwargs): | |
predict.to(dtype=self.dtype) | |
target.to(dtype=self.dtype) | |
return self.model(predict, target).mean() | |
class LCMVisualization(Loss): | |
"""Dummy loss used to visualize the LCM outputs | |
Args: | |
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. | |
pixel_values (torch.Tensor): The input image to the decoder | |
encoder_pixel_values (torch.Tesnor): The input image to the encoder | |
""" | |
def forward( | |
self, | |
predicted_pixel_values: torch.Tensor, | |
pixel_values: torch.Tensor, | |
encoder_pixel_values: torch.Tensor, | |
timesteps: torch.Tensor, | |
**kwargs, | |
) -> None: | |
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: | |
predicted_pixel_values = rearrange(predicted_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() | |
pixel_values = rearrange(pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() | |
encoder_pixel_values = rearrange(encoder_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() | |
image = np.hstack([encoder_pixel_values, pixel_values, predicted_pixel_values]) | |
for tracker in self.accelerator.trackers: | |
if tracker.name == 'wandb': | |
tracker.log({"TrainVisualization": wandb.Image(image, caption=f"Encoder Input Image, Decoder Input Image, Predicted LCM Image. Timesteps {timesteps.cpu().tolist()}")}) | |
return torch.tensor(0.0) | |
class L2Loss(Loss): | |
""" | |
Regular diffusion loss between predicted noise and target noise. | |
Args: | |
predicted_noise (torch.Tensor): noise predicted by the diffusion model | |
target_noise (torch.Tensor): actual noise added to the image. | |
""" | |
def forward( | |
self, | |
predict: torch.Tensor, | |
target: torch.Tensor, | |
weights: torch.Tensor = None, | |
**kwargs | |
) -> torch.Tensor: | |
if weights is not None: | |
loss = (predict.float() - target.float()).pow(2) * weights | |
return loss.mean() | |
return F.mse_loss(predict.float(), target.float(), reduction="mean") | |
class HuberLoss(Loss): | |
"""Huber loss between predicted_pixel_values and pixel_values. | |
Args: | |
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. | |
encoder_pixel_values (torch.Tesnor): The input image to the encoder | |
""" | |
def __init__(self, huber_c=0.001, **kwargs): | |
super().__init__(**kwargs) | |
self.huber_c = huber_c | |
def forward( | |
self, | |
predict: torch.Tensor, | |
target: torch.Tensor, | |
weights: torch.Tensor = None, | |
**kwargs | |
) -> torch.Tensor: | |
loss = torch.sqrt((predict.float() - target.float()) ** 2 + self.huber_c**2) - self.huber_c | |
if weights is not None: | |
return (loss * weights).mean() | |
return loss.mean() | |
class WeightedNoiseLoss(Loss): | |
""" | |
Weighted diffusion loss between predicted noise and target noise. | |
Args: | |
predicted_noise (torch.Tensor): noise predicted by the diffusion model | |
target_noise (torch.Tensor): actual noise added to the image. | |
loss_batch_weights (torch.Tensor): weighting for each batch item. Can be used to e.g. zero-out loss for InstantID training if keypoint extraction fails. | |
""" | |
def forward( | |
self, | |
predict: torch.Tensor, | |
target: torch.Tensor, | |
weights, | |
**kwargs | |
) -> torch.Tensor: | |
return F.mse_loss(predict.float() * weights, target.float() * weights, reduction="mean") | |
class IDLoss(Loss): | |
""" | |
Use pretrained facenet model to extract features from the face of the predicted image and target image. | |
Facenet expects 112x112 images, so we crop the face using MTCNN and resize it to 112x112. | |
Then we use the cosine similarity between the features to calculate the loss. (The cosine similarity is 1 - cosine distance). | |
Also notice that the outputs of facenet are normalized so the dot product is the same as cosine distance. | |
""" | |
def __init__(self, pretrained_arcface_path: str, skip_not_found=True, **kwargs): | |
super().__init__(**kwargs) | |
assert pretrained_arcface_path is not None, "please pass `pretrained_arcface_path` in the losses config. You can download the pretrained model from "\ | |
"https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing" | |
self.mtcnn = MTCNN(device=self.accelerator.device) | |
self.mtcnn.forward = self.mtcnn.detect | |
self.facenet_input_size = 112 # Has to be 112, can't find weights for 224 size. | |
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') | |
self.facenet.load_state_dict(torch.load(pretrained_arcface_path)) | |
self.face_pool = torch.nn.AdaptiveAvgPool2d((self.facenet_input_size, self.facenet_input_size)) | |
self.facenet.requires_grad_(False) | |
self.facenet.eval() | |
self.facenet.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision | |
self.face_pool.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision | |
self.visualization_resize = transforms.Resize((self.facenet_input_size, self.facenet_input_size), interpolation=transforms.InterpolationMode.BICUBIC) | |
self.reference_facial_points = np.array([[38.29459953, 51.69630051], | |
[72.53179932, 51.50139999], | |
[56.02519989, 71.73660278], | |
[41.54930115, 92.3655014], | |
[70.72990036, 92.20410156] | |
]) # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112 | |
self.facenet, self.face_pool, self.mtcnn = self.accelerator.prepare(self.facenet, self.face_pool, self.mtcnn) | |
self.skip_not_found = skip_not_found | |
def extract_feats(self, x: torch.Tensor): | |
""" | |
Extract features from the face of the image using facenet model. | |
""" | |
x = self.face_pool(x) | |
x_feats = self.facenet(x) | |
return x_feats | |
def forward( | |
self, | |
predicted_pixel_values: torch.Tensor, | |
encoder_pixel_values: torch.Tensor, | |
timesteps: torch.Tensor, | |
**kwargs | |
): | |
encoder_pixel_values = encoder_pixel_values.to(dtype=self.dtype) | |
predicted_pixel_values = predicted_pixel_values.to(dtype=self.dtype) | |
predicted_pixel_values_face, predicted_invalid_indices = extract_faces_and_landmarks(predicted_pixel_values, mtcnn=self.mtcnn) | |
with torch.no_grad(): | |
encoder_pixel_values_face, source_invalid_indices = extract_faces_and_landmarks(encoder_pixel_values, mtcnn=self.mtcnn) | |
if self.skip_not_found: | |
valid_indices = [] | |
for i in range(predicted_pixel_values.shape[0]): | |
if i not in predicted_invalid_indices and i not in source_invalid_indices: | |
valid_indices.append(i) | |
else: | |
valid_indices = list(range(predicted_pixel_values)) | |
valid_indices = torch.tensor(valid_indices).to(device=predicted_pixel_values.device) | |
if len(valid_indices) == 0: | |
loss = (predicted_pixel_values_face * 0.0).mean() # It's done this way so the `backwards` will delete the computation graph of the predicted_pixel_values. | |
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: | |
self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss) | |
return loss | |
with torch.no_grad(): | |
pixel_values_feats = self.extract_feats(encoder_pixel_values_face[valid_indices]) | |
predicted_pixel_values_feats = self.extract_feats(predicted_pixel_values_face[valid_indices]) | |
loss = 1 - torch.einsum("bi,bi->b", pixel_values_feats, predicted_pixel_values_feats) | |
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: | |
self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss) | |
return loss.mean() | |
def visualize( | |
self, | |
predicted_pixel_values: torch.Tensor, | |
encoder_pixel_values: torch.Tensor, | |
predicted_pixel_values_face: torch.Tensor, | |
encoder_pixel_values_face: torch.Tensor, | |
timesteps: torch.Tensor, | |
valid_indices: torch.Tensor, | |
loss: torch.Tensor, | |
) -> None: | |
small_predicted_pixel_values = (rearrange(self.visualization_resize(predicted_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy()) | |
small_pixle_values = rearrange(self.visualization_resize(encoder_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy() | |
small_predicted_pixel_values_face = rearrange(self.visualization_resize(predicted_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy() | |
small_pixle_values_face = rearrange(self.visualization_resize(encoder_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy() | |
small_predicted_pixel_values = add_text_to_image(((small_predicted_pixel_values * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Images", add_below=False) | |
small_pixle_values = add_text_to_image(((small_pixle_values * 0.5 + 0.5) * 255).astype(np.uint8), "Target Images", add_below=False) | |
small_predicted_pixel_values_face = add_text_to_image(((small_predicted_pixel_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Faces", add_below=False) | |
small_pixle_values_face = add_text_to_image(((small_pixle_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Target Faces", add_below=False) | |
final_image = np.hstack([small_predicted_pixel_values, small_pixle_values, small_predicted_pixel_values_face, small_pixle_values_face]) | |
for tracker in self.accelerator.trackers: | |
if tracker.name == 'wandb': | |
tracker.log({"IDLoss Visualization": wandb.Image(final_image, caption=f"loss: {loss.cpu().tolist()} timesteps: {timesteps.cpu().tolist()}, valid_indices: {valid_indices.cpu().tolist()}")}) | |
class ImageAugmentations(torch.nn.Module): | |
# Standard image augmentations used for CLIP loss to discourage adversarial outputs. | |
def __init__(self, output_size, augmentations_number, p=0.7): | |
super().__init__() | |
self.output_size = output_size | |
self.augmentations_number = augmentations_number | |
self.augmentations = torch.nn.Sequential( | |
K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore | |
K.RandomPerspective(0.7, p=p), | |
) | |
self.avg_pool = torch.nn.AdaptiveAvgPool2d((self.output_size, self.output_size)) | |
self.device = None | |
def forward(self, input): | |
"""Extents the input batch with augmentations | |
If the input is consists of images [I1, I2] the extended augmented output | |
will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...] | |
Args: | |
input ([type]): input batch of shape [batch, C, H, W] | |
Returns: | |
updated batch: of shape [batch * augmentations_number, C, H, W] | |
""" | |
# We want to multiply the number of images in the batch in contrast to regular augmantations | |
# that do not change the number of samples in the batch) | |
resized_images = self.avg_pool(input) | |
resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1)) | |
batch_size = input.shape[0] | |
# We want at least one non augmented image | |
non_augmented_batch = resized_images[:batch_size] | |
augmented_batch = self.augmentations(resized_images[batch_size:]) | |
updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0) | |
return updated_batch | |
class CLIPLoss(Loss): | |
def __init__(self, augmentations_number: int = 4, **kwargs): | |
super().__init__(**kwargs) | |
self.clip_model, clip_preprocess = clip.load("ViT-B/16", device=self.accelerator.device, jit=False) | |
self.clip_model.device = None | |
self.clip_model.eval().requires_grad_(False) | |
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1]. | |
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions | |
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor | |
self.clip_size = self.clip_model.visual.input_resolution | |
self.clip_normalize = transforms.Normalize( | |
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] | |
) | |
self.image_augmentations = ImageAugmentations(output_size=self.clip_size, | |
augmentations_number=augmentations_number) | |
self.clip_model, self.image_augmentations = self.accelerator.prepare(self.clip_model, self.image_augmentations) | |
def forward(self, decoder_prompts, predicted_pixel_values: torch.Tensor, **kwargs) -> torch.Tensor: | |
if not isinstance(decoder_prompts, list): | |
decoder_prompts = [decoder_prompts] | |
tokens = clip.tokenize(decoder_prompts).to(predicted_pixel_values.device) | |
image = self.preprocess(predicted_pixel_values) | |
logits_per_image, _ = self.clip_model(image, tokens) | |
logits_per_image = torch.diagonal(logits_per_image) | |
return (1. - logits_per_image / 100).mean() | |
class DINOLoss(Loss): | |
def __init__( | |
self, | |
dino_model, | |
dino_preprocess, | |
output_hidden_states: bool = False, | |
center_momentum: float = 0.9, | |
student_temp: float = 0.1, | |
teacher_temp: float = 0.04, | |
warmup_teacher_temp: float = 0.04, | |
warmup_teacher_temp_epochs: int = 30, | |
**kwargs): | |
super().__init__(**kwargs) | |
self.dino_model = dino_model | |
self.output_hidden_states = output_hidden_states | |
self.rescale_factor = dino_preprocess.rescale_factor | |
# Un-normalize from [-1.0, 1.0] (SD output) to [0, 1]. | |
self.preprocess = transforms.Compose( | |
[ | |
transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]), | |
transforms.Resize(size=256), | |
transforms.CenterCrop(size=(224, 224)), | |
transforms.Normalize(mean=dino_preprocess.image_mean, std=dino_preprocess.image_std) | |
] | |
) | |
self.student_temp = student_temp | |
self.teacher_temp = teacher_temp | |
self.center_momentum = center_momentum | |
self.center = torch.zeros(1, 257, 1024).to(self.accelerator.device, dtype=self.dtype) | |
# TODO: add temp, now fixed to 0.04 | |
# we apply a warm up for the teacher temperature because | |
# a too high temperature makes the training instable at the beginning | |
# self.teacher_temp_schedule = np.concatenate(( | |
# np.linspace(warmup_teacher_temp, | |
# teacher_temp, warmup_teacher_temp_epochs), | |
# np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp | |
# )) | |
self.dino_model = self.accelerator.prepare(self.dino_model) | |
def forward( | |
self, | |
target: torch.Tensor, | |
predict: torch.Tensor, | |
weights: torch.Tensor = None, | |
**kwargs) -> torch.Tensor: | |
predict = self.preprocess(predict) | |
target = self.preprocess(target) | |
encoder_input = torch.cat([target, predict]).to(self.dino_model.device, dtype=self.dino_model.dtype) | |
if self.output_hidden_states: | |
raise ValueError("Output hidden states not supported for DINO loss.") | |
image_enc_hidden_states = self.dino_model(encoder_input, output_hidden_states=True).hidden_states[-2] | |
else: | |
image_enc_hidden_states = self.dino_model(encoder_input).last_hidden_state | |
teacher_output, student_output = image_enc_hidden_states.chunk(2, dim=0) # [B, 257, 1024] | |
student_out = student_output.float() / self.student_temp | |
# teacher centering and sharpening | |
# temp = self.teacher_temp_schedule[epoch] | |
temp = self.teacher_temp | |
teacher_out = F.softmax((teacher_output.float() - self.center) / temp, dim=-1) | |
teacher_out = teacher_out.detach() | |
loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1, keepdim=True) | |
# self.update_center(teacher_output) | |
if weights is not None: | |
loss = loss * weights | |
return loss.mean() | |
return loss.mean() | |
def update_center(self, teacher_output): | |
""" | |
Update center used for teacher output. | |
""" | |
batch_center = torch.sum(teacher_output, dim=0, keepdim=True) | |
self.accelerator.reduce(batch_center, reduction="sum") | |
batch_center = batch_center / (len(teacher_output) * self.accelerator.num_processes) | |
# ema update | |
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) |