Text Generation
Transformers
PyTorch
Safetensors
English
gpt_refact
code
custom_code
Eval Results
4 papers
svakhreev commited on
Commit
c88bbae
1 Parent(s): cc8ed8d

Update modeling_gpt_refact.py

Browse files
Files changed (1) hide show
  1. modeling_gpt_refact.py +12 -0
modeling_gpt_refact.py CHANGED
@@ -369,6 +369,12 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
369
  # Initialize weights and apply final processing
370
  self.post_init()
371
 
 
 
 
 
 
 
372
  def forward(
373
  self,
374
  input_ids: Optional[torch.Tensor] = None,
@@ -509,6 +515,12 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
509
  # Initialize weights and apply final processing
510
  self.post_init()
511
 
 
 
 
 
 
 
512
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
513
  if inputs_embeds is not None and past_key_values is None:
514
  model_inputs = {"inputs_embeds": inputs_embeds}
 
369
  # Initialize weights and apply final processing
370
  self.post_init()
371
 
372
+ def get_input_embeddings(self):
373
+ return self.wte
374
+
375
+ def set_input_embeddings(self, new_embeddings):
376
+ self.wte = new_embeddings
377
+
378
  def forward(
379
  self,
380
  input_ids: Optional[torch.Tensor] = None,
 
515
  # Initialize weights and apply final processing
516
  self.post_init()
517
 
518
+ def get_output_embeddings(self):
519
+ return self.lm_head
520
+
521
+ def set_output_embeddings(self, new_embeddings):
522
+ self.lm_head = new_embeddings
523
+
524
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
525
  if inputs_embeds is not None and past_key_values is None:
526
  model_inputs = {"inputs_embeds": inputs_embeds}