Files changed (1) hide show
  1. rotary.py +4 -0
rotary.py CHANGED
@@ -31,6 +31,10 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
31
  """
32
  ro_dim = cos.shape[-1] * 2
33
  assert ro_dim <= x.shape[-1]
 
 
 
 
34
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
35
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
36
  return torch.cat(
 
31
  """
32
  ro_dim = cos.shape[-1] * 2
33
  assert ro_dim <= x.shape[-1]
34
+ cos, sin = (
35
+ cos[:x.shape[1]],
36
+ sin[:x.shape[1]],
37
+ )
38
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
39
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
40
  return torch.cat(