AzizBelaweid commited on
Commit
1ed562f
1 Parent(s): 6f39d18

Update modeling_pharia.py (#4)

Browse files

- Update modeling_pharia.py (d86f6e95c9a4b8d3fcbcb7c1169cc7e4dcee0906)

Files changed (1) hide show
  1. modeling_pharia.py +20 -1
modeling_pharia.py CHANGED
@@ -764,9 +764,28 @@ class PhariaForCausalLM(PhariaPreTrainedModel):
764
 
765
  hidden_states = outputs[0]
766
  logits = self.lm_head(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
 
768
  return CausalLMOutputWithPast(
769
- loss=0.0,
770
  logits=logits,
771
  past_key_values=outputs.past_key_values,
772
  hidden_states=outputs.hidden_states,
 
764
 
765
  hidden_states = outputs[0]
766
  logits = self.lm_head(hidden_states)
767
+ loss = 0.0
768
+
769
+ if self.training and labels is None:
770
+ raise ValueError(
771
+ "You have to specify the `labels` tensor when training the model."
772
+ )
773
+
774
+ if self.training and labels is not None:
775
+ # Shift logits and labels for causal language modeling
776
+ shift_logits = logits[..., :-1, :].contiguous()
777
+ shift_labels = outputs['labels'][..., 1:].contiguous()
778
+
779
+ # Flatten the tokens
780
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
781
+ shift_labels = shift_labels.view(-1)
782
+
783
+ # Compute loss
784
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=1) # Pad token ID for Pharia is 1
785
+ loss = loss_fct(shift_logits, shift_labels)
786
 
787
  return CausalLMOutputWithPast(
788
+ loss=loss,
789
  logits=logits,
790
  past_key_values=outputs.past_key_values,
791
  hidden_states=outputs.hidden_states,