Update modeling_phi.py
Browse files- modeling_phi.py +1 -1
modeling_phi.py
CHANGED
@@ -308,7 +308,6 @@ class PhiAttention(nn.Module):
|
|
308 |
past_key_value: Optional[Cache] = None,
|
309 |
output_attentions: bool = False,
|
310 |
use_cache: bool = False,
|
311 |
-
**kwargs,
|
312 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
313 |
bsz, q_len, _ = hidden_states.size()
|
314 |
|
@@ -358,6 +357,7 @@ class PhiAttention(nn.Module):
|
|
358 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
359 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
360 |
|
|
|
361 |
attn_weights = torch.matmul(
|
362 |
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
363 |
) / math.sqrt(self.head_dim)
|
|
|
308 |
past_key_value: Optional[Cache] = None,
|
309 |
output_attentions: bool = False,
|
310 |
use_cache: bool = False,
|
|
|
311 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
312 |
bsz, q_len, _ = hidden_states.size()
|
313 |
|
|
|
357 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
358 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
359 |
|
360 |
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
361 |
attn_weights = torch.matmul(
|
362 |
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
363 |
) / math.sqrt(self.head_dim)
|