Commit
·
e1fdeee
1
Parent(s):
2f12409
chore: add token
Browse files
src/distilabel_dataset_generator/sft.py
CHANGED
@@ -116,24 +116,12 @@ The prompt you write should follow the same style and structure as the following
|
|
116 |
User dataset description:
|
117 |
"""
|
118 |
|
119 |
-
MODEL = "meta-llama/Meta-Llama-3.1-
|
120 |
-
|
121 |
-
generate_description = TextGeneration(
|
122 |
-
llm=InferenceEndpointsLLM(
|
123 |
-
model_id=MODEL,
|
124 |
-
tokenizer_id=MODEL,
|
125 |
-
generation_kwargs={
|
126 |
-
"temperature": 0.8,
|
127 |
-
"max_new_tokens": 2048,
|
128 |
-
"do_sample": True,
|
129 |
-
},
|
130 |
-
),
|
131 |
-
use_system_prompt=True,
|
132 |
-
)
|
133 |
-
generate_description.load()
|
134 |
|
135 |
|
136 |
-
def _run_pipeline(
|
|
|
|
|
137 |
with Pipeline(name="sft") as pipeline:
|
138 |
magpie_step = MagpieGenerator(
|
139 |
llm=InferenceEndpointsLLM(
|
@@ -143,6 +131,7 @@ def _run_pipeline(result_queue, _num_turns, _num_rows, _system_prompt):
|
|
143 |
generation_kwargs={
|
144 |
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
145 |
},
|
|
|
146 |
),
|
147 |
n_turns=_num_turns,
|
148 |
num_rows=_num_rows,
|
@@ -152,7 +141,21 @@ def _run_pipeline(result_queue, _num_turns, _num_rows, _system_prompt):
|
|
152 |
result_queue.put(distiset)
|
153 |
|
154 |
|
155 |
-
def _generate_system_prompt(_dataset_description):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
return next(
|
157 |
generate_description.process(
|
158 |
[
|
|
|
116 |
User dataset description:
|
117 |
"""
|
118 |
|
119 |
+
MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
|
122 |
+
def _run_pipeline(
|
123 |
+
result_queue, _num_turns, _num_rows, _system_prompt, _token: OAuthToken = None
|
124 |
+
):
|
125 |
with Pipeline(name="sft") as pipeline:
|
126 |
magpie_step = MagpieGenerator(
|
127 |
llm=InferenceEndpointsLLM(
|
|
|
131 |
generation_kwargs={
|
132 |
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
133 |
},
|
134 |
+
api_key=_token,
|
135 |
),
|
136 |
n_turns=_num_turns,
|
137 |
num_rows=_num_rows,
|
|
|
141 |
result_queue.put(distiset)
|
142 |
|
143 |
|
144 |
+
def _generate_system_prompt(_dataset_description, _token: OAuthToken = None):
|
145 |
+
generate_description = TextGeneration(
|
146 |
+
llm=InferenceEndpointsLLM(
|
147 |
+
model_id=MODEL,
|
148 |
+
tokenizer_id=MODEL,
|
149 |
+
generation_kwargs={
|
150 |
+
"temperature": 0.8,
|
151 |
+
"max_new_tokens": 2048,
|
152 |
+
"do_sample": True,
|
153 |
+
},
|
154 |
+
api_key=_token,
|
155 |
+
),
|
156 |
+
use_system_prompt=True,
|
157 |
+
)
|
158 |
+
generate_description.load()
|
159 |
return next(
|
160 |
generate_description.process(
|
161 |
[
|