Performance improvements
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -22,11 +22,12 @@ from src.distilabel_dataset_generator.utils import (
|
|
22 |
)
|
23 |
|
24 |
|
25 |
-
def _run_pipeline(result_queue, num_turns, num_rows, system_prompt):
|
26 |
pipeline = get_pipeline(
|
27 |
num_turns,
|
28 |
num_rows,
|
29 |
system_prompt,
|
|
|
30 |
)
|
31 |
distiset: Distiset = pipeline.run(use_cache=False)
|
32 |
result_queue.put(distiset)
|
@@ -54,7 +55,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
54 |
|
55 |
def generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
56 |
progress(0.1, desc="Initializing sample dataset generation")
|
57 |
-
result = generate_dataset(system_prompt, num_turns=1, num_rows=1, progress=progress)
|
58 |
progress(1.0, desc="Sample dataset generated")
|
59 |
return result
|
60 |
|
@@ -68,6 +69,7 @@ def generate_dataset(
|
|
68 |
repo_name: str = None,
|
69 |
oauth_token: str = None,
|
70 |
progress=gr.Progress(),
|
|
|
71 |
):
|
72 |
repo_id = (
|
73 |
f"{org_name}/{repo_name}"
|
@@ -88,8 +90,9 @@ def generate_dataset(
|
|
88 |
gr.Info(
|
89 |
"You can only generate a dataset with 1000 or fewer rows. Setting to 1000."
|
90 |
)
|
91 |
-
|
92 |
-
|
|
|
93 |
duration = 60
|
94 |
elif num_rows < 30:
|
95 |
duration = 120
|
@@ -105,7 +108,7 @@ def generate_dataset(
|
|
105 |
result_queue = multiprocessing.Queue()
|
106 |
p = multiprocessing.Process(
|
107 |
target=_run_pipeline,
|
108 |
-
args=(result_queue, num_turns, num_rows, system_prompt),
|
109 |
)
|
110 |
|
111 |
try:
|
@@ -175,28 +178,31 @@ with gr.Blocks(
|
|
175 |
)
|
176 |
with gr.Row():
|
177 |
gr.Column(scale=1)
|
178 |
-
btn_generate_system_prompt = gr.Button(value="Generate sample
|
179 |
gr.Column(scale=1)
|
|
|
180 |
|
181 |
system_prompt = gr.TextArea(
|
182 |
-
label="
|
183 |
value=DEFAULT_SYSTEM_PROMPT,
|
184 |
)
|
185 |
|
186 |
-
with gr.Row():
|
187 |
-
gr.Column(scale=1)
|
188 |
-
btn_generate_sample_dataset = gr.Button(
|
189 |
-
value="Regenerate sample dataset",
|
190 |
-
)
|
191 |
-
gr.Column(scale=1)
|
192 |
-
|
193 |
with gr.Row():
|
194 |
table = gr.DataFrame(
|
195 |
value=DEFAULT_DATASET,
|
|
|
196 |
interactive=False,
|
197 |
wrap=True,
|
198 |
)
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
result = btn_generate_system_prompt.click(
|
201 |
fn=generate_system_prompt,
|
202 |
inputs=[dataset_description],
|
@@ -233,10 +239,10 @@ with gr.Blocks(
|
|
233 |
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
234 |
)
|
235 |
num_rows = gr.Number(
|
236 |
-
value=
|
237 |
label="Number of rows in the dataset",
|
238 |
minimum=1,
|
239 |
-
maximum=
|
240 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
241 |
)
|
242 |
|
@@ -249,16 +255,24 @@ with gr.Blocks(
|
|
249 |
visible=False,
|
250 |
)
|
251 |
org_name = get_org_dropdown()
|
252 |
-
repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name")
|
253 |
private = gr.Checkbox(
|
254 |
label="Private dataset", value=True, interactive=True, scale=0.5
|
255 |
)
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
261 |
success_message = gr.Markdown(visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
def show_success_message(org_name, repo_name):
|
264 |
return gr.Markdown(
|
@@ -294,7 +308,7 @@ with gr.Blocks(
|
|
294 |
repo_name,
|
295 |
oauth_token,
|
296 |
],
|
297 |
-
outputs=[
|
298 |
show_progress=True,
|
299 |
).success(
|
300 |
fn=show_success_message,
|
|
|
22 |
)
|
23 |
|
24 |
|
25 |
+
def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, is_sample):
|
26 |
pipeline = get_pipeline(
|
27 |
num_turns,
|
28 |
num_rows,
|
29 |
system_prompt,
|
30 |
+
is_sample
|
31 |
)
|
32 |
distiset: Distiset = pipeline.run(use_cache=False)
|
33 |
result_queue.put(distiset)
|
|
|
55 |
|
56 |
def generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
57 |
progress(0.1, desc="Initializing sample dataset generation")
|
58 |
+
result = generate_dataset(system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True)
|
59 |
progress(1.0, desc="Sample dataset generated")
|
60 |
return result
|
61 |
|
|
|
69 |
repo_name: str = None,
|
70 |
oauth_token: str = None,
|
71 |
progress=gr.Progress(),
|
72 |
+
is_sample: bool = False,
|
73 |
):
|
74 |
repo_id = (
|
75 |
f"{org_name}/{repo_name}"
|
|
|
90 |
gr.Info(
|
91 |
"You can only generate a dataset with 1000 or fewer rows. Setting to 1000."
|
92 |
)
|
93 |
+
if num_rows < 5:
|
94 |
+
duration = 25
|
95 |
+
elif num_rows < 10:
|
96 |
duration = 60
|
97 |
elif num_rows < 30:
|
98 |
duration = 120
|
|
|
108 |
result_queue = multiprocessing.Queue()
|
109 |
p = multiprocessing.Process(
|
110 |
target=_run_pipeline,
|
111 |
+
args=(result_queue, num_turns, num_rows, system_prompt, is_sample),
|
112 |
)
|
113 |
|
114 |
try:
|
|
|
178 |
)
|
179 |
with gr.Row():
|
180 |
gr.Column(scale=1)
|
181 |
+
btn_generate_system_prompt = gr.Button(value="Generate sample")
|
182 |
gr.Column(scale=1)
|
183 |
+
|
184 |
|
185 |
system_prompt = gr.TextArea(
|
186 |
+
label="System prompt for dataset generation. You can tune it and regenerate the sample",
|
187 |
value=DEFAULT_SYSTEM_PROMPT,
|
188 |
)
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
with gr.Row():
|
191 |
table = gr.DataFrame(
|
192 |
value=DEFAULT_DATASET,
|
193 |
+
label="Sample dataset. Prompts and completions truncated to 256 tokens.",
|
194 |
interactive=False,
|
195 |
wrap=True,
|
196 |
)
|
197 |
|
198 |
+
|
199 |
+
with gr.Row():
|
200 |
+
gr.Column(scale=1)
|
201 |
+
btn_generate_sample_dataset = gr.Button(
|
202 |
+
value="Regenerate sample",
|
203 |
+
)
|
204 |
+
gr.Column(scale=1)
|
205 |
+
|
206 |
result = btn_generate_system_prompt.click(
|
207 |
fn=generate_system_prompt,
|
208 |
inputs=[dataset_description],
|
|
|
239 |
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
240 |
)
|
241 |
num_rows = gr.Number(
|
242 |
+
value=10,
|
243 |
label="Number of rows in the dataset",
|
244 |
minimum=1,
|
245 |
+
maximum=500,
|
246 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
247 |
)
|
248 |
|
|
|
255 |
visible=False,
|
256 |
)
|
257 |
org_name = get_org_dropdown()
|
258 |
+
repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name", value="my-distiset")
|
259 |
private = gr.Checkbox(
|
260 |
label="Private dataset", value=True, interactive=True, scale=0.5
|
261 |
)
|
262 |
+
with gr.Row() as regenerate_row:
|
263 |
+
gr.Column(scale=1)
|
264 |
+
btn_generate_full_dataset = gr.Button(
|
265 |
+
value="Generate Full Dataset", variant="primary"
|
266 |
+
)
|
267 |
+
gr.Column(scale=1)
|
268 |
success_message = gr.Markdown(visible=False)
|
269 |
+
with gr.Row():
|
270 |
+
final_dataset = gr.DataFrame(
|
271 |
+
value=DEFAULT_DATASET,
|
272 |
+
label="Generated dataset",
|
273 |
+
interactive=False,
|
274 |
+
wrap=True,
|
275 |
+
)
|
276 |
|
277 |
def show_success_message(org_name, repo_name):
|
278 |
return gr.Markdown(
|
|
|
308 |
repo_name,
|
309 |
oauth_token,
|
310 |
],
|
311 |
+
outputs=[final_dataset],
|
312 |
show_progress=True,
|
313 |
).success(
|
314 |
fn=show_success_message,
|