OrionZheng commited on
Commit
1536202
1 Parent(s): 0a98445

Update modeling_openmoe.py

Browse files
Files changed (1) hide show
  1. modeling_openmoe.py +92 -84
modeling_openmoe.py CHANGED
@@ -48,40 +48,6 @@ logger = logging.get_logger(__name__)
48
 
49
  _CONFIG_FOR_DOC = "LlamaConfig"
50
 
51
- class LlamaRotaryEmbedding(nn.Module):
52
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
53
- super().__init__()
54
-
55
- self.dim = dim
56
- self.max_position_embeddings = max_position_embeddings
57
- self.base = base
58
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
59
- self.register_buffer("inv_freq", inv_freq, persistent=False)
60
-
61
- # Build here to make `torch.jit.trace` work.
62
- self._set_cos_sin_cache(
63
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
64
- )
65
-
66
- def _set_cos_sin_cache(self, seq_len, device, dtype):
67
- self.max_seq_len_cached = seq_len
68
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
69
-
70
- freqs = torch.outer(t, self.inv_freq) # (seq_len, dim//2)
71
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
72
- emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
73
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
74
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
75
-
76
- def forward(self, x, seq_len=None):
77
- # x: [bs, num_attention_heads, seq_len, head_size]
78
- if seq_len > self.max_seq_len_cached:
79
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
80
-
81
- return (
82
- self.cos_cached[:seq_len].to(dtype=x.dtype),
83
- self.sin_cached[:seq_len].to(dtype=x.dtype),
84
- )
85
 
86
  def set_openmoe_args(
87
  config: LlamaConfig,
@@ -191,6 +157,72 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
191
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def rotate_half(x):
195
  """Rotates half the hidden dims of the input."""
196
  x1 = x[..., : x.shape[-1] // 2]
@@ -198,33 +230,6 @@ def rotate_half(x):
198
  return torch.cat((-x2, x1), dim=-1)
199
 
200
 
201
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
202
- """Applies Rotary Position Embedding to the query and key tensors.
203
-
204
- Args:
205
- q (`torch.Tensor`): The query tensor.
206
- k (`torch.Tensor`): The key tensor.
207
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
208
- sin (`torch.Tensor`): The sine part of the rotary embedding.
209
- position_ids (`torch.Tensor`):
210
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
211
- used to pass offsetted position ids when working with a KV-cache.
212
- unsqueeze_dim (`int`, *optional*, defaults to 1):
213
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
214
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
215
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
216
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
217
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
218
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
219
- Returns:
220
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
221
- """
222
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
223
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
224
- q_embed = (q * cos) + (rotate_half(q) * sin)
225
- k_embed = (k * cos) + (rotate_half(k) * sin)
226
- return q_embed, k_embed
227
-
228
  def SwiGLU(x):
229
  """Gated linear unit activation function.
230
  Args:
@@ -297,24 +302,15 @@ class OpenMoeAttention(nn.Module):
297
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
298
  self.pretraining_tp = config.pretraining_tp
299
  self.max_position_embeddings = config.max_position_embeddings
300
- self.rope_theta = config.rope_theta
301
 
302
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
303
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
304
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
305
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
306
- self._init_rope()
307
-
308
- def _init_rope(self):
309
- if self.config.rope_scaling is None:
310
- self.rotary_emb = LlamaRotaryEmbedding(
311
- self.head_dim,
312
- max_position_embeddings=self.max_position_embeddings,
313
- base=self.rope_theta,
314
- )
315
- else:
316
- raise ValueError(f"Only Original RotaryEmbedding is supported yet")
317
-
318
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
319
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
320
 
@@ -350,28 +346,40 @@ class OpenMoeAttention(nn.Module):
350
  key_states = self.k_proj(hidden_states)
351
  value_states = self.v_proj(hidden_states)
352
 
353
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, q_len, head_dim)
354
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, q_len, head_dim)
355
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, q_len, head_dim)
356
 
357
- kv_seq_len = key_states.shape[-2]
 
 
 
 
358
  if past_key_value is not None:
359
- kv_seq_len += past_key_value[0].shape[-2]
360
  # reuse k, v, self_attention
361
- key_states = torch.cat([past_key_value[0], key_states], dim=2) # (bsz, num_heads, q_len+past_kv_len, head_dim)
362
- value_states = torch.cat([past_key_value[1], value_states], dim=2) # (bsz, num_heads, q_len+past_kv_len, head_dim)
363
 
364
  past_key_value = (key_states, value_states) if use_cache else None
365
 
366
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
367
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
 
 
 
 
 
 
 
368
 
369
  # repeat k/v heads if n_kv_heads < n_heads
370
  key_states = repeat_kv(key_states, self.num_key_value_groups)
371
  value_states = repeat_kv(value_states, self.num_key_value_groups)
372
 
373
  if HAS_FLASH_ATTN and use_kernel:
374
- exec("from flash_attn import flash_attn_func")
375
 
376
  query_states = query_states.transpose(1, 2)
377
  key_states = key_states.transpose(1, 2)
 
48
 
49
  _CONFIG_FOR_DOC = "LlamaConfig"
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def set_openmoe_args(
53
  config: LlamaConfig,
 
157
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
158
 
159
 
160
+ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timescale=10000.0):
161
+ """Generate Sin/Cos for Rotary Embeddings.
162
+
163
+ Args:
164
+ features: an integer
165
+ length: an integer
166
+ min_timescale: an optional float
167
+ max_timescale: an optional float
168
+
169
+ Returns:
170
+ output_sin: a float32 Tensor with shape [length, features]
171
+ output_cos: a float32 Tensor with shape [length, features]
172
+ """
173
+ fraction = torch.arange(0, features, 2, dtype=torch.float32) / features
174
+ timescale = min_timescale * (max_timescale / min_timescale) ** fraction
175
+ rotational_frequency = 1.0 / timescale
176
+
177
+ sinusoid_inp = torch.einsum("i,j->ij", torch.arange(length, dtype=torch.float32), rotational_frequency)
178
+
179
+ sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)
180
+
181
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
182
+
183
+
184
+ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None):
185
+ # q: (bs, q_len, num_heads, head_dim)
186
+ # k: (bs, q_len [+past_kv_len], num_heads, head_dim)
187
+ # cos: (max_seq_len, head_dim)
188
+ # sin: (max_seq_len, head_dim)
189
+ # rotary_index: (bs, 1) # only used during decoding, when one query token is input at a time
190
+ """Helper function to apply Rotary Embeddings."""
191
+ cos = cos.to(q.dtype)
192
+ sin = sin.to(q.dtype)
193
+
194
+ if len(k.shape) == 3: # for multi query attention
195
+ k = k.unsqueeze(2)
196
+ multiquery = True
197
+ else:
198
+ multiquery = False
199
+
200
+ batch, qlen, qheads, d = q.shape
201
+ kbatch, klen, kheads, kd = k.shape
202
+ assert batch == kbatch, f"{batch} != {kbatch}"
203
+ assert d == kd, f"{d} != {kd}"
204
+ if decode and qlen == 1 and rotary_index is not None:
205
+ qcos = cos[rotary_index, :] # (bs, 1, head_dim)
206
+ qsin = sin[rotary_index, :] # (bs, 1, head_dim)
207
+ qcos = qcos.unsqueeze(2) # (bs, q_len=1, 1, head_dim) # broadcast to all heads
208
+ qsin = qsin.unsqueeze(2) # (bs, q_len=1, 1, head_dim)
209
+ else:
210
+ qcos, qsin = cos[:qlen, :], sin[:qlen, :] # (q_len, head_dim)
211
+ qcos = qcos.unsqueeze(0).unsqueeze(2) # (1, q_len, 1, head_dim)
212
+ qsin = qsin.unsqueeze(0).unsqueeze(2)
213
+
214
+ kcos, ksin = cos[:klen, :], sin[:klen, :] # (k_len, head_dim)
215
+ kcos = kcos.unsqueeze(0).unsqueeze(2) # (1, k_len, 1, head_dim) # broadcast to the whole batch, broadcast to all heads
216
+ ksin = ksin.unsqueeze(0).unsqueeze(2) # (1, k_len, 1, head_dim)
217
+ out_q = (q * qcos) + (rotate_half(q) * qsin)
218
+ out_k = (k * kcos) + (rotate_half(k) * ksin)
219
+
220
+ if multiquery:
221
+ out_k = out_k.squeeze(2)
222
+
223
+ return out_q, out_k
224
+
225
+
226
  def rotate_half(x):
227
  """Rotates half the hidden dims of the input."""
228
  x1 = x[..., : x.shape[-1] // 2]
 
230
  return torch.cat((-x2, x1), dim=-1)
231
 
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def SwiGLU(x):
234
  """Gated linear unit activation function.
235
  Args:
 
302
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
303
  self.pretraining_tp = config.pretraining_tp
304
  self.max_position_embeddings = config.max_position_embeddings
 
305
 
306
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
307
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
308
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
309
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
310
+ sin, cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4)
311
+ self.register_buffer('sin', sin)
312
+ self.register_buffer('cos', cos)
313
+
 
 
 
 
 
 
 
 
314
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
315
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
316
 
 
346
  key_states = self.k_proj(hidden_states)
347
  value_states = self.v_proj(hidden_states)
348
 
349
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
350
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
351
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
352
 
353
+ kv_seq_len = key_states.shape[-2]
354
+ if past_key_value is not None:
355
+ kv_seq_len += past_key_value[0].shape[-2]
356
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
357
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
358
  if past_key_value is not None:
 
359
  # reuse k, v, self_attention
360
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
361
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
362
 
363
  past_key_value = (key_states, value_states) if use_cache else None
364
 
365
+ query_states = query_states.transpose(1, 2)
366
+ key_states = key_states.transpose(1, 2)
367
+ max_length = max(query_states.shape[1], key_states.shape[1])
368
+ assert max_length <= self.sin.shape[0]
369
+ sin, cos = self.sin[:max_length], self.cos[:max_length]
370
+ # TODO: for inference, we can add emb kv into cache to avoid computation
371
+ query_states, key_states = apply_rotary_embedding(
372
+ query_states, key_states, cos, sin, decode=True if q_len == 1 else False, rotary_index=position_ids
373
+ )
374
+ query_states = query_states.transpose(1, 2)
375
+ key_states = key_states.transpose(1, 2)
376
 
377
  # repeat k/v heads if n_kv_heads < n_heads
378
  key_states = repeat_kv(key_states, self.num_key_value_groups)
379
  value_states = repeat_kv(value_states, self.num_key_value_groups)
380
 
381
  if HAS_FLASH_ATTN and use_kernel:
382
+ from flash_attn import flash_attn_func
383
 
384
  query_states = query_states.transpose(1, 2)
385
  key_states = key_states.transpose(1, 2)