Alpha-Romeo
commited on
Commit
•
7d28380
1
Parent(s):
b0afe49
add cond stage to trainable parameters
Browse files
ControlNet/ControlNet.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
ControlNet/cldm/cldm.py
CHANGED
@@ -2,6 +2,8 @@ import einops
|
|
2 |
import torch
|
3 |
import torch as th
|
4 |
import torch.nn as nn
|
|
|
|
|
5 |
from torchvision.transforms import Resize
|
6 |
|
7 |
from ldm.modules.diffusionmodules.util import (
|
@@ -305,12 +307,15 @@ class ControlNet(nn.Module):
|
|
305 |
|
306 |
class ControlInpaintLDM(LatentDiffusion):
|
307 |
|
308 |
-
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
|
309 |
super().__init__(*args, **kwargs)
|
310 |
self.control_model = instantiate_from_config(control_stage_config)
|
311 |
self.control_key = control_key
|
312 |
self.only_mid_control = only_mid_control
|
313 |
self.control_scales = [1.0] * 13
|
|
|
|
|
|
|
314 |
|
315 |
@torch.no_grad()
|
316 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
@@ -380,6 +385,7 @@ class ControlInpaintLDM(LatentDiffusion):
|
|
380 |
|
381 |
if self.cond_stage_trainable:
|
382 |
c = self.get_learned_conditioning(c)
|
|
|
383 |
|
384 |
if sample:
|
385 |
# get denoise row
|
@@ -412,15 +418,38 @@ class ControlInpaintLDM(LatentDiffusion):
|
|
412 |
shape = (self.channels, h // 8, w // 8)
|
413 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
414 |
return samples, intermediates
|
415 |
-
|
416 |
def configure_optimizers(self):
|
417 |
lr = self.learning_rate
|
418 |
params = list(self.control_model.parameters())
|
|
|
|
|
|
|
|
|
|
|
419 |
if not self.sd_locked:
|
420 |
params += list(self.model.diffusion_model.output_blocks.parameters())
|
421 |
params += list(self.model.diffusion_model.out.parameters())
|
422 |
-
opt = torch.optim.AdamW(params, lr=lr)
|
|
|
|
|
423 |
return opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
def low_vram_shift(self, is_diffusing):
|
426 |
if is_diffusing:
|
|
|
2 |
import torch
|
3 |
import torch as th
|
4 |
import torch.nn as nn
|
5 |
+
import random
|
6 |
+
import bitsandbytes as bnb
|
7 |
from torchvision.transforms import Resize
|
8 |
|
9 |
from ldm.modules.diffusionmodules.util import (
|
|
|
307 |
|
308 |
class ControlInpaintLDM(LatentDiffusion):
|
309 |
|
310 |
+
def __init__(self, control_stage_config, control_key, u_cond_percent, only_mid_control, *args, **kwargs):
|
311 |
super().__init__(*args, **kwargs)
|
312 |
self.control_model = instantiate_from_config(control_stage_config)
|
313 |
self.control_key = control_key
|
314 |
self.only_mid_control = only_mid_control
|
315 |
self.control_scales = [1.0] * 13
|
316 |
+
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
|
317 |
+
self.proj_out=nn.Linear(1024, 768)
|
318 |
+
self.u_cond_percent=u_cond_percent
|
319 |
|
320 |
@torch.no_grad()
|
321 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
|
|
385 |
|
386 |
if self.cond_stage_trainable:
|
387 |
c = self.get_learned_conditioning(c)
|
388 |
+
c = self.proj_out(c)
|
389 |
|
390 |
if sample:
|
391 |
# get denoise row
|
|
|
418 |
shape = (self.channels, h // 8, w // 8)
|
419 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
420 |
return samples, intermediates
|
421 |
+
|
422 |
def configure_optimizers(self):
|
423 |
lr = self.learning_rate
|
424 |
params = list(self.control_model.parameters())
|
425 |
+
if self.cond_stage_trainable:
|
426 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
427 |
+
params = params + list(self.cond_stage_model.final_ln.parameters())+list(self.cond_stage_model.mapper.parameters())+list(self.proj_out.parameters())
|
428 |
+
self.params = params
|
429 |
+
self.params_with_white=params + list(self.learnable_vector)
|
430 |
if not self.sd_locked:
|
431 |
params += list(self.model.diffusion_model.output_blocks.parameters())
|
432 |
params += list(self.model.diffusion_model.out.parameters())
|
433 |
+
#opt = torch.optim.AdamW(params, lr=lr)
|
434 |
+
opt = bnb.optim.Adam8bit(params,lr=lr)
|
435 |
+
self.opt=opt
|
436 |
return opt
|
437 |
+
|
438 |
+
def forward(self, x, c, *args, **kwargs):
|
439 |
+
self.opt.params=self.params
|
440 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
441 |
+
if self.model.conditioning_key is not None:
|
442 |
+
assert c is not None
|
443 |
+
if self.cond_stage_trainable:
|
444 |
+
c['c_crossattn'][0] = self.get_learned_conditioning(c['c_crossattn'][0])
|
445 |
+
c['c_crossattn'][0] = self.proj_out(c['c_crossattn'][0])
|
446 |
+
u_cond_prop=random.uniform(0, 1)
|
447 |
+
if u_cond_prop<self.u_cond_percent:
|
448 |
+
self.opt.params=self.params_with_white
|
449 |
+
c['c_crossattn'][0] = self.learnable_vector.repeat(x.shape[0],1,1)
|
450 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
451 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
452 |
+
|
453 |
|
454 |
def low_vram_shift(self, is_diffusing):
|
455 |
if is_diffusing:
|
ControlNet/environment.yaml
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
name: control
|
2 |
channels:
|
3 |
- pytorch
|
|
|
4 |
- defaults
|
5 |
dependencies:
|
6 |
- python=3.8.5
|
7 |
- pip=20.3
|
8 |
- cudatoolkit=11.3
|
9 |
-
- pytorch=1.
|
10 |
- torchvision=0.13.1
|
11 |
- numpy=1.23.1
|
12 |
- pip:
|
@@ -36,4 +37,5 @@ dependencies:
|
|
36 |
- ipdb==0.13.11
|
37 |
- ipython==8.11.0
|
38 |
- ipykernel==6.21.2
|
|
|
39 |
|
|
|
1 |
name: control
|
2 |
channels:
|
3 |
- pytorch
|
4 |
+
- anaconda
|
5 |
- defaults
|
6 |
dependencies:
|
7 |
- python=3.8.5
|
8 |
- pip=20.3
|
9 |
- cudatoolkit=11.3
|
10 |
+
- pytorch=1.13.1
|
11 |
- torchvision=0.13.1
|
12 |
- numpy=1.23.1
|
13 |
- pip:
|
|
|
37 |
- ipdb==0.13.11
|
38 |
- ipython==8.11.0
|
39 |
- ipykernel==6.21.2
|
40 |
+
- bitsandbytes==0.37.1
|
41 |
|
ControlNet/ldm/models/diffusion/ddpm.py
CHANGED
@@ -552,8 +552,6 @@ class LatentDiffusion(DDPM):
|
|
552 |
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
|
553 |
ignore_keys = kwargs.pop("ignore_keys", [])
|
554 |
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
555 |
-
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
|
556 |
-
self.u_cond_percent=u_cond_percent
|
557 |
self.concat_mode = concat_mode
|
558 |
self.cond_stage_trainable = cond_stage_trainable
|
559 |
self.cond_stage_key = cond_stage_key
|
|
|
552 |
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
|
553 |
ignore_keys = kwargs.pop("ignore_keys", [])
|
554 |
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
|
|
|
|
555 |
self.concat_mode = concat_mode
|
556 |
self.cond_stage_trainable = cond_stage_trainable
|
557 |
self.cond_stage_key = cond_stage_key
|
ControlNet/ldm/modules/encoders/modules.py
CHANGED
@@ -137,7 +137,6 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
|
|
137 |
super().__init__()
|
138 |
self.transformer = CLIPVisionModel.from_pretrained(version)
|
139 |
self.final_ln = LayerNorm(1024)
|
140 |
-
self.proj_out=nn.Linear(1024, 768)
|
141 |
self.mapper = Transformer(
|
142 |
1,
|
143 |
1024,
|
@@ -162,7 +161,6 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
|
|
162 |
z = z.unsqueeze(1)
|
163 |
z = self.mapper(z)
|
164 |
z = self.final_ln(z)
|
165 |
-
z = self.proj_out(z)
|
166 |
return z
|
167 |
|
168 |
def encode(self, image):
|
|
|
137 |
super().__init__()
|
138 |
self.transformer = CLIPVisionModel.from_pretrained(version)
|
139 |
self.final_ln = LayerNorm(1024)
|
|
|
140 |
self.mapper = Transformer(
|
141 |
1,
|
142 |
1024,
|
|
|
161 |
z = z.unsqueeze(1)
|
162 |
z = self.mapper(z)
|
163 |
z = self.final_ln(z)
|
|
|
164 |
return z
|
165 |
|
166 |
def encode(self, image):
|