Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
@@ -295,8 +295,12 @@ class StreamMultiDiffusion(nn.Module):
|
|
295 |
def reset_latent(self) -> None:
|
296 |
# initialize x_t_latent (it can be any random tensor)
|
297 |
b = (self.denoising_steps_num - 1) * self.frame_bff_size
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
300 |
|
301 |
def reset_state(self) -> None:
|
302 |
# TODO Reset states for context switch between multiple users.
|
@@ -305,24 +309,35 @@ class StreamMultiDiffusion(nn.Module):
|
|
305 |
def prepare(self) -> None:
|
306 |
# make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
|
307 |
self.timesteps = self.scheduler.timesteps.to(self.device)
|
308 |
-
|
309 |
for t in self.t_list:
|
310 |
-
|
311 |
-
sub_timesteps_tensor = torch.tensor(
|
312 |
-
|
|
|
|
|
|
|
313 |
|
314 |
c_skip_list = []
|
315 |
c_out_list = []
|
316 |
-
for timestep in
|
317 |
c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
|
318 |
c_skip_list.append(c_skip)
|
319 |
c_out_list.append(c_out)
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
alpha_prod_t_sqrt_list = []
|
324 |
beta_prod_t_sqrt_list = []
|
325 |
-
for timestep in
|
326 |
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
|
327 |
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
|
328 |
alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
|
@@ -331,12 +346,24 @@ class StreamMultiDiffusion(nn.Module):
|
|
331 |
.to(dtype=self.dtype, device=self.device))
|
332 |
beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
|
333 |
.to(dtype=self.dtype, device=self.device))
|
334 |
-
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
noise_lvs = ((1 - self.scheduler.alphas_cumprod.to(self.device)[self.sub_timesteps_tensor]) ** 0.5)
|
338 |
-
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
|
341 |
@torch.no_grad()
|
342 |
def get_text_prompts(self, image: Image.Image) -> str:
|
|
|
295 |
def reset_latent(self) -> None:
|
296 |
# initialize x_t_latent (it can be any random tensor)
|
297 |
b = (self.denoising_steps_num - 1) * self.frame_bff_size
|
298 |
+
if not hasattr(self, 'x_t_latent_buffer'):
|
299 |
+
self.register_buffer('x_t_latent_buffer', torch.zeros(
|
300 |
+
(b, 4, self.latent_height, self.latent_width), dtype=self.dtype, device=self.device))
|
301 |
+
else:
|
302 |
+
self.x_t_latent_buffer = torch.zeros(
|
303 |
+
(b, 4, self.latent_height, self.latent_width), dtype=self.dtype, device=self.device))
|
304 |
|
305 |
def reset_state(self) -> None:
|
306 |
# TODO Reset states for context switch between multiple users.
|
|
|
309 |
def prepare(self) -> None:
|
310 |
# make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
|
311 |
self.timesteps = self.scheduler.timesteps.to(self.device)
|
312 |
+
sub_timesteps = []
|
313 |
for t in self.t_list:
|
314 |
+
sub_timesteps.append(self.timesteps[t])
|
315 |
+
sub_timesteps_tensor = torch.tensor(sub_timesteps, dtype=torch.long, device=self.device)
|
316 |
+
if not hasattr(self, 'sub_timesteps_tensor'):
|
317 |
+
self.register_buffer('sub_timesteps_tensor', sub_timesteps_tensor.repeat_interleave(self.frame_bff_size, dim=0)
|
318 |
+
else:
|
319 |
+
self.sub_timesteps_tensor = sub_timesteps_tensor.repeat_interleave(self.frame_bff_size, dim=0)
|
320 |
|
321 |
c_skip_list = []
|
322 |
c_out_list = []
|
323 |
+
for timestep in sub_timesteps:
|
324 |
c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
|
325 |
c_skip_list.append(c_skip)
|
326 |
c_out_list.append(c_out)
|
327 |
+
c_skip = torch.stack(c_skip_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
|
328 |
+
c_out = torch.stack(c_out_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
|
329 |
+
if not hasattr(self, 'c_skip'):
|
330 |
+
self.register_buffer('c_skip', c_skip)
|
331 |
+
else:
|
332 |
+
self.c_skip = c_skip
|
333 |
+
if not hasattr(self, 'c_out'):
|
334 |
+
self.register_buffer('c_out', c_out)
|
335 |
+
else:
|
336 |
+
self.c_out = c_out
|
337 |
|
338 |
alpha_prod_t_sqrt_list = []
|
339 |
beta_prod_t_sqrt_list = []
|
340 |
+
for timestep in sub_timesteps:
|
341 |
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
|
342 |
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
|
343 |
alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
|
|
|
346 |
.to(dtype=self.dtype, device=self.device))
|
347 |
beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
|
348 |
.to(dtype=self.dtype, device=self.device))
|
349 |
+
if not hasattr(self, 'alpha_prod_t_sqrt'):
|
350 |
+
self.register_buffer('alpha_prod_t_sqrt', alpha_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0))
|
351 |
+
else:
|
352 |
+
self.alpha_prod_t_sqrt = alpha_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
|
353 |
+
if not hasattr(self, 'beta_prod_t_sqrt'):
|
354 |
+
self.register_buffer('beta_prod_t_sqrt', beta_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0))
|
355 |
+
else:
|
356 |
+
self.beta_prod_t_sqrt = beta_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
|
357 |
|
358 |
noise_lvs = ((1 - self.scheduler.alphas_cumprod.to(self.device)[self.sub_timesteps_tensor]) ** 0.5)
|
359 |
+
if not hasattr(self, 'noise_lvs'):
|
360 |
+
self.register_buffer('noise_lvs', noise_lvs[None, :, None, None, None])
|
361 |
+
else:
|
362 |
+
self.noise_lvs = noise_lvs[None, :, None, None, None]
|
363 |
+
if not hasattr(self, 'next_noise_lvs'):
|
364 |
+
self.register_buffer('next_noise_lvs', torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None])
|
365 |
+
else:
|
366 |
+
self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
|
367 |
|
368 |
@torch.no_grad()
|
369 |
def get_text_prompts(self, image: Image.Image) -> str:
|