Update modeling_mpt.py (#55)
Browse files- Update modeling_mpt.py (775941c1853f41fa81585628a00c5b6eb33c1880)
- modeling_mpt.py +1 -1
modeling_mpt.py
CHANGED
@@ -182,7 +182,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
182 |
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
183 |
assert isinstance(self.emb_drop, nn.Module)
|
184 |
x = self.emb_drop(x_shrunk)
|
185 |
-
(attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=
|
186 |
if use_cache and past_key_values is None:
|
187 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
188 |
all_hidden_states = () if output_hidden_states else None
|
|
|
182 |
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
183 |
assert isinstance(self.emb_drop, nn.Module)
|
184 |
x = self.emb_drop(x_shrunk)
|
185 |
+
(attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
|
186 |
if use_cache and past_key_values is None:
|
187 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
188 |
all_hidden_states = () if output_hidden_states else None
|