Sapir Weissbuch commited on
Commit
00c2119
·
unverified ·
2 Parent(s): ee102d4 eec4cb2

Merge pull request #11 from LightricksResearch/rm-dist-util

Browse files
xora/models/autoencoders/vae_encode.py CHANGED
@@ -6,8 +6,10 @@ from torch import Tensor
6
 
7
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
8
  from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
9
- import xora.utils.dist_util
10
-
 
 
11
 
12
  def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
13
  """
@@ -54,10 +56,12 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
54
  encode_bs = len(media_items) // split_size
55
  # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
56
  latents = []
57
- dist_util.execute_graph()
 
58
  for image_batch in media_items.split(encode_bs):
59
  latents.append(vae.encode(image_batch).latent_dist.sample())
60
- dist_util.execute_graph()
 
61
  latents = torch.cat(latents, dim=0)
62
  else:
63
  latents = vae.encode(media_items).latent_dist.sample()
 
6
 
7
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
8
  from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
9
+ try:
10
+ import torch_xla.core.xla_model as xm
11
+ except:
12
+ pass
13
 
14
  def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
15
  """
 
56
  encode_bs = len(media_items) // split_size
57
  # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
58
  latents = []
59
+ if media_items.device.type == "xla":
60
+ xm.mark_step()
61
  for image_batch in media_items.split(encode_bs):
62
  latents.append(vae.encode(image_batch).latent_dist.sample())
63
+ if media_items.device.type == "xla":
64
+ xm.mark_step()
65
  latents = torch.cat(latents, dim=0)
66
  else:
67
  latents = vae.encode(media_items).latent_dist.sample()
xora/utils/dist_util.py DELETED
@@ -1,5 +0,0 @@
1
- from enum import Enum
2
-
3
- def execute_graph() -> None:
4
- if _acceleration_type == AccelerationType.TPU:
5
- xm.mark_step()