namespace-Pt
commited on
Commit
•
d7bd91c
1
Parent(s):
b6abf69
Upload configuration_llama.py with huggingface_hub
Browse files- configuration_llama.py +6 -12
configuration_llama.py
CHANGED
@@ -143,8 +143,9 @@ class LlamaConfig(PretrainedConfig):
|
|
143 |
beacon_attend_previous=True,
|
144 |
beacon_ratio=[8],
|
145 |
beacon_ratio_mix="step-random",
|
146 |
-
|
147 |
-
|
|
|
148 |
**kwargs,
|
149 |
):
|
150 |
self.vocab_size = vocab_size
|
@@ -177,9 +178,9 @@ class LlamaConfig(PretrainedConfig):
|
|
177 |
self.beacon_ratio = beacon_ratio
|
178 |
self.beacon_stride_mix = beacon_stride_mix
|
179 |
self.beacon_ratio_mix = beacon_ratio_mix
|
180 |
-
self.
|
181 |
-
self.
|
182 |
-
self.
|
183 |
|
184 |
super().__init__(
|
185 |
pad_token_id=pad_token_id,
|
@@ -210,10 +211,3 @@ class LlamaConfig(PretrainedConfig):
|
|
210 |
)
|
211 |
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
212 |
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
213 |
-
|
214 |
-
def _beacon_validation(self):
|
215 |
-
for stride in self.beacon_stride:
|
216 |
-
assert self.beacon_window >= stride, f"Make sure the beacon_window {self.beacon_window} >= beacon_stride {stride}!"
|
217 |
-
assert self.beacon_attn in ["segmentation", "step-expansion", "full-coverage"], f"beacon_attn {self.beacon_attn} not implemented!"
|
218 |
-
assert self.beacon_stride_mix in ["instance-random", "step-random", "mix-random"], f"beacon_stride_mix {self.beacon_stride_mix} not implemented!"
|
219 |
-
assert self.beacon_ratio_mix in ["instance-random", "step-random", "mix-random"] or "adapt-" in self.beacon_ratio_mix, f"beacon_ratio_mix {self.beacon_ratio_mix} not implemented!"
|
|
|
143 |
beacon_attend_previous=True,
|
144 |
beacon_ratio=[8],
|
145 |
beacon_ratio_mix="step-random",
|
146 |
+
beacon_param=["q","k","v","o"],
|
147 |
+
retrieval_method=None,
|
148 |
+
retrieval_topk=None,
|
149 |
**kwargs,
|
150 |
):
|
151 |
self.vocab_size = vocab_size
|
|
|
178 |
self.beacon_ratio = beacon_ratio
|
179 |
self.beacon_stride_mix = beacon_stride_mix
|
180 |
self.beacon_ratio_mix = beacon_ratio_mix
|
181 |
+
self.beacon_param = beacon_param
|
182 |
+
self.retrieval_method = retrieval_method
|
183 |
+
self.retrieval_topk = retrieval_topk
|
184 |
|
185 |
super().__init__(
|
186 |
pad_token_id=pad_token_id,
|
|
|
211 |
)
|
212 |
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
213 |
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|