Text Generation
Transformers
PyTorch
Safetensors
English
gpt_refact
code
custom_code
Eval Results
svakhreev commited on
Commit
fa166e1
1 Parent(s): c5d31de

Update modeling_gpt_refact.py

Browse files
Files changed (1) hide show
  1. modeling_gpt_refact.py +12 -0
modeling_gpt_refact.py CHANGED
@@ -503,6 +503,18 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
503
 
504
  # Initialize weights and apply final processing
505
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
508
  if inputs_embeds is not None and past_key_values is None:
 
503
 
504
  # Initialize weights and apply final processing
505
  self.post_init()
506
+
507
+ # gradient checkpointing support for lower versions of transformers
508
+ import transformers
509
+ from packaging import version
510
+
511
+ def _set_gradient_checkpointing(module, enable=False):
512
+ if isinstance(module, GPTRefactModel):
513
+ module.gradient_checkpointing = enable
514
+
515
+ v = version.parse(transformers.__version__)
516
+ if v.major <= 4 and v.minor < 35:
517
+ self._set_gradient_checkpointing = _set_gradient_checkpointing
518
 
519
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
520
  if inputs_embeds is not None and past_key_values is None: