Crystalcareai
commited on
Commit
•
835534a
1
Parent(s):
f2459a7
Update modeling_quiet.py
Browse files- modeling_quiet.py +28 -2
modeling_quiet.py
CHANGED
@@ -22,6 +22,7 @@ import inspect
|
|
22 |
import math
|
23 |
import warnings
|
24 |
from typing import List, Optional, Tuple, Union
|
|
|
25 |
|
26 |
import torch
|
27 |
import torch.nn.functional as F
|
@@ -56,6 +57,31 @@ logger = logging.get_logger(__name__)
|
|
56 |
|
57 |
_CONFIG_FOR_DOC = "QuietConfig"
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
61 |
def _get_unpad_data(attention_mask):
|
@@ -1097,7 +1123,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
1097 |
|
1098 |
if not return_dict:
|
1099 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1100 |
-
return
|
1101 |
last_hidden_state=hidden_states,
|
1102 |
past_key_values=next_cache,
|
1103 |
hidden_states=all_hidden_states,
|
@@ -1216,7 +1242,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1216 |
)
|
1217 |
|
1218 |
hidden_states = outputs.last_hidden_state
|
1219 |
-
base_logits = outputs.logits
|
1220 |
|
1221 |
thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
|
1222 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
|
|
22 |
import math
|
23 |
import warnings
|
24 |
from typing import List, Optional, Tuple, Union
|
25 |
+
from dataclasses import dataclass
|
26 |
|
27 |
import torch
|
28 |
import torch.nn.functional as F
|
|
|
57 |
|
58 |
_CONFIG_FOR_DOC = "QuietConfig"
|
59 |
|
60 |
+
@dataclass
|
61 |
+
class ModelOutput:
|
62 |
+
"""
|
63 |
+
Base class for model's outputs, with potential hidden states and attentions.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def to_tuple(self):
|
67 |
+
"""
|
68 |
+
Convert the output to a tuple.
|
69 |
+
"""
|
70 |
+
return tuple(self[k] for k in self.keys())
|
71 |
+
|
72 |
+
@dataclass
|
73 |
+
class BaseModelOutput(ModelOutput):
|
74 |
+
last_hidden_state: torch.FloatTensor = None
|
75 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
76 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class QuietModelOutputWithPast(BaseModelOutput):
|
80 |
+
last_hidden_state: torch.FloatTensor = None
|
81 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
82 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
83 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
84 |
+
logits: torch.FloatTensor = None
|
85 |
|
86 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
87 |
def _get_unpad_data(attention_mask):
|
|
|
1123 |
|
1124 |
if not return_dict:
|
1125 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1126 |
+
return QuietModelOutputWithPast(
|
1127 |
last_hidden_state=hidden_states,
|
1128 |
past_key_values=next_cache,
|
1129 |
hidden_states=all_hidden_states,
|
|
|
1242 |
)
|
1243 |
|
1244 |
hidden_states = outputs.last_hidden_state
|
1245 |
+
base_logits = outputs.logits
|
1246 |
|
1247 |
thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
|
1248 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|