Match args/kwargs for wrapped function

#7
Files changed (1) hide show
  1. modeling_mpt.py +7 -6
modeling_mpt.py CHANGED
@@ -242,19 +242,20 @@ class MPTModel(MPTPreTrainedModel):
242
  if self.gradient_checkpointing and self.training:
243
 
244
  def create_custom_forward(module):
245
- def custom_forward(*inputs):
246
  # None for past_key_value
247
- return module(*inputs)
248
 
249
  return custom_forward
250
 
251
  (x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
252
  create_custom_forward(block),
253
  x,
254
- past_key_value,
255
- attn_bias,
256
- attention_mask,
257
- self.is_causal,
 
258
  )
259
  else:
260
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
 
242
  if self.gradient_checkpointing and self.training:
243
 
244
  def create_custom_forward(module):
245
+ def custom_forward(*inputs, **kwargs):
246
  # None for past_key_value
247
+ return module(*inputs, **kwargs)
248
 
249
  return custom_forward
250
 
251
  (x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
252
  create_custom_forward(block),
253
  x,
254
+ past_key_value=past_key_value,
255
+ attn_bias=attn_bias,
256
+ attention_mask=attention_mask,
257
+ is_causal=self.is_causal,
258
+ output_attentions=bool(output_attentions)
259
  )
260
  else:
261
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))