yairschiff commited on
Commit
0600f8f
1 Parent(s): 2a6b613

Enable mambav2 compat

Browse files
Files changed (1) hide show
  1. modeling_rcps.py +5 -2
modeling_rcps.py CHANGED
@@ -10,9 +10,12 @@ from torch import nn
10
  from torch.nn import functional as F
11
 
12
  try:
13
- from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
14
  except ImportError:
15
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
 
 
 
16
 
17
 
18
  class RCPSEmbedding(nn.Module):
 
10
  from torch.nn import functional as F
11
 
12
  try:
13
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
14
  except ImportError:
15
+ try:
16
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
17
+ except ImportError:
18
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
19
 
20
 
21
  class RCPSEmbedding(nn.Module):