jupyterjazz
commited on
Commit
•
cadf946
1
Parent(s):
7ad815b
fix: rope dim mismatch
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
rotary.py
CHANGED
@@ -31,6 +31,10 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
31 |
"""
|
32 |
ro_dim = cos.shape[-1] * 2
|
33 |
assert ro_dim <= x.shape[-1]
|
|
|
|
|
|
|
|
|
34 |
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
35 |
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
36 |
return torch.cat(
|
|
|
31 |
"""
|
32 |
ro_dim = cos.shape[-1] * 2
|
33 |
assert ro_dim <= x.shape[-1]
|
34 |
+
cos, sin = (
|
35 |
+
cos[:x.shape[1]],
|
36 |
+
sin[:x.shape[1]],
|
37 |
+
)
|
38 |
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
39 |
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
40 |
return torch.cat(
|