NIRVANALAN commited on
Commit
592a426
1 Parent(s): caf9793

update dep

Browse files
Files changed (1) hide show
  1. dit/dit_i23d.py +7 -2
dit/dit_i23d.py CHANGED
@@ -9,10 +9,15 @@ from pdb import set_trace as st
9
 
10
  from ldm.modules.attention import MemoryEfficientCrossAttention
11
  from .dit_models_xformers import DiT, get_2d_sincos_pos_embed, ImageCondDiTBlock, FinalLayer, CaptionEmbedder, approx_gelu, ImageCondDiTBlockPixelArt, t2i_modulate, ImageCondDiTBlockPixelArtRMSNorm, T2IFinalLayer, ImageCondDiTBlockPixelArtRMSNormNoClip
12
- from apex.normalization import FusedLayerNorm as LayerNorm
13
- from apex.normalization import FusedRMSNorm as RMSNorm
14
  from timm.models.vision_transformer import Mlp
15
 
 
 
 
 
 
 
 
16
  # from vit.vit_triplane import XYZPosEmbed
17
 
18
 
 
9
 
10
  from ldm.modules.attention import MemoryEfficientCrossAttention
11
  from .dit_models_xformers import DiT, get_2d_sincos_pos_embed, ImageCondDiTBlock, FinalLayer, CaptionEmbedder, approx_gelu, ImageCondDiTBlockPixelArt, t2i_modulate, ImageCondDiTBlockPixelArtRMSNorm, T2IFinalLayer, ImageCondDiTBlockPixelArtRMSNormNoClip
 
 
12
  from timm.models.vision_transformer import Mlp
13
 
14
+ try:
15
+ from apex.normalization import FusedLayerNorm as LayerNorm
16
+ from apex.normalization import FusedRMSNorm as RMSNorm
17
+ except:
18
+ from torch.nn import LayerNorm
19
+ from dit.norm import RMSNorm
20
+
21
  # from vit.vit_triplane import XYZPosEmbed
22
 
23