Match args/kwargs for wrapped function
Browse files- modeling_mpt.py +7 -6
modeling_mpt.py
CHANGED
@@ -242,19 +242,20 @@ class MPTModel(MPTPreTrainedModel):
|
|
242 |
if self.gradient_checkpointing and self.training:
|
243 |
|
244 |
def create_custom_forward(module):
|
245 |
-
def custom_forward(*inputs):
|
246 |
# None for past_key_value
|
247 |
-
return module(*inputs)
|
248 |
|
249 |
return custom_forward
|
250 |
|
251 |
(x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
|
252 |
create_custom_forward(block),
|
253 |
x,
|
254 |
-
past_key_value,
|
255 |
-
attn_bias,
|
256 |
-
attention_mask,
|
257 |
-
self.is_causal,
|
|
|
258 |
)
|
259 |
else:
|
260 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
|
|
242 |
if self.gradient_checkpointing and self.training:
|
243 |
|
244 |
def create_custom_forward(module):
|
245 |
+
def custom_forward(*inputs, **kwargs):
|
246 |
# None for past_key_value
|
247 |
+
return module(*inputs, **kwargs)
|
248 |
|
249 |
return custom_forward
|
250 |
|
251 |
(x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
|
252 |
create_custom_forward(block),
|
253 |
x,
|
254 |
+
past_key_value=past_key_value,
|
255 |
+
attn_bias=attn_bias,
|
256 |
+
attention_mask=attention_mask,
|
257 |
+
is_causal=self.is_causal,
|
258 |
+
output_attentions=bool(output_attentions)
|
259 |
)
|
260 |
else:
|
261 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|