jupyterjazz commited on
Commit
11ba200
1 Parent(s): 77a17f7

refactor: revert alibi stuff

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (1) hide show
  1. mha.py +33 -5
mha.py CHANGED
@@ -56,7 +56,15 @@ class FlashSelfAttention(nn.Module):
56
  (default: 0.0)
57
  """
58
 
59
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
 
 
 
 
 
 
 
 
60
  super().__init__()
61
  assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
62
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
@@ -64,6 +72,7 @@ class FlashSelfAttention(nn.Module):
64
  self.softmax_scale = softmax_scale
65
  self.drop = nn.Dropout(attention_dropout)
66
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
 
67
  self.deterministic = deterministic
68
 
69
  def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
@@ -87,6 +96,8 @@ class FlashSelfAttention(nn.Module):
87
  assert qkv.is_cuda
88
  causal = self.causal if causal is None else causal
89
  unpadded = cu_seqlens is not None
 
 
90
  if unpadded:
91
  assert cu_seqlens.dtype == torch.int32
92
  assert max_seqlen is not None
@@ -99,6 +110,7 @@ class FlashSelfAttention(nn.Module):
99
  softmax_scale=self.softmax_scale,
100
  causal=causal,
101
  alibi_slopes=self.alibi_slopes,
 
102
  deterministic=self.deterministic,
103
  )
104
  else:
@@ -108,6 +120,7 @@ class FlashSelfAttention(nn.Module):
108
  softmax_scale=self.softmax_scale,
109
  causal=causal,
110
  alibi_slopes=self.alibi_slopes,
 
111
  deterministic=self.deterministic,
112
  )
113
 
@@ -123,7 +136,15 @@ class FlashCrossAttention(nn.Module):
123
  (default: 0.0)
124
  """
125
 
126
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
 
 
 
 
 
 
 
 
127
  super().__init__()
128
  assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
129
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
@@ -131,6 +152,7 @@ class FlashCrossAttention(nn.Module):
131
  self.softmax_scale = softmax_scale
132
  self.drop = nn.Dropout(attention_dropout)
133
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
 
134
  self.deterministic = deterministic
135
 
136
  def forward(
@@ -160,6 +182,8 @@ class FlashCrossAttention(nn.Module):
160
  assert q.is_cuda and kv.is_cuda
161
  causal = self.causal if causal is None else causal
162
  unpadded = cu_seqlens is not None
 
 
163
  if unpadded:
164
  assert cu_seqlens.dtype == torch.int32
165
  assert max_seqlen is not None
@@ -179,6 +203,7 @@ class FlashCrossAttention(nn.Module):
179
  softmax_scale=self.softmax_scale,
180
  causal=causal,
181
  alibi_slopes=self.alibi_slopes,
 
182
  deterministic=self.deterministic,
183
  )
184
  else:
@@ -192,6 +217,7 @@ class FlashCrossAttention(nn.Module):
192
  causal=causal,
193
  softmax_scale=self.softmax_scale,
194
  alibi_slopes=self.alibi_slopes,
 
195
  deterministic=self.deterministic,
196
  )
197
 
@@ -367,6 +393,7 @@ class MHA(nn.Module):
367
  rotary_emb_scale_base=None,
368
  rotary_emb_interleaved=False,
369
  use_alibi=False,
 
370
  fused_bias_fc=False,
371
  use_flash_attn=False,
372
  return_residual=False,
@@ -396,6 +423,8 @@ class MHA(nn.Module):
396
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
397
  else:
398
  alibi_slopes = None
 
 
399
 
400
  self.num_heads = num_heads
401
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
@@ -426,12 +455,12 @@ class MHA(nn.Module):
426
  )
427
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
428
  inner_attn_cls = (
429
- partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
430
  if use_flash_attn
431
  else SelfAttention
432
  )
433
  inner_cross_attn_cls = (
434
- partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
435
  if use_flash_attn
436
  else CrossAttention
437
  )
@@ -584,7 +613,6 @@ class MHA(nn.Module):
584
  assert key_padding_mask is None
585
  assert self.use_flash_attn
586
  assert not self.dwconv
587
- # assert self.rotary_emb_dim == 0
588
  if key_padding_mask is not None:
589
  assert cu_seqlens is None
590
  assert max_seqlen is None
 
56
  (default: 0.0)
57
  """
58
 
59
+ def __init__(
60
+ self,
61
+ causal=False,
62
+ softmax_scale=None,
63
+ attention_dropout=0.0,
64
+ window_size=(-1, -1),
65
+ alibi_slopes=None,
66
+ deterministic=False,
67
+ ):
68
  super().__init__()
69
  assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
70
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
 
72
  self.softmax_scale = softmax_scale
73
  self.drop = nn.Dropout(attention_dropout)
74
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
75
+ self.window_size = window_size
76
  self.deterministic = deterministic
77
 
78
  def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
 
96
  assert qkv.is_cuda
97
  causal = self.causal if causal is None else causal
98
  unpadded = cu_seqlens is not None
99
+ if self.alibi_slopes is not None:
100
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
101
  if unpadded:
102
  assert cu_seqlens.dtype == torch.int32
103
  assert max_seqlen is not None
 
110
  softmax_scale=self.softmax_scale,
111
  causal=causal,
112
  alibi_slopes=self.alibi_slopes,
113
+ window_size=self.window_size,
114
  deterministic=self.deterministic,
115
  )
116
  else:
 
120
  softmax_scale=self.softmax_scale,
121
  causal=causal,
122
  alibi_slopes=self.alibi_slopes,
123
+ window_size=self.window_size,
124
  deterministic=self.deterministic,
125
  )
126
 
 
136
  (default: 0.0)
137
  """
138
 
139
+ def __init__(
140
+ self,
141
+ causal=False,
142
+ softmax_scale=None,
143
+ attention_dropout=0.0,
144
+ alibi_slopes=None,
145
+ window_size=(-1, -1),
146
+ deterministic=False,
147
+ ):
148
  super().__init__()
149
  assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
150
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
 
152
  self.softmax_scale = softmax_scale
153
  self.drop = nn.Dropout(attention_dropout)
154
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
155
+ self.window_size = window_size
156
  self.deterministic = deterministic
157
 
158
  def forward(
 
182
  assert q.is_cuda and kv.is_cuda
183
  causal = self.causal if causal is None else causal
184
  unpadded = cu_seqlens is not None
185
+ if self.alibi_slopes is not None:
186
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
187
  if unpadded:
188
  assert cu_seqlens.dtype == torch.int32
189
  assert max_seqlen is not None
 
203
  softmax_scale=self.softmax_scale,
204
  causal=causal,
205
  alibi_slopes=self.alibi_slopes,
206
+ window_size=self.window_size,
207
  deterministic=self.deterministic,
208
  )
209
  else:
 
217
  causal=causal,
218
  softmax_scale=self.softmax_scale,
219
  alibi_slopes=self.alibi_slopes,
220
+ window_size=self.window_size,
221
  deterministic=self.deterministic,
222
  )
223
 
 
393
  rotary_emb_scale_base=None,
394
  rotary_emb_interleaved=False,
395
  use_alibi=False,
396
+ window_size=(-1, -1),
397
  fused_bias_fc=False,
398
  use_flash_attn=False,
399
  return_residual=False,
 
423
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
424
  else:
425
  alibi_slopes = None
426
+ if window_size != (-1, -1):
427
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
428
 
429
  self.num_heads = num_heads
430
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
 
455
  )
456
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
457
  inner_attn_cls = (
458
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
459
  if use_flash_attn
460
  else SelfAttention
461
  )
462
  inner_cross_attn_cls = (
463
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
464
  if use_flash_attn
465
  else CrossAttention
466
  )
 
613
  assert key_padding_mask is None
614
  assert self.use_flash_attn
615
  assert not self.dwconv
 
616
  if key_padding_mask is not None:
617
  assert cu_seqlens is None
618
  assert max_seqlen is None