jupyterjazz commited on
Commit
8b64fa8
1 Parent(s): f2e0e62

chore: remove parallelmha

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (1) hide show
  1. mha.py +0 -315
mha.py CHANGED
@@ -7,8 +7,6 @@ import torch
7
  import torch.nn as nn
8
  from einops import rearrange, repeat
9
 
10
- from flash_attn.utils.distributed import get_dim_for_local_rank
11
-
12
  try:
13
  from flash_attn import (
14
  flash_attn_kvpacked_func,
@@ -706,316 +704,3 @@ class MHA(nn.Module):
706
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
707
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
708
  return out if not self.return_residual else (out, x)
709
-
710
-
711
- class ParallelMHA(nn.Module):
712
- """Multi-head self-attention and cross-attention"""
713
-
714
- def __init__(
715
- self,
716
- embed_dim,
717
- num_heads,
718
- process_group,
719
- num_heads_kv=None,
720
- qkv_proj_bias=True,
721
- out_proj_bias=True,
722
- dropout=0.0,
723
- softmax_scale=None,
724
- causal=False,
725
- layer_idx=None,
726
- rotary_emb_dim=0,
727
- rotary_emb_base=10000.0,
728
- rotary_emb_scale_base=None,
729
- rotary_emb_interleaved=False,
730
- use_alibi=False,
731
- use_flash_attn=False,
732
- checkpointing=False,
733
- sequence_parallel=True,
734
- device=None,
735
- dtype=None,
736
- ) -> None:
737
- factory_kwargs = {"device": device, "dtype": dtype}
738
- super().__init__()
739
- self.embed_dim = embed_dim
740
- self.causal = causal
741
- self.layer_idx = layer_idx
742
- self.rotary_emb_dim = rotary_emb_dim
743
- self.use_flash_attn = use_flash_attn
744
- self.checkpointing = checkpointing
745
- self.process_group = process_group
746
- self.world_size = process_group.size()
747
- self.local_rank = torch.distributed.get_rank(process_group)
748
-
749
- self.num_heads = num_heads
750
- assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
751
-
752
- self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
753
- assert (
754
- self.num_heads % self.num_heads_kv == 0
755
- ), "num_heads must be divisible by num_heads_kv"
756
-
757
- self.num_heads_per_rank = get_dim_for_local_rank(
758
- self.num_heads, self.world_size, self.local_rank
759
- )
760
- self.num_heads_kv_per_rank = get_dim_for_local_rank(
761
- self.num_heads_kv, self.world_size, self.local_rank
762
- )
763
- self.head_dim = self.embed_dim // num_heads
764
- qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
765
-
766
- if use_alibi:
767
- assert use_flash_attn, "ALiBi code path requires flash_attn"
768
- num_heads_local = math.ceil(self.num_heads / self.world_size)
769
- alibi_slopes = torch.tensor(
770
- get_alibi_slopes(num_heads)[
771
- self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
772
- ],
773
- device=device,
774
- )
775
- else:
776
- alibi_slopes = None
777
-
778
- if self.rotary_emb_dim > 0:
779
- assert RotaryEmbedding is not None, "rotary_emb is not installed"
780
- self.rotary_emb = RotaryEmbedding(
781
- self.rotary_emb_dim,
782
- base=rotary_emb_base,
783
- scale_base=rotary_emb_scale_base,
784
- interleaved=rotary_emb_interleaved,
785
- device=device,
786
- )
787
-
788
- if ColumnParallelLinear is None or RowParallelLinear is None:
789
- raise ImportError("fused_dense is not installed")
790
- self.Wqkv = ColumnParallelLinear(
791
- embed_dim,
792
- qkv_dim,
793
- process_group,
794
- bias=qkv_proj_bias,
795
- sequence_parallel=sequence_parallel,
796
- multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
797
- **factory_kwargs,
798
- )
799
- inner_attn_cls = (
800
- partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
801
- if use_flash_attn
802
- else SelfAttention
803
- )
804
- inner_cross_attn_cls = (
805
- partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
806
- if use_flash_attn
807
- else CrossAttention
808
- )
809
- self.inner_attn = inner_attn_cls(
810
- causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
811
- )
812
- self.inner_cross_attn = inner_cross_attn_cls(
813
- causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
814
- )
815
- self.out_proj = RowParallelLinear(
816
- embed_dim,
817
- embed_dim,
818
- process_group,
819
- bias=out_proj_bias,
820
- sequence_parallel=sequence_parallel,
821
- multiple_of=self.head_dim,
822
- **factory_kwargs,
823
- )
824
-
825
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
826
- dtype = self.out_proj.weight.dtype if dtype is None else dtype
827
- device = self.out_proj.weight.device
828
- return torch.empty(
829
- batch_size,
830
- max_seqlen,
831
- 2,
832
- self.num_heads_kv_per_rank,
833
- self.head_dim,
834
- dtype=dtype,
835
- device=device,
836
- )
837
-
838
- def _update_kv_cache(self, kv, inference_params):
839
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
840
- assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
841
- return _update_kv_cache(kv, inference_params, self.layer_idx)
842
-
843
- def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
844
- """
845
- Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
846
- q: (batch_size, seqlen_q, nheads, head_dim)
847
- kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
848
- """
849
- assert inference_params is not None and inference_params.seqlen_offset > 0
850
- assert self.use_flash_attn
851
- if self.rotary_emb_dim > 0:
852
- assert self.rotary_emb.scale is None, "This code path does not support xPos"
853
- self.rotary_emb._update_cos_sin_cache(
854
- inference_params.max_seqlen, device=q.device, dtype=q.dtype
855
- )
856
- rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
857
- else:
858
- rotary_cos, rotary_sin = None, None
859
- batch = q.shape[0]
860
- kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
861
- cache_seqlens = (
862
- inference_params.lengths_per_sample[:batch]
863
- if inference_params.lengths_per_sample is not None
864
- else inference_params.seqlen_offset
865
- )
866
- alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
867
- context = flash_attn_with_kvcache(
868
- q,
869
- kv_cache[:, :, 0],
870
- kv_cache[:, :, 1],
871
- kv[:, :, 0],
872
- kv[:, :, 1],
873
- rotary_cos=rotary_cos,
874
- rotary_sin=rotary_sin,
875
- cache_seqlens=cache_seqlens,
876
- softmax_scale=self.inner_cross_attn.softmax_scale,
877
- causal=self.inner_cross_attn.causal,
878
- rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
879
- alibi_slopes=alibi_slopes,
880
- )
881
- return context
882
-
883
- def _update_kvcache_attention(self, q, kv, inference_params):
884
- """Write kv to inference_params, then do attention"""
885
- if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
886
- # TODO: this only uses seqlen_offset and not lengths_per_sample.
887
- kv = self._update_kv_cache(kv, inference_params)
888
- return self.inner_cross_attn(q, kv)
889
- else:
890
- batch = q.shape[0]
891
- kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
892
- cache_seqlens = (
893
- inference_params.lengths_per_sample[:batch]
894
- if inference_params.lengths_per_sample is not None
895
- else inference_params.seqlen_offset
896
- )
897
- alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
898
- context = flash_attn_with_kvcache(
899
- q,
900
- kv_cache[:, :, 0],
901
- kv_cache[:, :, 1],
902
- kv[:, :, 0],
903
- kv[:, :, 1],
904
- cache_seqlens=cache_seqlens,
905
- softmax_scale=self.inner_cross_attn.softmax_scale,
906
- causal=self.inner_cross_attn.causal,
907
- alibi_slopes=alibi_slopes,
908
- )
909
- return context
910
-
911
- def forward(
912
- self, x, seqlen=None, inference_params=None, cu_seqlens=None, max_seqlen=None, **kwargs
913
- ):
914
- """
915
- Arguments:
916
- x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None and cu_seqlens=None.
917
- (seqlen, hidden_dim) if cu_seqlens not None, seqlen equal cu_seqlens[-1].
918
- If seqlen is not None and cu_seqlens=None, x is (batch * seqlen, hidden_dim). This is so that when we
919
- split x during sequence parallel, we split the batch * seqlen dimension
920
- (in case batch is small).
921
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
922
- of the sequences in the batch, used to index into x. Only applicable when using
923
- FlashAttention.
924
- max_seqlen: int. Maximum sequence length in the batch.
925
- """
926
- if cu_seqlens is not None:
927
- assert max_seqlen is not None
928
- assert seqlen is None
929
- assert self.use_flash_attn
930
- if inference_params is not None:
931
- assert cu_seqlens is None and max_seqlen is None
932
- qkv = self.Wqkv(x)
933
- if seqlen is not None:
934
- qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
935
- kwargs = (
936
- {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
937
- if self.use_flash_attn
938
- else kwargs
939
- )
940
- seqlen_offset = (
941
- 0
942
- if inference_params is None
943
- else (
944
- inference_params.lengths_per_sample
945
- if inference_params.lengths_per_sample is not None
946
- else inference_params.seqlen_offset
947
- )
948
- )
949
- rotary_max_seqlen = (
950
- inference_params.max_sequence_len if inference_params is not None else max_seqlen
951
- )
952
- if self.num_heads_kv == self.num_heads:
953
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
954
- if (
955
- inference_params is None
956
- or inference_params.seqlen_offset == 0
957
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
958
- or not self.use_flash_attn
959
- ):
960
- if self.rotary_emb_dim > 0:
961
- qkv = self.rotary_emb(
962
- qkv,
963
- seqlen_offset=seqlen_offset,
964
- cu_seqlens=cu_seqlens,
965
- max_seqlen=rotary_max_seqlen,
966
- )
967
- if inference_params is None:
968
- if not self.checkpointing:
969
- context = self.inner_attn(qkv, **kwargs)
970
- else:
971
- context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
972
- else:
973
- context = self._update_kvcache_attention(
974
- qkv[:, :, 0], qkv[:, :, 1:], inference_params
975
- )
976
- else:
977
- context = self._apply_rotary_update_kvcache_attention(
978
- qkv[:, :, 0], qkv[:, :, 1:], inference_params
979
- )
980
- else:
981
- q = rearrange(
982
- qkv[..., : self.num_heads_per_rank * self.head_dim],
983
- "... (h d) -> ... h d",
984
- d=self.head_dim,
985
- )
986
- kv = rearrange(
987
- qkv[..., self.num_heads_per_rank * self.head_dim :],
988
- "... (two hkv d) -> ... two hkv d",
989
- two=2,
990
- d=self.head_dim,
991
- )
992
- if (
993
- inference_params is None
994
- or inference_params.seqlen_offset == 0
995
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
996
- or not self.use_flash_attn
997
- ):
998
- if self.rotary_emb_dim > 0:
999
- q, kv = self.rotary_emb(
1000
- q,
1001
- kv,
1002
- seqlen_offset=seqlen_offset,
1003
- cu_seqlens=cu_seqlens,
1004
- max_seqlen=rotary_max_seqlen,
1005
- )
1006
- if inference_params is None:
1007
- if not self.checkpointing:
1008
- context = self.inner_cross_attn(q, kv, **kwargs)
1009
- else:
1010
- context = torch.utils.checkpoint.checkpoint(
1011
- self.inner_cross_attn, q, kv, **kwargs
1012
- )
1013
- else:
1014
- context = self._update_kvcache_attention(q, kv, inference_params)
1015
- else:
1016
- context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
1017
- context = rearrange(context, "... h d -> ... (h d)")
1018
- if seqlen is not None:
1019
- context = rearrange(context, "b s d -> (b s) d")
1020
- out = self.out_proj(context)
1021
- return out
 
7
  import torch.nn as nn
8
  from einops import rearrange, repeat
9
 
 
 
10
  try:
11
  from flash_attn import (
12
  flash_attn_kvpacked_func,
 
704
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
705
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
706
  return out if not self.return_residual else (out, x)