Doubiiu commited on
Commit
c2d6308
1 Parent(s): 93e8f1d

Upload 27 files

Browse files
lvdm/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
lvdm/models/__pycache__/ddpm3d.cpython-39.pyc CHANGED
Binary files a/lvdm/models/__pycache__/ddpm3d.cpython-39.pyc and b/lvdm/models/__pycache__/ddpm3d.cpython-39.pyc differ
 
lvdm/models/__pycache__/utils_diffusion.cpython-39.pyc CHANGED
Binary files a/lvdm/models/__pycache__/utils_diffusion.cpython-39.pyc and b/lvdm/models/__pycache__/utils_diffusion.cpython-39.pyc differ
 
lvdm/models/ddpm3d.py CHANGED
@@ -20,7 +20,7 @@ import pytorch_lightning as pl
20
  from utils.utils import instantiate_from_config
21
  from lvdm.ema import LitEma
22
  from lvdm.distributions import DiagonalGaussianDistribution
23
- from lvdm.models.utils_diffusion import make_beta_schedule
24
  from lvdm.basics import disabled_train
25
  from lvdm.common import (
26
  extract_into_tensor,
@@ -63,6 +63,7 @@ class DDPM(pl.LightningModule):
63
  use_positional_encodings=False,
64
  learn_logvar=False,
65
  logvar_init=0.,
 
66
  ):
67
  super().__init__()
68
  assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
@@ -81,6 +82,7 @@ class DDPM(pl.LightningModule):
81
  self.model = DiffusionWrapper(unet_config, conditioning_key)
82
  #count_params(self.model, verbose=True)
83
  self.use_ema = use_ema
 
84
  if self.use_ema:
85
  self.model_ema = LitEma(self.model)
86
  mainlogger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
@@ -115,6 +117,9 @@ class DDPM(pl.LightningModule):
115
  else:
116
  betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
117
  cosine_s=cosine_s)
 
 
 
118
  alphas = 1. - betas
119
  alphas_cumprod = np.cumprod(alphas, axis=0)
120
  alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
@@ -135,8 +140,13 @@ class DDPM(pl.LightningModule):
135
  self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
136
  self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
137
  self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
138
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
139
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
 
 
 
 
 
140
 
141
  # calculations for posterior q(x_{t-1} | x_t, x_0)
142
  posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
@@ -365,6 +375,8 @@ class LatentDiffusion(DDPM):
365
  base_scale=0.7,
366
  turning_step=400,
367
  loop_video=False,
 
 
368
  *args, **kwargs):
369
  self.num_timesteps_cond = default(num_timesteps_cond, 1)
370
  self.scale_by_std = scale_by_std
@@ -380,6 +392,8 @@ class LatentDiffusion(DDPM):
380
  self.noise_strength = noise_strength
381
  self.use_dynamic_rescale = use_dynamic_rescale
382
  self.loop_video = loop_video
 
 
383
  try:
384
  self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
385
  except:
@@ -470,9 +484,18 @@ class LatentDiffusion(DDPM):
470
  else:
471
  reshape_back = False
472
 
473
- encoder_posterior = self.first_stage_model.encode(x)
474
- results = self.get_first_stage_encoding(encoder_posterior).detach()
475
-
 
 
 
 
 
 
 
 
 
476
  if reshape_back:
477
  results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
478
 
@@ -486,10 +509,17 @@ class LatentDiffusion(DDPM):
486
  else:
487
  reshape_back = False
488
 
489
- z = 1. / self.scale_factor * z
 
 
 
 
 
 
 
 
 
490
 
491
- results = self.first_stage_model.decode(z, **kwargs)
492
-
493
  if reshape_back:
494
  results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
495
  return results
 
20
  from utils.utils import instantiate_from_config
21
  from lvdm.ema import LitEma
22
  from lvdm.distributions import DiagonalGaussianDistribution
23
+ from lvdm.models.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr
24
  from lvdm.basics import disabled_train
25
  from lvdm.common import (
26
  extract_into_tensor,
 
63
  use_positional_encodings=False,
64
  learn_logvar=False,
65
  logvar_init=0.,
66
+ rescale_betas_zero_snr=False,
67
  ):
68
  super().__init__()
69
  assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
 
82
  self.model = DiffusionWrapper(unet_config, conditioning_key)
83
  #count_params(self.model, verbose=True)
84
  self.use_ema = use_ema
85
+ self.rescale_betas_zero_snr = rescale_betas_zero_snr
86
  if self.use_ema:
87
  self.model_ema = LitEma(self.model)
88
  mainlogger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
 
117
  else:
118
  betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
119
  cosine_s=cosine_s)
120
+ if self.rescale_betas_zero_snr:
121
+ betas = rescale_zero_terminal_snr(betas)
122
+
123
  alphas = 1. - betas
124
  alphas_cumprod = np.cumprod(alphas, axis=0)
125
  alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
 
140
  self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
141
  self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
142
  self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
143
+
144
+ if self.parameterization != 'v':
145
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
146
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
147
+ else:
148
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.zeros_like(to_torch(alphas_cumprod)))
149
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.zeros_like(to_torch(alphas_cumprod)))
150
 
151
  # calculations for posterior q(x_{t-1} | x_t, x_0)
152
  posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
 
375
  base_scale=0.7,
376
  turning_step=400,
377
  loop_video=False,
378
+ fps_condition_type='fs',
379
+ perframe_ae=False,
380
  *args, **kwargs):
381
  self.num_timesteps_cond = default(num_timesteps_cond, 1)
382
  self.scale_by_std = scale_by_std
 
392
  self.noise_strength = noise_strength
393
  self.use_dynamic_rescale = use_dynamic_rescale
394
  self.loop_video = loop_video
395
+ self.fps_condition_type = fps_condition_type
396
+ self.perframe_ae = perframe_ae
397
  try:
398
  self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
399
  except:
 
484
  else:
485
  reshape_back = False
486
 
487
+ ## consume more GPU memory but faster
488
+ if not self.perframe_ae:
489
+ encoder_posterior = self.first_stage_model.encode(x)
490
+ results = self.get_first_stage_encoding(encoder_posterior).detach()
491
+ else: ## consume less GPU memory but slower
492
+ results = []
493
+ for index in range(x.shape[0]):
494
+ frame_batch = self.first_stage_model.encode(x[index:index+1,:,:,:])
495
+ frame_result = self.get_first_stage_encoding(frame_batch).detach()
496
+ results.append(frame_result)
497
+ results = torch.cat(results, dim=0)
498
+
499
  if reshape_back:
500
  results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
501
 
 
509
  else:
510
  reshape_back = False
511
 
512
+ if not self.perframe_ae:
513
+ z = 1. / self.scale_factor * z
514
+ results = self.first_stage_model.decode(z, **kwargs)
515
+ else:
516
+ results = []
517
+ for index in range(z.shape[0]):
518
+ frame_z = 1. / self.scale_factor * z[index:index+1,:,:,:]
519
+ frame_result = self.first_stage_model.decode(frame_z, **kwargs)
520
+ results.append(frame_result)
521
+ results = torch.cat(results, dim=0)
522
 
 
 
523
  if reshape_back:
524
  results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
525
  return results
lvdm/models/samplers/__pycache__/ddim.cpython-39.pyc CHANGED
Binary files a/lvdm/models/samplers/__pycache__/ddim.cpython-39.pyc and b/lvdm/models/samplers/__pycache__/ddim.cpython-39.pyc differ
 
lvdm/models/samplers/__pycache__/ddim_multiplecond.cpython-39.pyc CHANGED
Binary files a/lvdm/models/samplers/__pycache__/ddim_multiplecond.cpython-39.pyc and b/lvdm/models/samplers/__pycache__/ddim_multiplecond.cpython-39.pyc differ
 
lvdm/models/samplers/ddim.py CHANGED
@@ -1,9 +1,10 @@
1
  import numpy as np
2
  from tqdm import tqdm
3
  import torch
4
- from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps
5
  from lvdm.common import noise_like
6
  from lvdm.common import extract_into_tensor
 
7
 
8
 
9
  class DDIMSampler(object):
@@ -80,7 +81,8 @@ class DDIMSampler(object):
80
  unconditional_conditioning=None,
81
  precision=None,
82
  fs=None,
83
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
 
84
  **kwargs
85
  ):
86
 
@@ -98,7 +100,7 @@ class DDIMSampler(object):
98
  if conditioning.shape[0] != batch_size:
99
  print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
100
 
101
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose)
102
 
103
  # make shape
104
  if len(shape) == 3:
@@ -107,8 +109,7 @@ class DDIMSampler(object):
107
  elif len(shape) == 4:
108
  C, T, H, W = shape
109
  size = (batch_size, C, T, H, W)
110
- # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
111
-
112
  samples, intermediates = self.ddim_sampling(conditioning, size,
113
  callback=callback,
114
  img_callback=img_callback,
@@ -126,6 +127,7 @@ class DDIMSampler(object):
126
  verbose=verbose,
127
  precision=precision,
128
  fs=fs,
 
129
  **kwargs)
130
  return samples, intermediates
131
 
@@ -135,7 +137,7 @@ class DDIMSampler(object):
135
  callback=None, timesteps=None, quantize_denoised=False,
136
  mask=None, x0=None, img_callback=None, log_every_t=100,
137
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
138
- unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,
139
  **kwargs):
140
  device = self.model.betas.device
141
  b = shape[0]
@@ -162,6 +164,8 @@ class DDIMSampler(object):
162
  iterator = time_range
163
 
164
  clean_cond = kwargs.pop("clean_cond", False)
 
 
165
  for i, step in enumerate(iterator):
166
  index = total_steps - i - 1
167
  ts = torch.full((b,), step, device=device, dtype=torch.long)
@@ -174,18 +178,20 @@ class DDIMSampler(object):
174
  else:
175
  img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? <ddim inversion>
176
  img = img_orig * mask + (1. - mask) * img # keep original & modify use img
177
-
 
 
 
178
  outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
179
  quantize_denoised=quantize_denoised, temperature=temperature,
180
  noise_dropout=noise_dropout, score_corrector=score_corrector,
181
  corrector_kwargs=corrector_kwargs,
182
  unconditional_guidance_scale=unconditional_guidance_scale,
183
  unconditional_conditioning=unconditional_conditioning,
184
- mask=mask,x0=x0,fs=fs,
185
  **kwargs)
186
 
187
 
188
-
189
  img, pred_x0 = outs
190
  if callback: callback(i)
191
  if img_callback: img_callback(pred_x0, i)
@@ -200,7 +206,7 @@ class DDIMSampler(object):
200
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
201
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
202
  unconditional_guidance_scale=1., unconditional_conditioning=None,
203
- uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None, **kwargs):
204
  b, *_, device = *x.shape, x.device
205
  if x.dim() == 5:
206
  is_video = True
@@ -208,28 +214,33 @@ class DDIMSampler(object):
208
  is_video = False
209
 
210
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
211
- e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
212
  else:
213
- ### with unconditional condition
214
  if isinstance(c, torch.Tensor) or isinstance(c, dict):
215
  e_t_cond = self.model.apply_model(x, t, c, **kwargs)
216
  e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
217
  else:
218
  raise NotImplementedError
219
 
220
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond)
221
 
222
- if self.model.parameterization == "v":
223
- e_t = self.model.predict_eps_from_z_and_v(x, t, e_t)
224
 
 
 
 
 
225
 
226
  if score_corrector is not None:
227
- assert self.model.parameterization == "eps"
228
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
229
 
230
  alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
231
  alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
232
  sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
 
233
  sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
234
  # select parameters corresponding to the currently considered timestep
235
 
@@ -246,7 +257,7 @@ class DDIMSampler(object):
246
  if self.model.parameterization != "v":
247
  pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
248
  else:
249
- pred_x0 = self.model.predict_start_from_z_and_v(x, t, e_t)
250
 
251
  if self.model.use_dynamic_rescale:
252
  scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
 
1
  import numpy as np
2
  from tqdm import tqdm
3
  import torch
4
+ from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
5
  from lvdm.common import noise_like
6
  from lvdm.common import extract_into_tensor
7
+ import copy
8
 
9
 
10
  class DDIMSampler(object):
 
81
  unconditional_conditioning=None,
82
  precision=None,
83
  fs=None,
84
+ timestep_spacing='uniform', #uniform_trailing for starting from last timestep
85
+ guidance_rescale=0.0,
86
  **kwargs
87
  ):
88
 
 
100
  if conditioning.shape[0] != batch_size:
101
  print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
102
 
103
+ self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose)
104
 
105
  # make shape
106
  if len(shape) == 3:
 
109
  elif len(shape) == 4:
110
  C, T, H, W = shape
111
  size = (batch_size, C, T, H, W)
112
+
 
113
  samples, intermediates = self.ddim_sampling(conditioning, size,
114
  callback=callback,
115
  img_callback=img_callback,
 
127
  verbose=verbose,
128
  precision=precision,
129
  fs=fs,
130
+ guidance_rescale=guidance_rescale,
131
  **kwargs)
132
  return samples, intermediates
133
 
 
137
  callback=None, timesteps=None, quantize_denoised=False,
138
  mask=None, x0=None, img_callback=None, log_every_t=100,
139
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
140
+ unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0,
141
  **kwargs):
142
  device = self.model.betas.device
143
  b = shape[0]
 
164
  iterator = time_range
165
 
166
  clean_cond = kwargs.pop("clean_cond", False)
167
+
168
+ # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning)
169
  for i, step in enumerate(iterator):
170
  index = total_steps - i - 1
171
  ts = torch.full((b,), step, device=device, dtype=torch.long)
 
178
  else:
179
  img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? <ddim inversion>
180
  img = img_orig * mask + (1. - mask) * img # keep original & modify use img
181
+
182
+
183
+
184
+
185
  outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
186
  quantize_denoised=quantize_denoised, temperature=temperature,
187
  noise_dropout=noise_dropout, score_corrector=score_corrector,
188
  corrector_kwargs=corrector_kwargs,
189
  unconditional_guidance_scale=unconditional_guidance_scale,
190
  unconditional_conditioning=unconditional_conditioning,
191
+ mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale,
192
  **kwargs)
193
 
194
 
 
195
  img, pred_x0 = outs
196
  if callback: callback(i)
197
  if img_callback: img_callback(pred_x0, i)
 
206
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
207
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
208
  unconditional_guidance_scale=1., unconditional_conditioning=None,
209
+ uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None,guidance_rescale=0.0,**kwargs):
210
  b, *_, device = *x.shape, x.device
211
  if x.dim() == 5:
212
  is_video = True
 
214
  is_video = False
215
 
216
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
217
+ model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
218
  else:
219
+ ### do_classifier_free_guidance
220
  if isinstance(c, torch.Tensor) or isinstance(c, dict):
221
  e_t_cond = self.model.apply_model(x, t, c, **kwargs)
222
  e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
223
  else:
224
  raise NotImplementedError
225
 
226
+ model_output = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond)
227
 
228
+ if guidance_rescale > 0.0:
229
+ model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale)
230
 
231
+ if self.model.parameterization == "v":
232
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
233
+ else:
234
+ e_t = model_output
235
 
236
  if score_corrector is not None:
237
+ assert self.model.parameterization == "eps", 'not implemented'
238
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
239
 
240
  alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
241
  alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
242
  sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
243
+ # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
244
  sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
245
  # select parameters corresponding to the currently considered timestep
246
 
 
257
  if self.model.parameterization != "v":
258
  pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
259
  else:
260
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
261
 
262
  if self.model.use_dynamic_rescale:
263
  scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
lvdm/models/samplers/ddim_multiplecond.py CHANGED
@@ -1,297 +1,323 @@
1
- import numpy as np
2
- from tqdm import tqdm
3
- import torch
4
- from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps
5
- from lvdm.common import noise_like
6
- from lvdm.common import extract_into_tensor
7
-
8
-
9
- class DDIMSampler(object):
10
- def __init__(self, model, schedule="linear", **kwargs):
11
- super().__init__()
12
- self.model = model
13
- self.ddpm_num_timesteps = model.num_timesteps
14
- self.schedule = schedule
15
- self.counter = 0
16
-
17
- def register_buffer(self, name, attr):
18
- if type(attr) == torch.Tensor:
19
- if attr.device != torch.device("cuda"):
20
- attr = attr.to(torch.device("cuda"))
21
- setattr(self, name, attr)
22
-
23
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
- alphas_cumprod = self.model.alphas_cumprod
27
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
-
30
- self.register_buffer('betas', to_torch(self.model.betas))
31
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
-
34
- # calculations for diffusion q(x_t | x_{t-1}) and others
35
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
-
41
- # ddim sampling parameters
42
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
- ddim_timesteps=self.ddim_timesteps,
44
- eta=ddim_eta,verbose=verbose)
45
- self.register_buffer('ddim_sigmas', ddim_sigmas)
46
- self.register_buffer('ddim_alphas', ddim_alphas)
47
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
-
54
- @torch.no_grad()
55
- def sample(self,
56
- S,
57
- batch_size,
58
- shape,
59
- conditioning=None,
60
- callback=None,
61
- normals_sequence=None,
62
- img_callback=None,
63
- quantize_x0=False,
64
- eta=0.,
65
- mask=None,
66
- x0=None,
67
- temperature=1.,
68
- noise_dropout=0.,
69
- score_corrector=None,
70
- corrector_kwargs=None,
71
- verbose=True,
72
- schedule_verbose=False,
73
- x_T=None,
74
- log_every_t=100,
75
- unconditional_guidance_scale=1.,
76
- unconditional_conditioning=None,
77
- precision=None,
78
- fs=None,
79
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
- **kwargs
81
- ):
82
-
83
- # check condition bs
84
- if conditioning is not None:
85
- if isinstance(conditioning, dict):
86
- try:
87
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
88
- except:
89
- cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
90
-
91
- if cbs != batch_size:
92
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
93
- else:
94
- if conditioning.shape[0] != batch_size:
95
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
-
97
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose)
98
-
99
- # make shape
100
- if len(shape) == 3:
101
- C, H, W = shape
102
- size = (batch_size, C, H, W)
103
- elif len(shape) == 4:
104
- C, T, H, W = shape
105
- size = (batch_size, C, T, H, W)
106
- # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
107
-
108
- samples, intermediates = self.ddim_sampling(conditioning, size,
109
- callback=callback,
110
- img_callback=img_callback,
111
- quantize_denoised=quantize_x0,
112
- mask=mask, x0=x0,
113
- ddim_use_original_steps=False,
114
- noise_dropout=noise_dropout,
115
- temperature=temperature,
116
- score_corrector=score_corrector,
117
- corrector_kwargs=corrector_kwargs,
118
- x_T=x_T,
119
- log_every_t=log_every_t,
120
- unconditional_guidance_scale=unconditional_guidance_scale,
121
- unconditional_conditioning=unconditional_conditioning,
122
- verbose=verbose,
123
- precision=precision,
124
- fs=fs,
125
- **kwargs)
126
- return samples, intermediates
127
-
128
- @torch.no_grad()
129
- def ddim_sampling(self, cond, shape,
130
- x_T=None, ddim_use_original_steps=False,
131
- callback=None, timesteps=None, quantize_denoised=False,
132
- mask=None, x0=None, img_callback=None, log_every_t=100,
133
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
134
- unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,
135
- **kwargs):
136
- device = self.model.betas.device
137
- b = shape[0]
138
- if x_T is None:
139
- img = torch.randn(shape, device=device)
140
- else:
141
- img = x_T
142
- if precision is not None:
143
- if precision == 16:
144
- img = img.to(dtype=torch.float16)
145
-
146
-
147
- if timesteps is None:
148
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
149
- elif timesteps is not None and not ddim_use_original_steps:
150
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
151
- timesteps = self.ddim_timesteps[:subset_end]
152
-
153
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
154
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
155
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
156
- if verbose:
157
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
158
- else:
159
- iterator = time_range
160
-
161
- clean_cond = kwargs.pop("clean_cond", False)
162
- for i, step in enumerate(iterator):
163
- index = total_steps - i - 1
164
- ts = torch.full((b,), step, device=device, dtype=torch.long)
165
-
166
- ## use mask to blend noised original latent (img_orig) & new sampled latent (img)
167
- if mask is not None:
168
- assert x0 is not None
169
- if clean_cond:
170
- img_orig = x0
171
- else:
172
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? <ddim inversion>
173
- img = img_orig * mask + (1. - mask) * img # keep original & modify use img
174
-
175
- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
176
- quantize_denoised=quantize_denoised, temperature=temperature,
177
- noise_dropout=noise_dropout, score_corrector=score_corrector,
178
- corrector_kwargs=corrector_kwargs,
179
- unconditional_guidance_scale=unconditional_guidance_scale,
180
- unconditional_conditioning=unconditional_conditioning,
181
- mask=mask,x0=x0,fs=fs,
182
- **kwargs)
183
-
184
-
185
-
186
- img, pred_x0 = outs
187
- if callback: callback(i)
188
- if img_callback: img_callback(pred_x0, i)
189
-
190
- if index % log_every_t == 0 or index == total_steps - 1:
191
- intermediates['x_inter'].append(img)
192
- intermediates['pred_x0'].append(pred_x0)
193
-
194
- return img, intermediates
195
-
196
- @torch.no_grad()
197
- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
198
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
199
- unconditional_guidance_scale=1., unconditional_conditioning=None,
200
- uc_type=None, cfg_img=None,mask=None,x0=None, **kwargs):
201
- b, *_, device = *x.shape, x.device
202
- if x.dim() == 5:
203
- is_video = True
204
- else:
205
- is_video = False
206
- if cfg_img is None:
207
- cfg_img = unconditional_guidance_scale
208
-
209
- unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext']
210
-
211
-
212
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
213
- e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
214
- else:
215
- ### with unconditional condition
216
- e_t_cond = self.model.apply_model(x, t, c, **kwargs)
217
- e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
218
- e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs)
219
- # text cfg
220
- e_t = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img)
221
-
222
- if self.model.parameterization == "v":
223
- e_t = self.model.predict_eps_from_z_and_v(x, t, e_t)
224
-
225
-
226
- if score_corrector is not None:
227
- assert self.model.parameterization == "eps"
228
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
229
-
230
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
231
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
232
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
233
- sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
234
- # select parameters corresponding to the currently considered timestep
235
-
236
- if is_video:
237
- size = (b, 1, 1, 1, 1)
238
- else:
239
- size = (b, 1, 1, 1)
240
- a_t = torch.full(size, alphas[index], device=device)
241
- a_prev = torch.full(size, alphas_prev[index], device=device)
242
- sigma_t = torch.full(size, sigmas[index], device=device)
243
- sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device)
244
-
245
- # current prediction for x_0
246
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
247
-
248
- if quantize_denoised:
249
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
250
- # direction pointing to x_t
251
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
252
-
253
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
254
- if noise_dropout > 0.:
255
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
256
-
257
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
258
-
259
- return x_prev, pred_x0
260
-
261
- @torch.no_grad()
262
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
263
- use_original_steps=False, callback=None):
264
-
265
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
266
- timesteps = timesteps[:t_start]
267
-
268
- time_range = np.flip(timesteps)
269
- total_steps = timesteps.shape[0]
270
- print(f"Running DDIM Sampling with {total_steps} timesteps")
271
-
272
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
273
- x_dec = x_latent
274
- for i, step in enumerate(iterator):
275
- index = total_steps - i - 1
276
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
277
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
278
- unconditional_guidance_scale=unconditional_guidance_scale,
279
- unconditional_conditioning=unconditional_conditioning)
280
- if callback: callback(i)
281
- return x_dec
282
-
283
- @torch.no_grad()
284
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
285
- # fast, but does not allow for exact reconstruction
286
- # t serves as an index to gather the correct alphas
287
- if use_original_steps:
288
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
289
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
290
- else:
291
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
292
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
293
-
294
- if noise is None:
295
- noise = torch.randn_like(x0)
296
- return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ import torch
4
+ from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
5
+ from lvdm.common import noise_like
6
+ from lvdm.common import extract_into_tensor
7
+ import copy
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+ self.counter = 0
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
+ alphas_cumprod = self.model.alphas_cumprod
28
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30
+
31
+ if self.model.use_dynamic_rescale:
32
+ self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
33
+ self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
34
+
35
+ self.register_buffer('betas', to_torch(self.model.betas))
36
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
37
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
38
+
39
+ # calculations for diffusion q(x_t | x_{t-1}) and others
40
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
44
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
45
+
46
+ # ddim sampling parameters
47
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
48
+ ddim_timesteps=self.ddim_timesteps,
49
+ eta=ddim_eta,verbose=verbose)
50
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
51
+ self.register_buffer('ddim_alphas', ddim_alphas)
52
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
53
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
54
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
55
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
56
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
57
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
58
+
59
+ @torch.no_grad()
60
+ def sample(self,
61
+ S,
62
+ batch_size,
63
+ shape,
64
+ conditioning=None,
65
+ callback=None,
66
+ normals_sequence=None,
67
+ img_callback=None,
68
+ quantize_x0=False,
69
+ eta=0.,
70
+ mask=None,
71
+ x0=None,
72
+ temperature=1.,
73
+ noise_dropout=0.,
74
+ score_corrector=None,
75
+ corrector_kwargs=None,
76
+ verbose=True,
77
+ schedule_verbose=False,
78
+ x_T=None,
79
+ log_every_t=100,
80
+ unconditional_guidance_scale=1.,
81
+ unconditional_conditioning=None,
82
+ precision=None,
83
+ fs=None,
84
+ timestep_spacing='uniform', #uniform_trailing for starting from last timestep
85
+ guidance_rescale=0.0,
86
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
87
+ **kwargs
88
+ ):
89
+
90
+ # check condition bs
91
+ if conditioning is not None:
92
+ if isinstance(conditioning, dict):
93
+ try:
94
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
95
+ except:
96
+ cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
97
+
98
+ if cbs != batch_size:
99
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
100
+ else:
101
+ if conditioning.shape[0] != batch_size:
102
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
103
+
104
+ # print('==> timestep_spacing: ', timestep_spacing, guidance_rescale)
105
+ self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose)
106
+
107
+ # make shape
108
+ if len(shape) == 3:
109
+ C, H, W = shape
110
+ size = (batch_size, C, H, W)
111
+ elif len(shape) == 4:
112
+ C, T, H, W = shape
113
+ size = (batch_size, C, T, H, W)
114
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
115
+
116
+ samples, intermediates = self.ddim_sampling(conditioning, size,
117
+ callback=callback,
118
+ img_callback=img_callback,
119
+ quantize_denoised=quantize_x0,
120
+ mask=mask, x0=x0,
121
+ ddim_use_original_steps=False,
122
+ noise_dropout=noise_dropout,
123
+ temperature=temperature,
124
+ score_corrector=score_corrector,
125
+ corrector_kwargs=corrector_kwargs,
126
+ x_T=x_T,
127
+ log_every_t=log_every_t,
128
+ unconditional_guidance_scale=unconditional_guidance_scale,
129
+ unconditional_conditioning=unconditional_conditioning,
130
+ verbose=verbose,
131
+ precision=precision,
132
+ fs=fs,
133
+ guidance_rescale=guidance_rescale,
134
+ **kwargs)
135
+ return samples, intermediates
136
+
137
+ @torch.no_grad()
138
+ def ddim_sampling(self, cond, shape,
139
+ x_T=None, ddim_use_original_steps=False,
140
+ callback=None, timesteps=None, quantize_denoised=False,
141
+ mask=None, x0=None, img_callback=None, log_every_t=100,
142
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
143
+ unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0,
144
+ **kwargs):
145
+ device = self.model.betas.device
146
+ b = shape[0]
147
+ if x_T is None:
148
+ img = torch.randn(shape, device=device)
149
+ else:
150
+ img = x_T
151
+ if precision is not None:
152
+ if precision == 16:
153
+ img = img.to(dtype=torch.float16)
154
+
155
+
156
+ if timesteps is None:
157
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
158
+ elif timesteps is not None and not ddim_use_original_steps:
159
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
160
+ timesteps = self.ddim_timesteps[:subset_end]
161
+
162
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
163
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
164
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
165
+ if verbose:
166
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
167
+ else:
168
+ iterator = time_range
169
+
170
+ clean_cond = kwargs.pop("clean_cond", False)
171
+
172
+ # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning)
173
+ for i, step in enumerate(iterator):
174
+ index = total_steps - i - 1
175
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
176
+
177
+ ## use mask to blend noised original latent (img_orig) & new sampled latent (img)
178
+ if mask is not None:
179
+ assert x0 is not None
180
+ if clean_cond:
181
+ img_orig = x0
182
+ else:
183
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? <ddim inversion>
184
+ img = img_orig * mask + (1. - mask) * img # keep original & modify use img
185
+
186
+
187
+
188
+
189
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
190
+ quantize_denoised=quantize_denoised, temperature=temperature,
191
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
192
+ corrector_kwargs=corrector_kwargs,
193
+ unconditional_guidance_scale=unconditional_guidance_scale,
194
+ unconditional_conditioning=unconditional_conditioning,
195
+ mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale,
196
+ **kwargs)
197
+
198
+
199
+
200
+ img, pred_x0 = outs
201
+ if callback: callback(i)
202
+ if img_callback: img_callback(pred_x0, i)
203
+
204
+ if index % log_every_t == 0 or index == total_steps - 1:
205
+ intermediates['x_inter'].append(img)
206
+ intermediates['pred_x0'].append(pred_x0)
207
+
208
+ return img, intermediates
209
+
210
+ @torch.no_grad()
211
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
212
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
213
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
214
+ uc_type=None, cfg_img=None,mask=None,x0=None,guidance_rescale=0.0, **kwargs):
215
+ b, *_, device = *x.shape, x.device
216
+ if x.dim() == 5:
217
+ is_video = True
218
+ else:
219
+ is_video = False
220
+ if cfg_img is None:
221
+ cfg_img = unconditional_guidance_scale
222
+
223
+ unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext']
224
+
225
+
226
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
227
+ model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
228
+ else:
229
+ ### with unconditional condition
230
+ e_t_cond = self.model.apply_model(x, t, c, **kwargs)
231
+ e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
232
+ e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs)
233
+ # text cfg
234
+ model_output = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img)
235
+ if guidance_rescale > 0.0:
236
+ model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale)
237
+
238
+ if self.model.parameterization == "v":
239
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
240
+ else:
241
+ e_t = model_output
242
+
243
+ if score_corrector is not None:
244
+ assert self.model.parameterization == "eps", 'not implemented'
245
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
246
+
247
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
248
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
249
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
250
+ sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
251
+ # select parameters corresponding to the currently considered timestep
252
+
253
+ if is_video:
254
+ size = (b, 1, 1, 1, 1)
255
+ else:
256
+ size = (b, 1, 1, 1)
257
+ a_t = torch.full(size, alphas[index], device=device)
258
+ a_prev = torch.full(size, alphas_prev[index], device=device)
259
+ sigma_t = torch.full(size, sigmas[index], device=device)
260
+ sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device)
261
+
262
+ # current prediction for x_0
263
+ if self.model.parameterization != "v":
264
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
265
+ else:
266
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
267
+
268
+ if self.model.use_dynamic_rescale:
269
+ scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
270
+ prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device)
271
+ rescale = (prev_scale_t / scale_t)
272
+ pred_x0 *= rescale
273
+
274
+ if quantize_denoised:
275
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
276
+ # direction pointing to x_t
277
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
278
+
279
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
280
+ if noise_dropout > 0.:
281
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
282
+
283
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
284
+
285
+ return x_prev, pred_x0
286
+
287
+ @torch.no_grad()
288
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
289
+ use_original_steps=False, callback=None):
290
+
291
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
292
+ timesteps = timesteps[:t_start]
293
+
294
+ time_range = np.flip(timesteps)
295
+ total_steps = timesteps.shape[0]
296
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
297
+
298
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
299
+ x_dec = x_latent
300
+ for i, step in enumerate(iterator):
301
+ index = total_steps - i - 1
302
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
303
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
304
+ unconditional_guidance_scale=unconditional_guidance_scale,
305
+ unconditional_conditioning=unconditional_conditioning)
306
+ if callback: callback(i)
307
+ return x_dec
308
+
309
+ @torch.no_grad()
310
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
311
+ # fast, but does not allow for exact reconstruction
312
+ # t serves as an index to gather the correct alphas
313
+ if use_original_steps:
314
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
315
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
316
+ else:
317
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
318
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
319
+
320
+ if noise is None:
321
+ noise = torch.randn_like(x0)
322
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
323
  extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
lvdm/models/utils_diffusion.py CHANGED
@@ -57,14 +57,20 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
57
  if ddim_discr_method == 'uniform':
58
  c = num_ddpm_timesteps // num_ddim_timesteps
59
  ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
 
 
 
 
 
60
  elif ddim_discr_method == 'quad':
61
  ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
 
62
  else:
63
  raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
64
 
65
  # assert ddim_timesteps.shape[0] == num_ddim_timesteps
66
  # add one to get the final alpha values right (the ones from first scale to data during sampling)
67
- steps_out = ddim_timesteps + 1
68
  if verbose:
69
  print(f'Selected timesteps for ddim sampler: {steps_out}')
70
  return steps_out
@@ -101,4 +107,52 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
101
  t1 = i / num_diffusion_timesteps
102
  t2 = (i + 1) / num_diffusion_timesteps
103
  betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
104
- return np.array(betas)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if ddim_discr_method == 'uniform':
58
  c = num_ddpm_timesteps // num_ddim_timesteps
59
  ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
60
+ steps_out = ddim_timesteps + 1
61
+ elif ddim_discr_method == 'uniform_trailing':
62
+ c = num_ddpm_timesteps / num_ddim_timesteps
63
+ ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, -c))).astype(np.int64)
64
+ steps_out = ddim_timesteps - 1
65
  elif ddim_discr_method == 'quad':
66
  ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
67
+ steps_out = ddim_timesteps + 1
68
  else:
69
  raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
70
 
71
  # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72
  # add one to get the final alpha values right (the ones from first scale to data during sampling)
73
+ # steps_out = ddim_timesteps + 1
74
  if verbose:
75
  print(f'Selected timesteps for ddim sampler: {steps_out}')
76
  return steps_out
 
107
  t1 = i / num_diffusion_timesteps
108
  t2 = (i + 1) / num_diffusion_timesteps
109
  betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
110
+ return np.array(betas)
111
+
112
+ def rescale_zero_terminal_snr(betas):
113
+ """
114
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
115
+
116
+ Args:
117
+ betas (`numpy.ndarray`):
118
+ the betas that the scheduler is being initialized with.
119
+
120
+ Returns:
121
+ `numpy.ndarray`: rescaled betas with zero terminal SNR
122
+ """
123
+ # Convert betas to alphas_bar_sqrt
124
+ alphas = 1.0 - betas
125
+ alphas_cumprod = np.cumprod(alphas, axis=0)
126
+ alphas_bar_sqrt = np.sqrt(alphas_cumprod)
127
+
128
+ # Store old values.
129
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
130
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
131
+
132
+ # Shift so the last timestep is zero.
133
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
134
+
135
+ # Scale so the first timestep is back to the old value.
136
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
137
+
138
+ # Convert alphas_bar_sqrt to betas
139
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
140
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
141
+ alphas = np.concatenate([alphas_bar[0:1], alphas])
142
+ betas = 1 - alphas
143
+
144
+ return betas
145
+
146
+
147
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
148
+ """
149
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
150
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
151
+ """
152
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
153
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
154
+ # rescale the results from guidance (fixes overexposure)
155
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
156
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
157
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
158
+ return noise_cfg
lvdm/modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
lvdm/modules/encoders/resampler.py CHANGED
@@ -1,145 +1,145 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
- # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
4
- import math
5
- import torch
6
- import torch.nn as nn
7
-
8
-
9
- class ImageProjModel(nn.Module):
10
- """Projection Model"""
11
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
12
- super().__init__()
13
- self.cross_attention_dim = cross_attention_dim
14
- self.clip_extra_context_tokens = clip_extra_context_tokens
15
- self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
16
- self.norm = nn.LayerNorm(cross_attention_dim)
17
-
18
- def forward(self, image_embeds):
19
- #embeds = image_embeds
20
- embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
21
- clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
22
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
23
- return clip_extra_context_tokens
24
-
25
-
26
- # FFN
27
- def FeedForward(dim, mult=4):
28
- inner_dim = int(dim * mult)
29
- return nn.Sequential(
30
- nn.LayerNorm(dim),
31
- nn.Linear(dim, inner_dim, bias=False),
32
- nn.GELU(),
33
- nn.Linear(inner_dim, dim, bias=False),
34
- )
35
-
36
-
37
- def reshape_tensor(x, heads):
38
- bs, length, width = x.shape
39
- #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
40
- x = x.view(bs, length, heads, -1)
41
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
42
- x = x.transpose(1, 2)
43
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
44
- x = x.reshape(bs, heads, length, -1)
45
- return x
46
-
47
-
48
- class PerceiverAttention(nn.Module):
49
- def __init__(self, *, dim, dim_head=64, heads=8):
50
- super().__init__()
51
- self.scale = dim_head**-0.5
52
- self.dim_head = dim_head
53
- self.heads = heads
54
- inner_dim = dim_head * heads
55
-
56
- self.norm1 = nn.LayerNorm(dim)
57
- self.norm2 = nn.LayerNorm(dim)
58
-
59
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
60
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
61
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
62
-
63
-
64
- def forward(self, x, latents):
65
- """
66
- Args:
67
- x (torch.Tensor): image features
68
- shape (b, n1, D)
69
- latent (torch.Tensor): latent features
70
- shape (b, n2, D)
71
- """
72
- x = self.norm1(x)
73
- latents = self.norm2(latents)
74
-
75
- b, l, _ = latents.shape
76
-
77
- q = self.to_q(latents)
78
- kv_input = torch.cat((x, latents), dim=-2)
79
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
80
-
81
- q = reshape_tensor(q, self.heads)
82
- k = reshape_tensor(k, self.heads)
83
- v = reshape_tensor(v, self.heads)
84
-
85
- # attention
86
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
87
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
88
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
89
- out = weight @ v
90
-
91
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
92
-
93
- return self.to_out(out)
94
-
95
-
96
- class Resampler(nn.Module):
97
- def __init__(
98
- self,
99
- dim=1024,
100
- depth=8,
101
- dim_head=64,
102
- heads=16,
103
- num_queries=8,
104
- embedding_dim=768,
105
- output_dim=1024,
106
- ff_mult=4,
107
- video_length=None, # using frame-wise version or not
108
- ):
109
- super().__init__()
110
- ## queries for a single frame / image
111
- self.num_queries = num_queries
112
- self.video_length = video_length
113
-
114
- ## <num_queries> queries for each frame
115
- if video_length is not None:
116
- num_queries = num_queries * video_length
117
-
118
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
119
- self.proj_in = nn.Linear(embedding_dim, dim)
120
- self.proj_out = nn.Linear(dim, output_dim)
121
- self.norm_out = nn.LayerNorm(output_dim)
122
-
123
- self.layers = nn.ModuleList([])
124
- for _ in range(depth):
125
- self.layers.append(
126
- nn.ModuleList(
127
- [
128
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
129
- FeedForward(dim=dim, mult=ff_mult),
130
- ]
131
- )
132
- )
133
-
134
- def forward(self, x):
135
- latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C
136
- x = self.proj_in(x)
137
-
138
- for attn, ff in self.layers:
139
- latents = attn(x, latents) + latents
140
- latents = ff(latents) + latents
141
-
142
- latents = self.proj_out(latents)
143
- latents = self.norm_out(latents) # B L C or B (T L) C
144
-
145
  return latents
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+ # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class ImageProjModel(nn.Module):
10
+ """Projection Model"""
11
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
12
+ super().__init__()
13
+ self.cross_attention_dim = cross_attention_dim
14
+ self.clip_extra_context_tokens = clip_extra_context_tokens
15
+ self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
16
+ self.norm = nn.LayerNorm(cross_attention_dim)
17
+
18
+ def forward(self, image_embeds):
19
+ #embeds = image_embeds
20
+ embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
21
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
22
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
23
+ return clip_extra_context_tokens
24
+
25
+
26
+ # FFN
27
+ def FeedForward(dim, mult=4):
28
+ inner_dim = int(dim * mult)
29
+ return nn.Sequential(
30
+ nn.LayerNorm(dim),
31
+ nn.Linear(dim, inner_dim, bias=False),
32
+ nn.GELU(),
33
+ nn.Linear(inner_dim, dim, bias=False),
34
+ )
35
+
36
+
37
+ def reshape_tensor(x, heads):
38
+ bs, length, width = x.shape
39
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
40
+ x = x.view(bs, length, heads, -1)
41
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
42
+ x = x.transpose(1, 2)
43
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
44
+ x = x.reshape(bs, heads, length, -1)
45
+ return x
46
+
47
+
48
+ class PerceiverAttention(nn.Module):
49
+ def __init__(self, *, dim, dim_head=64, heads=8):
50
+ super().__init__()
51
+ self.scale = dim_head**-0.5
52
+ self.dim_head = dim_head
53
+ self.heads = heads
54
+ inner_dim = dim_head * heads
55
+
56
+ self.norm1 = nn.LayerNorm(dim)
57
+ self.norm2 = nn.LayerNorm(dim)
58
+
59
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
60
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
61
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
62
+
63
+
64
+ def forward(self, x, latents):
65
+ """
66
+ Args:
67
+ x (torch.Tensor): image features
68
+ shape (b, n1, D)
69
+ latent (torch.Tensor): latent features
70
+ shape (b, n2, D)
71
+ """
72
+ x = self.norm1(x)
73
+ latents = self.norm2(latents)
74
+
75
+ b, l, _ = latents.shape
76
+
77
+ q = self.to_q(latents)
78
+ kv_input = torch.cat((x, latents), dim=-2)
79
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
80
+
81
+ q = reshape_tensor(q, self.heads)
82
+ k = reshape_tensor(k, self.heads)
83
+ v = reshape_tensor(v, self.heads)
84
+
85
+ # attention
86
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
87
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
88
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
89
+ out = weight @ v
90
+
91
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
92
+
93
+ return self.to_out(out)
94
+
95
+
96
+ class Resampler(nn.Module):
97
+ def __init__(
98
+ self,
99
+ dim=1024,
100
+ depth=8,
101
+ dim_head=64,
102
+ heads=16,
103
+ num_queries=8,
104
+ embedding_dim=768,
105
+ output_dim=1024,
106
+ ff_mult=4,
107
+ video_length=None, # using frame-wise version or not
108
+ ):
109
+ super().__init__()
110
+ ## queries for a single frame / image
111
+ self.num_queries = num_queries
112
+ self.video_length = video_length
113
+
114
+ ## <num_queries> queries for each frame
115
+ if video_length is not None:
116
+ num_queries = num_queries * video_length
117
+
118
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
119
+ self.proj_in = nn.Linear(embedding_dim, dim)
120
+ self.proj_out = nn.Linear(dim, output_dim)
121
+ self.norm_out = nn.LayerNorm(output_dim)
122
+
123
+ self.layers = nn.ModuleList([])
124
+ for _ in range(depth):
125
+ self.layers.append(
126
+ nn.ModuleList(
127
+ [
128
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
129
+ FeedForward(dim=dim, mult=ff_mult),
130
+ ]
131
+ )
132
+ )
133
+
134
+ def forward(self, x):
135
+ latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C
136
+ x = self.proj_in(x)
137
+
138
+ for attn, ff in self.layers:
139
+ latents = attn(x, latents) + latents
140
+ latents = ff(latents) + latents
141
+
142
+ latents = self.proj_out(latents)
143
+ latents = self.norm_out(latents) # B L C or B (T L) C
144
+
145
  return latents
lvdm/modules/networks/__pycache__/openaimodel3d.cpython-39.pyc CHANGED
Binary files a/lvdm/modules/networks/__pycache__/openaimodel3d.cpython-39.pyc and b/lvdm/modules/networks/__pycache__/openaimodel3d.cpython-39.pyc differ
 
lvdm/modules/networks/openaimodel3d.py CHANGED
@@ -373,14 +373,13 @@ class UNetModel(nn.Module):
373
  linear(time_embed_dim, time_embed_dim),
374
  )
375
  if fs_condition:
376
- self.framestride_embed = nn.Sequential(
377
  linear(model_channels, time_embed_dim),
378
  nn.SiLU(),
379
  linear(time_embed_dim, time_embed_dim),
380
  )
381
- nn.init.zeros_(self.framestride_embed[-1].weight)
382
- nn.init.zeros_(self.framestride_embed[-1].bias)
383
-
384
  ## Input Block
385
  self.input_blocks = nn.ModuleList(
386
  [
@@ -572,7 +571,8 @@ class UNetModel(nn.Module):
572
  fs = torch.tensor(
573
  [self.default_fs] * b, dtype=torch.long, device=x.device)
574
  fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
575
- fs_embed = self.framestride_embed(fs_emb)
 
576
  fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
577
  emb = emb + fs_embed
578
 
 
373
  linear(time_embed_dim, time_embed_dim),
374
  )
375
  if fs_condition:
376
+ self.fps_embedding = nn.Sequential(
377
  linear(model_channels, time_embed_dim),
378
  nn.SiLU(),
379
  linear(time_embed_dim, time_embed_dim),
380
  )
381
+ nn.init.zeros_(self.fps_embedding[-1].weight)
382
+ nn.init.zeros_(self.fps_embedding[-1].bias)
 
383
  ## Input Block
384
  self.input_blocks = nn.ModuleList(
385
  [
 
571
  fs = torch.tensor(
572
  [self.default_fs] * b, dtype=torch.long, device=x.device)
573
  fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
574
+
575
+ fs_embed = self.fps_embedding(fs_emb)
576
  fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
577
  emb = emb + fs_embed
578