NIRVANALAN commited on
Commit
945a01c
1 Parent(s): bcb1668
Files changed (1) hide show
  1. ldm/modules/attention.py +2 -2
ldm/modules/attention.py CHANGED
@@ -257,8 +257,8 @@ class MemoryEfficientCrossAttention(nn.Module):
257
  # self.q_rmsnorm = RMSNorm(query_dim, eps=1e-5)
258
  # self.k_rmsnorm = RMSNorm(context_dim, eps=1e-5)
259
 
260
- self.q_norm = RMSNorm(self.dim_head, elementwise_affine=True) if qk_norm else nn.Identity()
261
- self.k_norm = RMSNorm(self.dim_head, elementwise_affine=True) if qk_norm else nn.Identity()
262
 
263
  # self.enable_rmsnorm = enable_rmsnorm
264
 
 
257
  # self.q_rmsnorm = RMSNorm(query_dim, eps=1e-5)
258
  # self.k_rmsnorm = RMSNorm(context_dim, eps=1e-5)
259
 
260
+ self.q_norm = RMSNorm(self.dim_head, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity()
261
+ self.k_norm = RMSNorm(self.dim_head, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity()
262
 
263
  # self.enable_rmsnorm = enable_rmsnorm
264