zxdu20 commited on
Commit
10b3f35
1 Parent(s): 1c2a801

flash_attention_2 (#51)

Browse files

- Add eager and sdpa attention implementations (835c717962e2632f116db776a087970c22e4a6c1)
- Add support for flash attention 2 (a7eaddd0ac0e89cf779dce9596635369178e15cf)
- Merge branch 'main' into attention (29038ea19de709ec833a7ad9e86e838e274194f2)

Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_chatglm.py +203 -81
config.json CHANGED
@@ -17,6 +17,7 @@
17
  "apply_residual_connection_post_layernorm": false,
18
  "attention_dropout": 0.0,
19
  "attention_softmax_in_fp32": true,
 
20
  "bias_dropout_fusion": true,
21
  "ffn_hidden_size": 13696,
22
  "fp32_residual_connection": false,
 
17
  "apply_residual_connection_post_layernorm": false,
18
  "attention_dropout": 0.0,
19
  "attention_softmax_in_fp32": true,
20
+ "attn_implementation": "sdpa",
21
  "bias_dropout_fusion": true,
22
  "ffn_hidden_size": 13696,
23
  "fp32_residual_connection": false,
modeling_chatglm.py CHANGED
@@ -21,12 +21,17 @@ from transformers.modeling_outputs import (
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.utils import logging, is_torch_npu_available
 
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
28
  from .configuration_chatglm import ChatGLMConfig
29
 
 
 
 
 
30
  # flags required to enable jit fusion kernels
31
 
32
  if sys.platform != 'darwin' and not is_torch_npu_available():
@@ -40,6 +45,7 @@ logger = logging.get_logger(__name__)
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
 
43
  def default_init(cls, *args, **kwargs):
44
  return cls(*args, **kwargs)
45
 
@@ -159,12 +165,13 @@ class RMSNorm(torch.nn.Module):
159
  class CoreAttention(torch.nn.Module):
160
  def __init__(self, config: ChatGLMConfig, layer_number):
161
  super(CoreAttention, self).__init__()
162
-
163
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
164
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
165
  if self.apply_query_key_layer_scaling:
166
  self.attention_softmax_in_fp32 = True
167
  self.layer_number = max(1, layer_number)
 
168
 
169
  projection_size = config.kv_channels * config.num_attention_heads
170
 
@@ -183,91 +190,198 @@ class CoreAttention(torch.nn.Module):
183
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
184
 
185
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
186
- pytorch_major_version = int(torch.__version__.split('.')[0])
187
- if pytorch_major_version >= 2:
188
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
189
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
190
- is_causal=True)
191
- else:
192
- if attention_mask is not None:
193
- attention_mask = ~attention_mask
194
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
195
- attention_mask)
196
- context_layer = context_layer.transpose(1, 2).contiguous()
197
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
198
- context_layer = context_layer.reshape(*new_context_layer_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  else:
200
- # Raw attention scores
 
 
 
 
 
 
 
 
201
 
202
- # [b, np, sq, sk]
203
- output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
204
 
205
- # [b, np, sq, hn] -> [b * np, sq, hn]
206
- query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
207
- # [b, np, sk, hn] -> [b * np, sk, hn]
208
- key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- # preallocting input tensor: [b * np, sq, sk]
211
- matmul_input_buffer = torch.empty(
212
- output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
213
- device=query_layer.device
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
 
216
- # Raw attention scores. [b * np, sq, sk]
217
- matmul_result = torch.baddbmm(
218
- matmul_input_buffer,
219
- query_layer, # [b * np, sq, hn]
220
- key_layer.transpose(1, 2), # [b * np, hn, sk]
221
- beta=0.0,
222
- alpha=(1.0 / self.norm_factor),
 
 
 
 
 
 
 
223
  )
224
 
225
- # change view to [b, np, sq, sk]
226
- attention_scores = matmul_result.view(*output_size)
227
-
228
- # ===========================
229
- # Attention probs and dropout
230
- # ===========================
231
-
232
- # attention scores and attention mask [b, np, sq, sk]
233
- if self.attention_softmax_in_fp32:
234
- attention_scores = attention_scores.float()
235
- if self.coeff is not None:
236
- attention_scores = attention_scores * self.coeff
237
- if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
238
- attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
239
- device=attention_scores.device, dtype=torch.bool)
240
- attention_mask.tril_()
241
- attention_mask = ~attention_mask
242
- if attention_mask is not None:
243
- attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
244
- attention_probs = F.softmax(attention_scores, dim=-1)
245
- attention_probs = attention_probs.type_as(value_layer)
246
-
247
- # This is actually dropping out entire tokens to attend to, which might
248
- # seem a bit unusual, but is taken from the original Transformer paper.
249
- attention_probs = self.attention_dropout(attention_probs)
250
-
251
- # query layer shape: [b * np, sq, hn]
252
- # value layer shape: [b, np, sk, hn]
253
- # attention shape: [b, np, sq, sk]
254
- # context layer shape: [b, np, sq, hn]
255
- output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
256
- # change view [b * np, sk, hn]
257
- value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
258
- # change view [b * np, sq, sk]
259
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
260
- # matmul: [b * np, sq, hn]
261
- context_layer = torch.bmm(attention_probs, value_layer)
262
- # change view [b, np, sq, hn]
263
- context_layer = context_layer.view(*output_size)
264
- # [b, np, sq, hn] --> [b, sq, np, hn]
265
- context_layer = context_layer.transpose(1, 2).contiguous()
266
- # [b, sq, np, hn] --> [b, sq, hp]
267
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
268
- context_layer = context_layer.reshape(*new_context_layer_shape)
269
 
270
- return context_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
 
273
  class SelfAttention(torch.nn.Module):
@@ -299,7 +413,7 @@ class SelfAttention(torch.nn.Module):
299
  device=device, **_config_to_kwargs(config)
300
  )
301
 
302
- self.core_attention = CoreAttention(config, self.layer_number)
303
 
304
  # Output.
305
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
@@ -378,7 +492,8 @@ class SelfAttention(torch.nn.Module):
378
  value_layer = torch.cat((cache_v, value_layer), dim=2)
379
  if use_cache:
380
  if kv_cache is None:
381
- kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
 
382
  else:
383
  kv_cache = (key_layer, value_layer)
384
  else:
@@ -644,12 +759,18 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
644
  config_class = ChatGLMConfig
645
  base_model_prefix = "transformer"
646
  _no_split_modules = ["GLMBlock"]
 
 
647
 
648
  def _init_weights(self, module: nn.Module):
649
  """Initialize the weights."""
650
  return
651
 
652
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
 
 
 
 
653
  batch_size, seq_length = input_ids.shape
654
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
655
  full_attention_mask.tril_()
@@ -724,7 +845,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
724
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
725
  )
726
 
727
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
 
728
  device=device, dtype=config.torch_dtype)
729
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
730
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
 
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging, is_torch_npu_available, is_flash_attn_greater_or_equal_2_10, \
25
+ is_flash_attn_2_available
26
  from transformers.generation.logits_process import LogitsProcessor
27
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
 
29
  from .configuration_chatglm import ChatGLMConfig
30
 
31
+ if is_flash_attn_2_available():
32
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
33
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
34
+
35
  # flags required to enable jit fusion kernels
36
 
37
  if sys.platform != 'darwin' and not is_torch_npu_available():
 
45
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
46
  _CONFIG_FOR_DOC = "ChatGLMConfig"
47
 
48
+
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
51
 
 
165
  class CoreAttention(torch.nn.Module):
166
  def __init__(self, config: ChatGLMConfig, layer_number):
167
  super(CoreAttention, self).__init__()
168
+ self.config = config
169
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
170
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
171
  if self.apply_query_key_layer_scaling:
172
  self.attention_softmax_in_fp32 = True
173
  self.layer_number = max(1, layer_number)
174
+ self.is_causal = True
175
 
176
  projection_size = config.kv_channels * config.num_attention_heads
177
 
 
190
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
191
 
192
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
193
+ # [b, np, sq, sk]
194
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
195
+
196
+ # [b, np, sq, hn] -> [b * np, sq, hn]
197
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
198
+ # [b, np, sk, hn] -> [b * np, sk, hn]
199
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
200
+
201
+ # preallocting input tensor: [b * np, sq, sk]
202
+ matmul_input_buffer = torch.empty(
203
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
204
+ device=query_layer.device
205
+ )
206
+
207
+ # Raw attention scores. [b * np, sq, sk]
208
+ matmul_result = torch.baddbmm(
209
+ matmul_input_buffer,
210
+ query_layer, # [b * np, sq, hn]
211
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
212
+ beta=0.0,
213
+ alpha=(1.0 / self.norm_factor),
214
+ )
215
+
216
+ # change view to [b, np, sq, sk]
217
+ attention_scores = matmul_result.view(*output_size)
218
+
219
+ # ===========================
220
+ # Attention probs and dropout
221
+ # ===========================
222
+
223
+ # attention scores and attention mask [b, np, sq, sk]
224
+ if self.attention_softmax_in_fp32:
225
+ attention_scores = attention_scores.float()
226
+ if self.coeff is not None:
227
+ attention_scores = attention_scores * self.coeff
228
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
229
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
230
+ device=attention_scores.device, dtype=torch.bool)
231
+ attention_mask.tril_()
232
+ attention_mask = ~attention_mask
233
+ if attention_mask is not None:
234
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
235
+ attention_probs = F.softmax(attention_scores, dim=-1)
236
+ attention_probs = attention_probs.type_as(value_layer)
237
+
238
+ # This is actually dropping out entire tokens to attend to, which might
239
+ # seem a bit unusual, but is taken from the original Transformer paper.
240
+ attention_probs = self.attention_dropout(attention_probs)
241
+
242
+ # query layer shape: [b * np, sq, hn]
243
+ # value layer shape: [b, np, sk, hn]
244
+ # attention shape: [b, np, sq, sk]
245
+ # context layer shape: [b, np, sq, hn]
246
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
247
+ # change view [b * np, sk, hn]
248
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
249
+ # change view [b * np, sq, sk]
250
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
251
+ # matmul: [b * np, sq, hn]
252
+ context_layer = torch.bmm(attention_probs, value_layer)
253
+ # change view [b, np, sq, hn]
254
+ context_layer = context_layer.view(*output_size)
255
+ # [b, np, sq, hn] --> [b, sq, np, hn]
256
+ context_layer = context_layer.transpose(1, 2).contiguous()
257
+ # [b, sq, np, hn] --> [b, sq, hp]
258
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
259
+ context_layer = context_layer.reshape(*new_context_layer_shape)
260
+
261
+ return context_layer
262
+
263
+
264
+ class SdpaAttention(CoreAttention):
265
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
266
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
267
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
268
+ is_causal=True,
269
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
270
  else:
271
+ if attention_mask is not None:
272
+ attention_mask = ~attention_mask
273
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
274
+ attention_mask,
275
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
276
+ context_layer = context_layer.transpose(1, 2).contiguous()
277
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
278
+ context_layer = context_layer.reshape(*new_context_layer_shape)
279
+ return context_layer
280
 
 
 
281
 
282
+ def _get_unpad_data(attention_mask):
283
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
284
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
285
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
286
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
287
+ return (
288
+ indices,
289
+ cu_seqlens,
290
+ max_seqlen_in_batch,
291
+ )
292
+
293
+
294
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
295
+ class FlashAttention2(CoreAttention):
296
+ def __init__(self, *args, **kwargs):
297
+ super().__init__(*args, **kwargs)
298
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
299
 
300
+ def forward(self, query_states, key_states, value_states, attention_mask):
301
+ query_states = query_states.transpose(1, 2)
302
+ key_states = key_states.transpose(1, 2)
303
+ value_states = value_states.transpose(1, 2)
304
+ batch_size, query_length = query_states.shape[:2]
305
+ if not self._flash_attn_uses_top_left_mask:
306
+ causal = self.is_causal
307
+ else:
308
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
309
+ causal = self.is_causal and query_length != 1
310
+ dropout = self.config.attention_dropout if self.training else 0.0
311
+ # Contains at least one padding token in the sequence
312
+ if attention_mask is not None:
313
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
314
+ query_states, key_states, value_states, attention_mask, query_length
315
  )
316
 
317
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
318
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
319
+
320
+ attn_output_unpad = flash_attn_varlen_func(
321
+ query_states,
322
+ key_states,
323
+ value_states,
324
+ cu_seqlens_q=cu_seqlens_q,
325
+ cu_seqlens_k=cu_seqlens_k,
326
+ max_seqlen_q=max_seqlen_in_batch_q,
327
+ max_seqlen_k=max_seqlen_in_batch_k,
328
+ dropout_p=dropout,
329
+ softmax_scale=None,
330
+ causal=causal,
331
  )
332
 
333
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
334
+ else:
335
+ attn_output = flash_attn_func(
336
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
337
+ )
338
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
339
+ return attn_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
342
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
343
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
344
+
345
+ key_layer = index_first_axis(
346
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
347
+ )
348
+ value_layer = index_first_axis(
349
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
350
+ )
351
+ if query_length == kv_seq_len:
352
+ query_layer = index_first_axis(
353
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), indices_k
354
+ )
355
+ cu_seqlens_q = cu_seqlens_k
356
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
357
+ indices_q = indices_k
358
+ elif query_length == 1:
359
+ max_seqlen_in_batch_q = 1
360
+ cu_seqlens_q = torch.arange(
361
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
362
+ ) # There is a memcpy here, that is very bad.
363
+ indices_q = cu_seqlens_q[:-1]
364
+ query_layer = query_layer.squeeze(1)
365
+ else:
366
+ # The -q_len: slice assumes left padding.
367
+ attention_mask = attention_mask[:, -query_length:]
368
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
369
+
370
+ return (
371
+ query_layer,
372
+ key_layer,
373
+ value_layer,
374
+ indices_q,
375
+ (cu_seqlens_q, cu_seqlens_k),
376
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
377
+ )
378
+
379
+
380
+ CORE_ATTENTION_CLASSES = {
381
+ "eager": CoreAttention,
382
+ "sdpa": SdpaAttention,
383
+ "flash_attention_2": FlashAttention2
384
+ }
385
 
386
 
387
  class SelfAttention(torch.nn.Module):
 
413
  device=device, **_config_to_kwargs(config)
414
  )
415
 
416
+ self.core_attention = CORE_ATTENTION_CLASSES[config._attn_implementation](config, self.layer_number)
417
 
418
  # Output.
419
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
 
492
  value_layer = torch.cat((cache_v, value_layer), dim=2)
493
  if use_cache:
494
  if kv_cache is None:
495
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
496
+ dim=1)
497
  else:
498
  kv_cache = (key_layer, value_layer)
499
  else:
 
759
  config_class = ChatGLMConfig
760
  base_model_prefix = "transformer"
761
  _no_split_modules = ["GLMBlock"]
762
+ _supports_flash_attn_2 = True
763
+ _supports_sdpa = True
764
 
765
  def _init_weights(self, module: nn.Module):
766
  """Initialize the weights."""
767
  return
768
 
769
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
770
+ if self.config._attn_implementation == "flash_attention_2":
771
+ if padding_mask is not None and not padding_mask.all():
772
+ return padding_mask
773
+ return None
774
  batch_size, seq_length = input_ids.shape
775
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
776
  full_attention_mask.tril_()
 
845
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
846
  )
847
 
848
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
849
+ original_impl=config.original_rope,
850
  device=device, dtype=config.torch_dtype)
851
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
852
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,