gugarosa commited on
Commit
e0f03c4
1 Parent(s): 051d15f

Update modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +1 -1
modeling_phi.py CHANGED
@@ -308,7 +308,6 @@ class PhiAttention(nn.Module):
308
  past_key_value: Optional[Cache] = None,
309
  output_attentions: bool = False,
310
  use_cache: bool = False,
311
- **kwargs,
312
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
313
  bsz, q_len, _ = hidden_states.size()
314
 
@@ -358,6 +357,7 @@ class PhiAttention(nn.Module):
358
  key_states = repeat_kv(key_states, self.num_key_value_groups)
359
  value_states = repeat_kv(value_states, self.num_key_value_groups)
360
 
 
361
  attn_weights = torch.matmul(
362
  query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
363
  ) / math.sqrt(self.head_dim)
 
308
  past_key_value: Optional[Cache] = None,
309
  output_attentions: bool = False,
310
  use_cache: bool = False,
 
311
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
312
  bsz, q_len, _ = hidden_states.size()
313
 
 
357
  key_states = repeat_kv(key_states, self.num_key_value_groups)
358
  value_states = repeat_kv(value_states, self.num_key_value_groups)
359
 
360
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
361
  attn_weights = torch.matmul(
362
  query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
363
  ) / math.sqrt(self.head_dim)