omerbartal
commited on
Commit
•
7b90989
1
Parent(s):
2104e5b
Upload region_control.py
Browse files- region_control.py +208 -0
region_control.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
2 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
|
3 |
+
|
4 |
+
# suppress partial model loading warning
|
5 |
+
logging.set_verbosity_error()
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torchvision.transforms as T
|
10 |
+
import argparse
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
|
15 |
+
def seed_everything(seed):
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed(seed)
|
18 |
+
# torch.backends.cudnn.deterministic = True
|
19 |
+
# torch.backends.cudnn.benchmark = True
|
20 |
+
|
21 |
+
|
22 |
+
def get_views(panorama_height, panorama_width, window_size=64, stride=8):
|
23 |
+
panorama_height /= 8
|
24 |
+
panorama_width /= 8
|
25 |
+
num_blocks_height = (panorama_height - window_size) // stride + 1
|
26 |
+
num_blocks_width = (panorama_width - window_size) // stride + 1
|
27 |
+
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
28 |
+
views = []
|
29 |
+
for i in range(total_num_blocks):
|
30 |
+
h_start = int((i // num_blocks_width) * stride)
|
31 |
+
h_end = h_start + window_size
|
32 |
+
w_start = int((i % num_blocks_width) * stride)
|
33 |
+
w_end = w_start + window_size
|
34 |
+
views.append((h_start, h_end, w_start, w_end))
|
35 |
+
return views
|
36 |
+
|
37 |
+
|
38 |
+
class MultiDiffusion(nn.Module):
|
39 |
+
def __init__(self, device, sd_version='2.0', hf_key=None):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.device = device
|
43 |
+
self.sd_version = sd_version
|
44 |
+
|
45 |
+
print(f'[INFO] loading stable diffusion...')
|
46 |
+
if hf_key is not None:
|
47 |
+
print(f'[INFO] using hugging face custom model key: {hf_key}')
|
48 |
+
model_key = hf_key
|
49 |
+
elif self.sd_version == '2.1':
|
50 |
+
model_key = "stabilityai/stable-diffusion-2-1-base"
|
51 |
+
elif self.sd_version == '2.0':
|
52 |
+
model_key = "stabilityai/stable-diffusion-2-base"
|
53 |
+
elif self.sd_version == '1.5':
|
54 |
+
model_key = "runwayml/stable-diffusion-v1-5"
|
55 |
+
else:
|
56 |
+
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
|
57 |
+
|
58 |
+
# Create model
|
59 |
+
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
|
60 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
|
61 |
+
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
|
62 |
+
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
|
63 |
+
|
64 |
+
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
65 |
+
|
66 |
+
print(f'[INFO] loaded stable diffusion!')
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def get_random_background(self, n_samples):
|
70 |
+
# sample random background with a constant rgb value
|
71 |
+
backgrounds = torch.rand(n_samples, 3, device=self.device)[:, :, None, None].repeat(1, 1, 512, 512)
|
72 |
+
return torch.cat([self.encode_imgs(bg.unsqueeze(0)) for bg in backgrounds])
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def get_text_embeds(self, prompt, negative_prompt):
|
76 |
+
# Tokenize text and get embeddings
|
77 |
+
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
|
78 |
+
truncation=True, return_tensors='pt')
|
79 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
80 |
+
|
81 |
+
# Do the same for unconditional embeddings
|
82 |
+
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
|
83 |
+
return_tensors='pt')
|
84 |
+
|
85 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
86 |
+
|
87 |
+
# Cat for final embeddings
|
88 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
89 |
+
return text_embeddings
|
90 |
+
|
91 |
+
@torch.no_grad()
|
92 |
+
def encode_imgs(self, imgs):
|
93 |
+
imgs = 2 * imgs - 1
|
94 |
+
posterior = self.vae.encode(imgs).latent_dist
|
95 |
+
latents = posterior.sample() * 0.18215
|
96 |
+
return latents
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def decode_latents(self, latents):
|
100 |
+
latents = 1 / 0.18215 * latents
|
101 |
+
imgs = self.vae.decode(latents).sample
|
102 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
103 |
+
return imgs
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
|
107 |
+
guidance_scale=7.5, bootstrapping=20):
|
108 |
+
|
109 |
+
# get bootstrapping backgrounds
|
110 |
+
# can move this outside of the function to speed up generation. i.e., calculate in init
|
111 |
+
bootstrapping_backgrounds = self.get_random_background(bootstrapping)
|
112 |
+
|
113 |
+
# Prompts -> text embeds
|
114 |
+
text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768]
|
115 |
+
|
116 |
+
# Define panorama grid and get views
|
117 |
+
latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
118 |
+
noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1)
|
119 |
+
views = get_views(height, width)
|
120 |
+
count = torch.zeros_like(latent)
|
121 |
+
value = torch.zeros_like(latent)
|
122 |
+
|
123 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
124 |
+
|
125 |
+
with torch.autocast('cuda'):
|
126 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
127 |
+
count.zero_()
|
128 |
+
value.zero_()
|
129 |
+
|
130 |
+
for h_start, h_end, w_start, w_end in views:
|
131 |
+
masks_view = masks[:, :, h_start:h_end, w_start:w_end]
|
132 |
+
latent_view = latent[:, :, h_start:h_end, w_start:w_end].repeat(len(prompts), 1, 1, 1)
|
133 |
+
if i < bootstrapping:
|
134 |
+
bg = bootstrapping_backgrounds[torch.randint(0, bootstrapping, (len(prompts) - 1,))]
|
135 |
+
bg = self.scheduler.add_noise(bg, noise[:, :, h_start:h_end, w_start:w_end], t)
|
136 |
+
latent_view[1:] = latent_view[1:] * masks_view[1:] + bg * (1 - masks_view[1:])
|
137 |
+
|
138 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
139 |
+
latent_model_input = torch.cat([latent_view] * 2)
|
140 |
+
|
141 |
+
# predict the noise residual
|
142 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
|
143 |
+
|
144 |
+
# perform guidance
|
145 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
146 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
147 |
+
|
148 |
+
# compute the denoising step with the reference model
|
149 |
+
latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
|
150 |
+
|
151 |
+
value[:, :, h_start:h_end, w_start:w_end] += (latents_view_denoised * masks_view).sum(dim=0,
|
152 |
+
keepdims=True)
|
153 |
+
count[:, :, h_start:h_end, w_start:w_end] += masks_view.sum(dim=0, keepdims=True)
|
154 |
+
|
155 |
+
# take the MultiDiffusion step
|
156 |
+
latent = torch.where(count > 0, value / count, value)
|
157 |
+
|
158 |
+
# Img latents -> imgs
|
159 |
+
imgs = self.decode_latents(latent) # [1, 3, 512, 512]
|
160 |
+
img = T.ToPILImage()(imgs[0].cpu())
|
161 |
+
return img
|
162 |
+
|
163 |
+
|
164 |
+
def preprocess_mask(mask_path, h, w, device):
|
165 |
+
mask = np.array(Image.open(mask_path).convert("L"))
|
166 |
+
mask = mask.astype(np.float32) / 255.0
|
167 |
+
mask = mask[None, None]
|
168 |
+
mask[mask < 0.5] = 0
|
169 |
+
mask[mask >= 0.5] = 1
|
170 |
+
mask = torch.from_numpy(mask).to(device)
|
171 |
+
mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest')
|
172 |
+
return mask
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == '__main__':
|
176 |
+
parser = argparse.ArgumentParser()
|
177 |
+
parser.add_argument('--mask_paths', type=list)
|
178 |
+
parser.add_argument('--bg_prompt', type=str)
|
179 |
+
parser.add_argument('--bg_negative', type=str) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
|
180 |
+
parser.add_argument('--fg_prompts', type=list)
|
181 |
+
parser.add_argument('--fg_negative', type=list) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
|
182 |
+
parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'],
|
183 |
+
help="stable diffusion version")
|
184 |
+
parser.add_argument('--H', type=int, default=768)
|
185 |
+
parser.add_argument('--W', type=int, default=512)
|
186 |
+
parser.add_argument('--seed', type=int, default=0)
|
187 |
+
parser.add_argument('--steps', type=int, default=50)
|
188 |
+
parser.add_argument('--bootstrapping', type=int, default=20)
|
189 |
+
opt = parser.parse_args()
|
190 |
+
|
191 |
+
seed_everything(opt.seed)
|
192 |
+
|
193 |
+
device = torch.device('cuda')
|
194 |
+
|
195 |
+
sd = MultiDiffusion(device, opt.sd_version)
|
196 |
+
|
197 |
+
fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.mask_paths])
|
198 |
+
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
|
199 |
+
bg_mask[bg_mask < 0] = 0
|
200 |
+
masks = torch.cat([bg_mask, fg_masks])
|
201 |
+
|
202 |
+
prompts = [opt.bg_prompt] + opt.fg_prompts
|
203 |
+
neg_prompts = [opt.bg_negative] + opt.fg_negative
|
204 |
+
|
205 |
+
img = sd.generate(masks, prompts, neg_prompts, opt.H, opt.W, opt.steps, bootstrapping=opt.bootstrapping)
|
206 |
+
|
207 |
+
# save image
|
208 |
+
img.save('out.png')
|