Fabrice-TIERCELIN commited on
Commit
dbb0d36
1 Parent(s): c6701f2

Upload llama_flash_attn_monkey_patch.py

Browse files
llava/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import warnings
3
+
4
+ import torch
5
+
6
+ import transformers
7
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8
+
9
+ try:
10
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11
+ except ImportError:
12
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13
+ from flash_attn.bert_padding import unpad_input, pad_input
14
+
15
+
16
+ def forward(
17
+ self,
18
+ hidden_states: torch.Tensor,
19
+ attention_mask: Optional[torch.Tensor] = None,
20
+ position_ids: Optional[torch.Tensor] = None,
21
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
22
+ output_attentions: bool = False,
23
+ use_cache: bool = False,
24
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25
+ if output_attentions:
26
+ warnings.warn(
27
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28
+ )
29
+
30
+ bsz, q_len, _ = hidden_states.size()
31
+
32
+ query_states = (
33
+ self.q_proj(hidden_states)
34
+ .view(bsz, q_len, self.num_heads, self.head_dim)
35
+ .transpose(1, 2)
36
+ )
37
+ key_states = (
38
+ self.k_proj(hidden_states)
39
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40
+ .transpose(1, 2)
41
+ )
42
+ value_states = (
43
+ self.v_proj(hidden_states)
44
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45
+ .transpose(1, 2)
46
+ ) # shape: (b, num_heads, s, head_dim)
47
+
48
+ kv_seq_len = key_states.shape[-2]
49
+ if past_key_value is not None:
50
+ kv_seq_len += past_key_value[0].shape[-2]
51
+
52
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53
+ query_states, key_states = apply_rotary_pos_emb(
54
+ query_states, key_states, cos, sin, position_ids
55
+ )
56
+
57
+ if past_key_value is not None:
58
+ # reuse k, v
59
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
60
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
61
+
62
+ past_key_value = (key_states, value_states) if use_cache else None
63
+
64
+ # repeat k/v heads if n_kv_heads < n_heads
65
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
66
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
67
+
68
+ # Transform the data into the format required by flash attention
69
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
70
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71
+ key_padding_mask = attention_mask
72
+
73
+ if key_padding_mask is None:
74
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75
+ cu_q_lens = torch.arange(
76
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77
+ )
78
+ max_s = q_len
79
+ output = flash_attn_unpadded_qkvpacked_func(
80
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81
+ )
82
+ output = output.view(bsz, q_len, -1)
83
+ else:
84
+ qkv = qkv.reshape(bsz, q_len, -1)
85
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
88
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89
+ )
90
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91
+ output = pad_input(output_unpad, indices, bsz, q_len)
92
+
93
+ return self.o_proj(output), None, past_key_value
94
+
95
+
96
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
97
+ # requires the attention mask to be the same as the key_padding_mask
98
+ def _prepare_decoder_attention_mask(
99
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100
+ ):
101
+ # [bsz, seq_len]
102
+ return attention_mask
103
+
104
+
105
+ def replace_llama_attn_with_flash_attn():
106
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
107
+ if cuda_major < 8:
108
+ warnings.warn(
109
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111
+ )
112
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113
+ _prepare_decoder_attention_mask
114
+ )
115
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward