Spaces:
Runtime error
Runtime error
import torch | |
from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer | |
from omegaconf import OmegaConf | |
import math | |
import imageio | |
from PIL import Image | |
import torchvision | |
import torch.nn.functional as F | |
import torch | |
import numpy as np | |
from PIL import Image | |
import time | |
import datetime | |
import torch | |
import sys | |
import os | |
from torchvision import datasets | |
import pickle | |
# StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl | |
use_half_prec = True | |
if use_half_prec: | |
from my_half_diffusers import AutoencoderKL, UNet2DConditionModel | |
from my_half_diffusers.schedulers.scheduling_utils import SchedulerOutput | |
from my_half_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler | |
else: | |
from my_diffusers import AutoencoderKL, UNet2DConditionModel | |
from my_diffusers.schedulers.scheduling_utils import SchedulerOutput | |
from my_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler | |
torch_dtype = torch.float16 if use_half_prec else torch.float64 | |
np_dtype = np.float16 if use_half_prec else np.float64 | |
import random | |
from tqdm.auto import tqdm | |
from torch import autocast | |
from difflib import SequenceMatcher | |
# Build our CLIP model | |
model_path_clip = "openai/clip-vit-large-patch14" | |
clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip) | |
clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch_dtype) | |
clip = clip_model.text_model | |
# Getting our HF Auth token | |
auth_token = os.environ.get('auth_token') | |
if auth_token is None: | |
with open('hf_auth', 'r') as f: | |
auth_token = f.readlines()[0].strip() | |
model_path_diffusion = "CompVis/stable-diffusion-v1-4" | |
# Build our SD model | |
unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) | |
vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) | |
# Push to devices w/ double precision | |
device = 'cuda' | |
if use_half_prec: | |
unet.to(device) | |
vae.to(device) | |
clip.to(device) | |
else: | |
unet.double().to(device) | |
vae.double().to(device) | |
clip.double().to(device) | |
print("Loaded all models") | |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
from transformers import AutoFeatureExtractor | |
# load safety model | |
safety_model_id = "CompVis/stable-diffusion-safety-checker" | |
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) | |
def load_replacement(x): | |
try: | |
hwc = x.shape | |
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) | |
y = (np.array(y)/255.0).astype(x.dtype) | |
assert y.shape == x.shape | |
return y | |
except Exception: | |
return x | |
def check_safety(x_image): | |
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") | |
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) | |
assert x_checked_image.shape[0] == len(has_nsfw_concept) | |
for i in range(len(has_nsfw_concept)): | |
if has_nsfw_concept[i]: | |
# x_checked_image[i] = load_replacement(x_checked_image[i]) | |
x_checked_image[i] *= 0 # load_replacement(x_checked_image[i]) | |
return x_checked_image, has_nsfw_concept | |
def EDICT_editing(im_path, | |
base_prompt, | |
edit_prompt, | |
use_p2p=False, | |
steps=50, | |
mix_weight=0.93, | |
init_image_strength=0.8, | |
guidance_scale=3, | |
run_baseline=False, | |
width=512, height=512): | |
""" | |
Main call of our research, performs editing with either EDICT or DDIM | |
Args: | |
im_path: path to image to run on | |
base_prompt: conditional prompt to deterministically noise with | |
edit_prompt: desired text conditoining | |
steps: ddim steps | |
mix_weight: Weight of mixing layers. | |
Higher means more consistent generations but divergence in inversion | |
Lower means opposite | |
This is fairly tuned and can get good results | |
init_image_strength: Editing strength. Higher = more dramatic edit. | |
Typically [0.6, 0.9] is good range. | |
Definitely tunable per-image/maybe best results are at a different value | |
guidance_scale: classifier-free guidance scale | |
3 I've found is the best for both our method and basic DDIM inversion | |
Higher can result in more distorted results | |
run_baseline: | |
VERY IMPORTANT | |
True is EDICT, False is DDIM | |
Output: | |
PAIR of Images (tuple) | |
If run_baseline=True then [0] will be edit and [1] will be original | |
If run_baseline=False then they will be two nearly identical edited versions | |
""" | |
# Resize/center crop to 512x512 (Can do higher res. if desired) | |
if isinstance(im_path, str): | |
orig_im = load_im_into_format_from_path(im_path) | |
elif Image.isImageType(im_path): | |
width, height = im_path.size | |
# add max dim for sake of memory | |
max_dim = max(width, height) | |
if max_dim > 1024: | |
factor = 1024 / max_dim | |
width *= factor | |
height *= factor | |
width = int(width) | |
height = int(height) | |
im_path = im_path.resize((width, height)) | |
min_dim = min(width, height) | |
if min_dim < 512: | |
factor = 512 / min_dim | |
width *= factor | |
height *= factor | |
width = int(width) | |
height = int(height) | |
im_path = im_path.resize((width, height)) | |
width = width - (width%64) | |
height = height - (height%64) | |
orig_im = im_path # general_crop(im_path, width, height) | |
else: | |
orig_im = im_path | |
# compute latent pair (second one will be original latent if run_baseline=True) | |
latents = coupled_stablediffusion(base_prompt, | |
reverse=True, | |
init_image=orig_im, | |
init_image_strength=init_image_strength, | |
steps=steps, | |
mix_weight=mix_weight, | |
guidance_scale=guidance_scale, | |
run_baseline=run_baseline, | |
width=width, height=height) | |
# Denoise intermediate state with new conditioning | |
gen = coupled_stablediffusion(edit_prompt if (not use_p2p) else base_prompt, | |
None if (not use_p2p) else edit_prompt, | |
fixed_starting_latent=latents, | |
init_image_strength=init_image_strength, | |
steps=steps, | |
mix_weight=mix_weight, | |
guidance_scale=guidance_scale, | |
run_baseline=run_baseline, | |
width=width, height=height) | |
return gen | |
def img2img_editing(im_path, | |
edit_prompt, | |
steps=50, | |
init_image_strength=0.7, | |
guidance_scale=3): | |
""" | |
Basic SDEdit/img2img, given an image add some noise and denoise with prompt | |
""" | |
orig_im = load_im_into_format_from_path(im_path) | |
return baseline_stablediffusion(edit_prompt, | |
init_image_strength=init_image_strength, | |
steps=steps, | |
init_image=orig_im, | |
guidance_scale=guidance_scale) | |
def center_crop(im): | |
width, height = im.size # Get dimensions | |
min_dim = min(width, height) | |
left = (width - min_dim)/2 | |
top = (height - min_dim)/2 | |
right = (width + min_dim)/2 | |
bottom = (height + min_dim)/2 | |
# Crop the center of the image | |
im = im.crop((left, top, right, bottom)) | |
return im | |
def general_crop(im, target_w, target_h): | |
width, height = im.size # Get dimensions | |
min_dim = min(width, height) | |
left = target_w / 2 # (width - min_dim)/2 | |
top = target_h / 2 # (height - min_dim)/2 | |
right = width - (target_w / 2) # (width + min_dim)/2 | |
bottom = height - (target_h / 2) # (height + min_dim)/2 | |
# Crop the center of the image | |
im = im.crop((left, top, right, bottom)) | |
return im | |
def load_im_into_format_from_path(im_path): | |
return center_crop(Image.open(im_path)).resize((512,512)) | |
#### P2P STUFF #### | |
def init_attention_weights(weight_tuples): | |
tokens_length = clip_tokenizer.model_max_length | |
weights = torch.ones(tokens_length) | |
for i, w in weight_tuples: | |
if i < tokens_length and i >= 0: | |
weights[i] = w | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and "attn2" in name: | |
module.last_attn_slice_weights = weights.to(device) | |
if module_name == "CrossAttention" and "attn1" in name: | |
module.last_attn_slice_weights = None | |
def init_attention_edit(tokens, tokens_edit): | |
tokens_length = clip_tokenizer.model_max_length | |
mask = torch.zeros(tokens_length) | |
indices_target = torch.arange(tokens_length, dtype=torch.long) | |
indices = torch.zeros(tokens_length, dtype=torch.long) | |
tokens = tokens.input_ids.numpy()[0] | |
tokens_edit = tokens_edit.input_ids.numpy()[0] | |
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): | |
if b0 < tokens_length: | |
if name == "equal" or (name == "replace" and a1-a0 == b1-b0): | |
mask[b0:b1] = 1 | |
indices[b0:b1] = indices_target[a0:a1] | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and "attn2" in name: | |
module.last_attn_slice_mask = mask.to(device) | |
module.last_attn_slice_indices = indices.to(device) | |
if module_name == "CrossAttention" and "attn1" in name: | |
module.last_attn_slice_mask = None | |
module.last_attn_slice_indices = None | |
def init_attention_func(): | |
def new_attention(self, query, key, value, sequence_length, dim): | |
batch_size_attention = query.shape[0] | |
hidden_states = torch.zeros( | |
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype | |
) | |
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] | |
for i in range(hidden_states.shape[0] // slice_size): | |
start_idx = i * slice_size | |
end_idx = (i + 1) * slice_size | |
attn_slice = ( | |
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale | |
) | |
attn_slice = attn_slice.softmax(dim=-1) | |
if self.use_last_attn_slice: | |
if self.last_attn_slice_mask is not None: | |
new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) | |
attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask | |
else: | |
attn_slice = self.last_attn_slice | |
self.use_last_attn_slice = False | |
if self.save_last_attn_slice: | |
self.last_attn_slice = attn_slice | |
self.save_last_attn_slice = False | |
if self.use_last_attn_weights and self.last_attn_slice_weights is not None: | |
attn_slice = attn_slice * self.last_attn_slice_weights | |
self.use_last_attn_weights = False | |
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) | |
hidden_states[start_idx:end_idx] = attn_slice | |
# reshape hidden_states | |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | |
return hidden_states | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention": | |
module.last_attn_slice = None | |
module.use_last_attn_slice = False | |
module.use_last_attn_weights = False | |
module.save_last_attn_slice = False | |
module._attention = new_attention.__get__(module, type(module)) | |
def use_last_tokens_attention(use=True): | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and "attn2" in name: | |
module.use_last_attn_slice = use | |
def use_last_tokens_attention_weights(use=True): | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and "attn2" in name: | |
module.use_last_attn_weights = use | |
def use_last_self_attention(use=True): | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and "attn1" in name: | |
module.use_last_attn_slice = use | |
def save_last_tokens_attention(save=True): | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and "attn2" in name: | |
module.save_last_attn_slice = save | |
def save_last_self_attention(save=True): | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention" and "attn1" in name: | |
module.save_last_attn_slice = save | |
#################################### | |
##### BASELINE ALGORITHM, ONLY USED NOW FOR SDEDIT ####3 | |
def baseline_stablediffusion(prompt="", | |
prompt_edit=None, | |
null_prompt='', | |
prompt_edit_token_weights=[], | |
prompt_edit_tokens_start=0.0, | |
prompt_edit_tokens_end=1.0, | |
prompt_edit_spatial_start=0.0, | |
prompt_edit_spatial_end=1.0, | |
clip_start=0.0, | |
clip_end=1.0, | |
guidance_scale=7, | |
steps=50, | |
seed=1, | |
width=512, height=512, | |
init_image=None, init_image_strength=0.5, | |
fixed_starting_latent = None, | |
prev_image= None, | |
grid=None, | |
clip_guidance=None, | |
clip_guidance_scale=1, | |
num_cutouts=4, | |
cut_power=1, | |
scheduler_str='lms', | |
return_latent=False, | |
one_pass=False, | |
normalize_noise_pred=False): | |
width = width - width % 64 | |
height = height - height % 64 | |
#If seed is None, randomly select seed from 0 to 2^32-1 | |
if seed is None: seed = random.randrange(2**32 - 1) | |
generator = torch.cuda.manual_seed(seed) | |
#Set inference timesteps to scheduler | |
scheduler_dict = {'ddim':DDIMScheduler, | |
'lms':LMSDiscreteScheduler, | |
'pndm':PNDMScheduler, | |
'ddpm':DDPMScheduler} | |
scheduler_call = scheduler_dict[scheduler_str] | |
if scheduler_str == 'ddim': | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, set_alpha_to_one=False) | |
else: | |
scheduler = scheduler_call(beta_schedule="scaled_linear", | |
num_train_timesteps=1000) | |
scheduler.set_timesteps(steps) | |
if prev_image is not None: | |
prev_scheduler = LMSDiscreteScheduler(beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000) | |
prev_scheduler.set_timesteps(steps) | |
#Preprocess image if it exists (img2img) | |
if init_image is not None: | |
init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS) | |
init_image = np.array(init_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 | |
init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2)) | |
#If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel | |
if init_image.shape[1] > 3: | |
init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:]) | |
#Move image to GPU | |
init_image = init_image.to(device) | |
#Encode image | |
with autocast(device): | |
init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215 | |
t_start = steps - int(steps * init_image_strength) | |
else: | |
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) | |
t_start = 0 | |
#Generate random normal noise | |
if fixed_starting_latent is None: | |
noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=unet.dtype) | |
if scheduler_str == 'ddim': | |
if init_image is not None: | |
raise notImplementedError | |
latent = scheduler.add_noise(init_latent, noise, | |
1000 - int(1000 * init_image_strength)).to(device) | |
else: | |
latent = noise | |
else: | |
latent = scheduler.add_noise(init_latent, noise, | |
t_start).to(device) | |
else: | |
latent = fixed_starting_latent | |
t_start = steps - int(steps * init_image_strength) | |
if prev_image is not None: | |
#Resize and prev_image for numpy b h w c -> torch b c h w | |
prev_image = prev_image.resize((width, height), resample=Image.Resampling.LANCZOS) | |
prev_image = np.array(prev_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 | |
prev_image = torch.from_numpy(prev_image[np.newaxis, ...].transpose(0, 3, 1, 2)) | |
#If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel | |
if prev_image.shape[1] > 3: | |
prev_image = prev_image[:, :3] * prev_image[:, 3:] + (1 - prev_image[:, 3:]) | |
#Move image to GPU | |
prev_image = prev_image.to(device) | |
#Encode image | |
with autocast(device): | |
prev_init_latent = vae.encode(prev_image).latent_dist.sample(generator=generator) * 0.18215 | |
t_start = steps - int(steps * init_image_strength) | |
prev_latent = prev_scheduler.add_noise(prev_init_latent, noise, t_start).to(device) | |
else: | |
prev_latent = None | |
#Process clip | |
with autocast(device): | |
tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) | |
embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state | |
tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) | |
embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state | |
#Process prompt editing | |
assert not ((prompt_edit is not None) and (prev_image is not None)) | |
if prompt_edit is not None: | |
tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) | |
embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state | |
init_attention_edit(tokens_conditional, tokens_conditional_edit) | |
elif prev_image is not None: | |
init_attention_edit(tokens_conditional, tokens_conditional) | |
init_attention_func() | |
init_attention_weights(prompt_edit_token_weights) | |
timesteps = scheduler.timesteps[t_start:] | |
# print(timesteps) | |
assert isinstance(guidance_scale, int) | |
num_cycles = 1 # guidance_scale + 1 | |
last_noise_preds = None | |
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): | |
t_index = t_start + i | |
latent_model_input = latent | |
if scheduler_str=='lms': | |
sigma = scheduler.sigmas[t_index] # last is first and first is last | |
latent_model_input = (latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) | |
else: | |
assert scheduler_str in ['ddim', 'pndm', 'ddpm'] | |
#Predict the unconditional noise residual | |
if len(t.shape) == 0: | |
t = t[None].to(unet.device) | |
noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional, | |
).sample | |
if prev_latent is not None: | |
prev_latent_model_input = prev_latent | |
prev_latent_model_input = (prev_latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) | |
prev_noise_pred_uncond = unet(prev_latent_model_input, t, | |
encoder_hidden_states=embedding_unconditional, | |
).sample | |
# noise_pred_uncond = unet(latent_model_input, t, | |
# encoder_hidden_states=embedding_unconditional)['sample'] | |
#Prepare the Cross-Attention layers | |
if prompt_edit is not None or prev_latent is not None: | |
save_last_tokens_attention() | |
save_last_self_attention() | |
else: | |
#Use weights on non-edited prompt when edit is None | |
use_last_tokens_attention_weights() | |
#Predict the conditional noise residual and save the cross-attention layer activations | |
if prev_latent is not None: | |
raise NotImplementedError # I totally lost track of what this is | |
prev_noise_pred_cond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_conditional, | |
).sample | |
else: | |
noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional, | |
).sample | |
#Edit the Cross-Attention layer activations | |
t_scale = t / scheduler.num_train_timesteps | |
if prompt_edit is not None or prev_latent is not None: | |
if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: | |
use_last_tokens_attention() | |
if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: | |
use_last_self_attention() | |
#Use weights on edited prompt | |
use_last_tokens_attention_weights() | |
#Predict the edited conditional noise residual using the cross-attention masks | |
if prompt_edit is not None: | |
noise_pred_cond = unet(latent_model_input, t, | |
encoder_hidden_states=embedding_conditional_edit).sample | |
#Perform guidance | |
# if i%(num_cycles)==0: # cycle_i+1==num_cycles: | |
""" | |
if cycle_i+1==num_cycles: | |
noise_pred = noise_pred_uncond | |
else: | |
noise_pred = noise_pred_cond - noise_pred_uncond | |
""" | |
if last_noise_preds is not None: | |
# print( (last_noise_preds[0]*noise_pred_uncond).sum(), (last_noise_preds[1]*noise_pred_cond).sum()) | |
# print(F.cosine_similarity(last_noise_preds[0].flatten(), noise_pred_uncond.flatten(), dim=0), | |
# F.cosine_similarity(last_noise_preds[1].flatten(), noise_pred_cond.flatten(), dim=0)) | |
last_grad= last_noise_preds[1] - last_noise_preds[0] | |
new_grad = noise_pred_cond - noise_pred_uncond | |
# print( F.cosine_similarity(last_grad.flatten(), new_grad.flatten(), dim=0)) | |
last_noise_preds = (noise_pred_uncond, noise_pred_cond) | |
use_cond_guidance = True | |
if use_cond_guidance: | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
else: | |
noise_pred = noise_pred_uncond | |
if clip_guidance is not None and t_scale >= clip_start and t_scale <= clip_end: | |
noise_pred, latent = new_cond_fn(latent, t, t_index, | |
embedding_conditional, noise_pred,clip_guidance, | |
clip_guidance_scale, | |
num_cutouts, | |
scheduler, unet,use_cutouts=True, | |
cut_power=cut_power) | |
if normalize_noise_pred: | |
noise_pred = noise_pred * noise_pred_uncond.norm() / noise_pred.norm() | |
if scheduler_str == 'ddim': | |
latent = forward_step(scheduler, noise_pred, | |
t, | |
latent).prev_sample | |
else: | |
latent = scheduler.step(noise_pred, | |
t_index, | |
latent).prev_sample | |
if prev_latent is not None: | |
prev_noise_pred = prev_noise_pred_uncond + guidance_scale * (prev_noise_pred_cond - prev_noise_pred_uncond) | |
prev_latent = prev_scheduler.step(prev_noise_pred, t_index, prev_latent).prev_sample | |
if one_pass: break | |
#scale and decode the image latents with vae | |
if return_latent: return latent | |
latent = latent / 0.18215 | |
image = vae.decode(latent.to(vae.dtype)).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
image, _ = check_safety(image) | |
image = (image[0] * 255).round().astype("uint8") | |
return Image.fromarray(image) | |
#################################### | |
#### HELPER FUNCTIONS FOR OUR METHOD ##### | |
def get_alpha_and_beta(t, scheduler): | |
# want to run this for both current and previous timnestep | |
if t.dtype==torch.long: | |
alpha = scheduler.alphas_cumprod[t] | |
return alpha, 1-alpha | |
if t<0: | |
return scheduler.final_alpha_cumprod, 1 - scheduler.final_alpha_cumprod | |
low = t.floor().long() | |
high = t.ceil().long() | |
rem = t - low | |
low_alpha = scheduler.alphas_cumprod[low] | |
high_alpha = scheduler.alphas_cumprod[high] | |
interpolated_alpha = low_alpha * rem + high_alpha * (1-rem) | |
interpolated_beta = 1 - interpolated_alpha | |
return interpolated_alpha, interpolated_beta | |
# A DDIM forward step function | |
def forward_step( | |
self, | |
model_output, | |
timestep: int, | |
sample, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
return_dict: bool = True, | |
use_double=False, | |
) : | |
if self.num_inference_steps is None: | |
raise ValueError( | |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
) | |
prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps | |
if timestep > self.timesteps.max(): | |
raise NotImplementedError("Need to double check what the overflow is") | |
alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) | |
alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) | |
alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) | |
first_term = (1./alpha_quotient) * sample | |
second_term = (1./alpha_quotient) * (beta_prod_t ** 0.5) * model_output | |
third_term = ((1 - alpha_prod_t_prev)**0.5) * model_output | |
return first_term - second_term + third_term | |
# A DDIM reverse step function, the inverse of above | |
def reverse_step( | |
self, | |
model_output, | |
timestep: int, | |
sample, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
return_dict: bool = True, | |
use_double=False, | |
) : | |
if self.num_inference_steps is None: | |
raise ValueError( | |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
) | |
prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps | |
if timestep > self.timesteps.max(): | |
raise NotImplementedError | |
else: | |
alpha_prod_t = self.alphas_cumprod[timestep] | |
alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) | |
alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) | |
alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) | |
first_term = alpha_quotient * sample | |
second_term = ((beta_prod_t)**0.5) * model_output | |
third_term = alpha_quotient * ((1 - alpha_prod_t_prev)**0.5) * model_output | |
return first_term + second_term - third_term | |
def latent_to_image(latent): | |
image = vae.decode(latent.to(vae.dtype)/0.18215).sample | |
image = prep_image_for_return(image) | |
return image | |
def prep_image_for_return(image): | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
image = (image[0] * 255).round().astype("uint8") | |
image = Image.fromarray(image) | |
return image | |
############################# | |
##### MAIN EDICT FUNCTION ####### | |
# Use EDICT_editing to perform calls | |
def coupled_stablediffusion(prompt="", | |
prompt_edit=None, | |
null_prompt='', | |
prompt_edit_token_weights=[], | |
prompt_edit_tokens_start=0.0, | |
prompt_edit_tokens_end=1.0, | |
prompt_edit_spatial_start=0.0, | |
prompt_edit_spatial_end=1.0, | |
guidance_scale=7.0, steps=50, | |
seed=1, width=512, height=512, | |
init_image=None, init_image_strength=1.0, | |
run_baseline=False, | |
use_lms=False, | |
leapfrog_steps=True, | |
reverse=False, | |
return_latents=False, | |
fixed_starting_latent=None, | |
beta_schedule='scaled_linear', | |
mix_weight=0.93): | |
#If seed is None, randomly select seed from 0 to 2^32-1 | |
if seed is None: seed = random.randrange(2**32 - 1) | |
generator = torch.cuda.manual_seed(seed) | |
def image_to_latent(im): | |
if isinstance(im, torch.Tensor): | |
# assume it's the latent | |
# used to avoid clipping new generation before inversion | |
init_latent = im.to(device) | |
else: | |
#Resize and transpose for numpy b h w c -> torch b c h w | |
im = im.resize((width, height), resample=Image.Resampling.LANCZOS) | |
im = np.array(im).astype(np_dtype) / 255.0 * 2.0 - 1.0 | |
# check if black and white | |
if len(im.shape) < 3: | |
im = np.stack([im for _ in range(3)], axis=2) # putting at end b/c channels | |
im = torch.from_numpy(im[np.newaxis, ...].transpose(0, 3, 1, 2)) | |
#If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel | |
if im.shape[1] > 3: | |
im = im[:, :3] * im[:, 3:] + (1 - im[:, 3:]) | |
#Move image to GPU | |
im = im.to(device) | |
#Encode image | |
if use_half_prec: | |
init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 | |
else: | |
with autocast(device): | |
init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 | |
return init_latent | |
assert not use_lms, "Can't invert LMS the same as DDIM" | |
if run_baseline: leapfrog_steps=False | |
#Change size to multiple of 64 to prevent size mismatches inside model | |
width = width - width % 64 | |
height = height - height % 64 | |
#Preprocess image if it exists (img2img) | |
if init_image is not None: | |
assert reverse # want to be performing deterministic noising | |
# can take either pair (output of generative process) or single image | |
if isinstance(init_image, list): | |
if isinstance(init_image[0], torch.Tensor): | |
init_latent = [t.clone() for t in init_image] | |
else: | |
init_latent = [image_to_latent(im) for im in init_image] | |
else: | |
init_latent = image_to_latent(init_image) | |
# this is t_start for forward, t_end for reverse | |
t_limit = steps - int(steps * init_image_strength) | |
else: | |
assert not reverse, 'Need image to reverse from' | |
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) | |
t_limit = 0 | |
if reverse: | |
latent = init_latent | |
else: | |
#Generate random normal noise | |
noise = torch.randn(init_latent.shape, | |
generator=generator, | |
device=device, | |
dtype=torch_dtype) | |
if fixed_starting_latent is None: | |
latent = noise | |
else: | |
if isinstance(fixed_starting_latent, list): | |
latent = [l.clone() for l in fixed_starting_latent] | |
else: | |
latent = fixed_starting_latent.clone() | |
t_limit = steps - int(steps * init_image_strength) | |
if isinstance(latent, list): # initializing from pair of images | |
latent_pair = latent | |
else: # initializing from noise | |
latent_pair = [latent.clone(), latent.clone()] | |
if steps==0: | |
if init_image is not None: | |
return image_to_latent(init_image) | |
else: | |
image = vae.decode(latent.to(vae.dtype) / 0.18215).sample | |
return prep_image_for_return(image) | |
#Set inference timesteps to scheduler | |
schedulers = [] | |
for i in range(2): | |
# num_raw_timesteps = max(1000, steps) | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, | |
beta_schedule=beta_schedule, | |
num_train_timesteps=1000, | |
clip_sample=False, | |
set_alpha_to_one=False) | |
scheduler.set_timesteps(steps) | |
schedulers.append(scheduler) | |
with autocast(device): | |
# CLIP Text Embeddings | |
tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", | |
max_length=clip_tokenizer.model_max_length, | |
truncation=True, return_tensors="pt", | |
return_overflowing_tokens=True) | |
embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state | |
tokens_conditional = clip_tokenizer(prompt, padding="max_length", | |
max_length=clip_tokenizer.model_max_length, | |
truncation=True, return_tensors="pt", | |
return_overflowing_tokens=True) | |
embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state | |
#Process prompt editing (if running Prompt-to-Prompt) | |
if prompt_edit is not None: | |
tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", | |
max_length=clip_tokenizer.model_max_length, | |
truncation=True, return_tensors="pt", | |
return_overflowing_tokens=True) | |
embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state | |
init_attention_edit(tokens_conditional, tokens_conditional_edit) | |
init_attention_func() | |
init_attention_weights(prompt_edit_token_weights) | |
timesteps = schedulers[0].timesteps[t_limit:] | |
if reverse: timesteps = timesteps.flip(0) | |
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): | |
t_scale = t / schedulers[0].num_train_timesteps | |
if (reverse) and (not run_baseline): | |
# Reverse mixing layer | |
new_latents = [l.clone() for l in latent_pair] | |
new_latents[1] = (new_latents[1].clone() - (1-mix_weight)*new_latents[0].clone()) / mix_weight | |
new_latents[0] = (new_latents[0].clone() - (1-mix_weight)*new_latents[1].clone()) / mix_weight | |
latent_pair = new_latents | |
# alternate EDICT steps | |
for latent_i in range(2): | |
if run_baseline and latent_i==1: continue # just have one sequence for baseline | |
# this modifies latent_pair[i] while using | |
# latent_pair[(i+1)%2] | |
if reverse and (not run_baseline): | |
if leapfrog_steps: | |
# what i would be from going other way | |
orig_i = len(timesteps) - (i+1) | |
offset = (orig_i+1) % 2 | |
latent_i = (latent_i + offset) % 2 | |
else: | |
# Do 1 then 0 | |
latent_i = (latent_i+1)%2 | |
else: | |
if leapfrog_steps: | |
offset = i%2 | |
latent_i = (latent_i + offset) % 2 | |
latent_j = ((latent_i+1) % 2) if not run_baseline else latent_i | |
latent_model_input = latent_pair[latent_j] | |
latent_base = latent_pair[latent_i] | |
#Predict the unconditional noise residual | |
noise_pred_uncond = unet(latent_model_input, t, | |
encoder_hidden_states=embedding_unconditional).sample | |
#Prepare the Cross-Attention layers | |
if prompt_edit is not None: | |
save_last_tokens_attention() | |
save_last_self_attention() | |
else: | |
#Use weights on non-edited prompt when edit is None | |
use_last_tokens_attention_weights() | |
#Predict the conditional noise residual and save the cross-attention layer activations | |
noise_pred_cond = unet(latent_model_input, t, | |
encoder_hidden_states=embedding_conditional).sample | |
#Edit the Cross-Attention layer activations | |
if prompt_edit is not None: | |
t_scale = t / schedulers[0].num_train_timesteps | |
if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: | |
use_last_tokens_attention() | |
if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: | |
use_last_self_attention() | |
#Use weights on edited prompt | |
use_last_tokens_attention_weights() | |
#Predict the edited conditional noise residual using the cross-attention masks | |
noise_pred_cond = unet(latent_model_input, | |
t, | |
encoder_hidden_states=embedding_conditional_edit).sample | |
#Perform guidance | |
grad = (noise_pred_cond - noise_pred_uncond) | |
noise_pred = noise_pred_uncond + guidance_scale * grad | |
step_call = reverse_step if reverse else forward_step | |
new_latent = step_call(schedulers[latent_i], | |
noise_pred, | |
t, | |
latent_base)# .prev_sample | |
new_latent = new_latent.to(latent_base.dtype) | |
latent_pair[latent_i] = new_latent | |
if (not reverse) and (not run_baseline): | |
# Mixing layer (contraction) during generative process | |
new_latents = [l.clone() for l in latent_pair] | |
new_latents[0] = (mix_weight*new_latents[0] + (1-mix_weight)*new_latents[1]).clone() | |
new_latents[1] = ((1-mix_weight)*new_latents[0] + (mix_weight)*new_latents[1]).clone() | |
latent_pair = new_latents | |
#scale and decode the image latents with vae, can return latents instead of images | |
if reverse or return_latents: | |
results = [latent_pair] | |
return results if len(results)>1 else results[0] | |
# decode latents to iamges | |
images = [] | |
for latent_i in range(2): | |
latent = latent_pair[latent_i] / 0.18215 | |
image = vae.decode(latent.to(vae.dtype)).sample | |
images.append(image) | |
# Return images | |
return_arr = [] | |
for image in images: | |
image = prep_image_for_return(image) | |
return_arr.append(image) | |
results = [return_arr] | |
return results if len(results)>1 else results[0] | |