Commit
·
f7b33f1
1
Parent(s):
75f9ac3
feat: map to trl/autotrain compatible columns
Browse files
src/distilabel_dataset_generator/sft.py
CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
|
|
5 |
from distilabel.distiset import Distiset
|
6 |
from distilabel.llms import InferenceEndpointsLLM
|
7 |
from distilabel.pipeline import Pipeline
|
|
|
8 |
from distilabel.steps.tasks import MagpieGenerator, TextGeneration
|
9 |
|
10 |
from src.distilabel_dataset_generator.utils import (
|
@@ -141,8 +142,18 @@ DEFAULT_DATASET = pd.DataFrame(
|
|
141 |
|
142 |
|
143 |
def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, token: str = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
with Pipeline(name="sft") as pipeline:
|
145 |
-
|
146 |
llm=InferenceEndpointsLLM(
|
147 |
model_id=MODEL,
|
148 |
tokenizer_id=MODEL,
|
@@ -150,6 +161,7 @@ def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, token: str =
|
|
150 |
generation_kwargs={
|
151 |
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
152 |
"do_sample": True,
|
|
|
153 |
"stop_sequences": [
|
154 |
"<|eot_id|>",
|
155 |
"<|end_of_text|>",
|
@@ -163,7 +175,12 @@ def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, token: str =
|
|
163 |
n_turns=num_turns,
|
164 |
num_rows=num_rows,
|
165 |
system_prompt=system_prompt,
|
|
|
|
|
|
|
|
|
166 |
)
|
|
|
167 |
distiset: Distiset = pipeline.run()
|
168 |
result_queue.put(distiset)
|
169 |
|
@@ -212,7 +229,9 @@ def generate_dataset(
|
|
212 |
"Please sign in with Hugging Face to be able to push the dataset to the Hub."
|
213 |
)
|
214 |
|
215 |
-
gr.Info(
|
|
|
|
|
216 |
result_queue = multiprocessing.Queue()
|
217 |
p = multiprocessing.Process(
|
218 |
target=_run_pipeline,
|
@@ -223,7 +242,7 @@ def generate_dataset(
|
|
223 |
distiset = result_queue.get()
|
224 |
|
225 |
if dataset_name is not None:
|
226 |
-
gr.Info("Pushing dataset to Hugging Face Hub
|
227 |
repo_id = f"{orgs_selector}/{dataset_name}"
|
228 |
distiset.push_to_hub(
|
229 |
repo_id=repo_id,
|
@@ -231,17 +250,19 @@ def generate_dataset(
|
|
231 |
include_script=False,
|
232 |
token=token.token,
|
233 |
)
|
234 |
-
gr.Info(
|
|
|
|
|
235 |
else:
|
236 |
# If not pushing to hub generate the dataset directly
|
237 |
distiset = distiset["default"]["train"]
|
238 |
if num_turns == 1:
|
239 |
-
outputs = distiset.to_pandas()[["
|
240 |
else:
|
241 |
outputs = {"conversation_id": [], "role": [], "content": []}
|
242 |
-
conversations = distiset["
|
243 |
for idx, entry in enumerate(conversations):
|
244 |
-
for message in entry["
|
245 |
outputs["conversation_id"].append(idx + 1)
|
246 |
outputs["role"].append(message["role"])
|
247 |
outputs["content"].append(message["content"])
|
|
|
5 |
from distilabel.distiset import Distiset
|
6 |
from distilabel.llms import InferenceEndpointsLLM
|
7 |
from distilabel.pipeline import Pipeline
|
8 |
+
from distilabel.steps import KeepColumns
|
9 |
from distilabel.steps.tasks import MagpieGenerator, TextGeneration
|
10 |
|
11 |
from src.distilabel_dataset_generator.utils import (
|
|
|
142 |
|
143 |
|
144 |
def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, token: str = None):
|
145 |
+
output_mappings = (
|
146 |
+
{
|
147 |
+
"instruction": "prompt",
|
148 |
+
"response": "completion",
|
149 |
+
}
|
150 |
+
if num_turns == 1
|
151 |
+
else {
|
152 |
+
"conversation": "messages",
|
153 |
+
}
|
154 |
+
)
|
155 |
with Pipeline(name="sft") as pipeline:
|
156 |
+
magpie = MagpieGenerator(
|
157 |
llm=InferenceEndpointsLLM(
|
158 |
model_id=MODEL,
|
159 |
tokenizer_id=MODEL,
|
|
|
161 |
generation_kwargs={
|
162 |
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
163 |
"do_sample": True,
|
164 |
+
"max_new_tokens": 2048,
|
165 |
"stop_sequences": [
|
166 |
"<|eot_id|>",
|
167 |
"<|end_of_text|>",
|
|
|
175 |
n_turns=num_turns,
|
176 |
num_rows=num_rows,
|
177 |
system_prompt=system_prompt,
|
178 |
+
output_mappings=output_mappings,
|
179 |
+
)
|
180 |
+
keep_columns = KeepColumns(
|
181 |
+
columns=list(output_mappings.values()) + ["model_name"],
|
182 |
)
|
183 |
+
magpie.connect(keep_columns)
|
184 |
distiset: Distiset = pipeline.run()
|
185 |
result_queue.put(distiset)
|
186 |
|
|
|
229 |
"Please sign in with Hugging Face to be able to push the dataset to the Hub."
|
230 |
)
|
231 |
|
232 |
+
gr.Info(
|
233 |
+
"Started pipeline execution. This might take a while, depending on the number of rows and turns you have selected. Don't close this page."
|
234 |
+
)
|
235 |
result_queue = multiprocessing.Queue()
|
236 |
p = multiprocessing.Process(
|
237 |
target=_run_pipeline,
|
|
|
242 |
distiset = result_queue.get()
|
243 |
|
244 |
if dataset_name is not None:
|
245 |
+
gr.Info("Pushing dataset to Hugging Face Hub.")
|
246 |
repo_id = f"{orgs_selector}/{dataset_name}"
|
247 |
distiset.push_to_hub(
|
248 |
repo_id=repo_id,
|
|
|
250 |
include_script=False,
|
251 |
token=token.token,
|
252 |
)
|
253 |
+
gr.Info(
|
254 |
+
f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
|
255 |
+
)
|
256 |
else:
|
257 |
# If not pushing to hub generate the dataset directly
|
258 |
distiset = distiset["default"]["train"]
|
259 |
if num_turns == 1:
|
260 |
+
outputs = distiset.to_pandas()[["prompt", "completion"]]
|
261 |
else:
|
262 |
outputs = {"conversation_id": [], "role": [], "content": []}
|
263 |
+
conversations = distiset["messages"]
|
264 |
for idx, entry in enumerate(conversations):
|
265 |
+
for message in entry["messages"]:
|
266 |
outputs["conversation_id"].append(idx + 1)
|
267 |
outputs["role"].append(message["role"])
|
268 |
outputs["content"].append(message["content"])
|