hywu commited on
Commit
b8573dc
1 Parent(s): 44064bd

update modeling file

Browse files
Files changed (1) hide show
  1. modeling_camelidae.py +3 -0
modeling_camelidae.py CHANGED
@@ -20,6 +20,7 @@
20
  """ PyTorch LLaMA model."""
21
  import math
22
  from typing import List, Optional, Tuple, Union
 
23
 
24
  import numpy as np
25
  import copy
@@ -53,6 +54,7 @@ logger = logging.get_logger(__name__)
53
  _CONFIG_FOR_DOC = "CamelidaeConfig"
54
 
55
 
 
56
  class MoEModelOutputWithPast(ModelOutput):
57
  last_hidden_state: torch.FloatTensor = None
58
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -61,6 +63,7 @@ class MoEModelOutputWithPast(ModelOutput):
61
  router_logits: Optional[Tuple[torch.FloatTensor]] = None
62
 
63
 
 
64
  class MoECausalLMOutputWithPast(ModelOutput):
65
  loss: Optional[torch.FloatTensor] = None
66
  aux_loss: Optional[torch.FloatTensor] = None
 
20
  """ PyTorch LLaMA model."""
21
  import math
22
  from typing import List, Optional, Tuple, Union
23
+ from dataclasses import dataclass
24
 
25
  import numpy as np
26
  import copy
 
54
  _CONFIG_FOR_DOC = "CamelidaeConfig"
55
 
56
 
57
+ @dataclass
58
  class MoEModelOutputWithPast(ModelOutput):
59
  last_hidden_state: torch.FloatTensor = None
60
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
 
63
  router_logits: Optional[Tuple[torch.FloatTensor]] = None
64
 
65
 
66
+ @dataclass
67
  class MoECausalLMOutputWithPast(ModelOutput):
68
  loss: Optional[torch.FloatTensor] = None
69
  aux_loss: Optional[torch.FloatTensor] = None