yairschiff
commited on
Commit
•
0600f8f
1
Parent(s):
2a6b613
Enable mambav2 compat
Browse files- modeling_rcps.py +5 -2
modeling_rcps.py
CHANGED
@@ -10,9 +10,12 @@ from torch import nn
|
|
10 |
from torch.nn import functional as F
|
11 |
|
12 |
try:
|
13 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
14 |
except ImportError:
|
15 |
-
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class RCPSEmbedding(nn.Module):
|
|
|
10 |
from torch.nn import functional as F
|
11 |
|
12 |
try:
|
13 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
|
14 |
except ImportError:
|
15 |
+
try:
|
16 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
|
17 |
+
except ImportError:
|
18 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
19 |
|
20 |
|
21 |
class RCPSEmbedding(nn.Module):
|