Update src/distilabel_dataset_generator/sft.py
Browse files
src/distilabel_dataset_generator/sft.py
CHANGED
@@ -223,13 +223,12 @@ def generate_dataset(
|
|
223 |
num_turns=1,
|
224 |
num_rows=5,
|
225 |
private=True,
|
226 |
-
|
227 |
-
|
228 |
-
token: OAuthToken = None,
|
229 |
progress=gr.Progress(),
|
230 |
):
|
231 |
-
if
|
232 |
-
if not
|
233 |
raise gr.Error("Please provide a dataset name to push the dataset to.")
|
234 |
if token is None:
|
235 |
raise gr.Error(
|
@@ -280,14 +279,13 @@ def generate_dataset(
|
|
280 |
|
281 |
distiset = result_queue.get()
|
282 |
|
283 |
-
if
|
284 |
progress(0.95, desc="Pushing dataset to Hugging Face Hub.")
|
285 |
-
repo_id = f"{orgs_selector}/{dataset_name}"
|
286 |
distiset.push_to_hub(
|
287 |
repo_id=repo_id,
|
288 |
private=private,
|
289 |
include_script=False,
|
290 |
-
token=token
|
291 |
)
|
292 |
gr.Info(
|
293 |
f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
|
@@ -339,7 +337,6 @@ with gr.Blocks(
|
|
339 |
)
|
340 |
gr.Column(scale=1)
|
341 |
|
342 |
-
#table = gr.HTML(_format_dataframe_as_html(DEFAULT_DATASET))
|
343 |
table = gr.DataFrame(
|
344 |
value=DEFAULT_DATASET,
|
345 |
interactive=False,
|
@@ -347,7 +344,7 @@ with gr.Blocks(
|
|
347 |
|
348 |
)
|
349 |
|
350 |
-
btn_generate_system_prompt.click(
|
351 |
fn=generate_system_prompt,
|
352 |
inputs=[dataset_description],
|
353 |
outputs=[system_prompt],
|
@@ -365,12 +362,10 @@ with gr.Blocks(
|
|
365 |
outputs=[table],
|
366 |
show_progress=True,
|
367 |
)
|
368 |
-
|
369 |
# Add a header for the full dataset generation section
|
370 |
-
gr.Markdown("## Generate full dataset
|
371 |
gr.Markdown("Once you're satisfied with the sample, generate a larger dataset and push it to the hub.")
|
372 |
-
|
373 |
-
btn_login: gr.LoginButton | None = get_login_button()
|
374 |
with gr.Column() as push_to_hub_ui:
|
375 |
with gr.Row(variant="panel"):
|
376 |
num_turns = gr.Number(
|
@@ -386,11 +381,12 @@ with gr.Blocks(
|
|
386 |
maximum=5000,
|
387 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
388 |
)
|
389 |
-
|
390 |
|
391 |
with gr.Row(variant="panel"):
|
392 |
-
|
393 |
-
|
|
|
394 |
|
395 |
btn_generate_full_dataset = gr.Button(
|
396 |
value="⚗️ Generate Full Dataset", variant="primary"
|
@@ -403,12 +399,8 @@ with gr.Blocks(
|
|
403 |
num_turns,
|
404 |
num_rows,
|
405 |
private,
|
406 |
-
|
407 |
-
dataset_name_push_to_hub,
|
408 |
],
|
409 |
outputs=[table],
|
410 |
show_progress=True,
|
411 |
)
|
412 |
-
|
413 |
-
app.load(get_org_dropdown, outputs=[orgs_selector])
|
414 |
-
app.load(fn=swap_visibilty, outputs=push_to_hub_ui)
|
|
|
223 |
num_turns=1,
|
224 |
num_rows=5,
|
225 |
private=True,
|
226 |
+
repo_id=None,
|
227 |
+
token=None,
|
|
|
228 |
progress=gr.Progress(),
|
229 |
):
|
230 |
+
if repo_id is not None:
|
231 |
+
if not repo_id:
|
232 |
raise gr.Error("Please provide a dataset name to push the dataset to.")
|
233 |
if token is None:
|
234 |
raise gr.Error(
|
|
|
279 |
|
280 |
distiset = result_queue.get()
|
281 |
|
282 |
+
if repo_id is not None:
|
283 |
progress(0.95, desc="Pushing dataset to Hugging Face Hub.")
|
|
|
284 |
distiset.push_to_hub(
|
285 |
repo_id=repo_id,
|
286 |
private=private,
|
287 |
include_script=False,
|
288 |
+
token=token,
|
289 |
)
|
290 |
gr.Info(
|
291 |
f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
|
|
|
337 |
)
|
338 |
gr.Column(scale=1)
|
339 |
|
|
|
340 |
table = gr.DataFrame(
|
341 |
value=DEFAULT_DATASET,
|
342 |
interactive=False,
|
|
|
344 |
|
345 |
)
|
346 |
|
347 |
+
result = btn_generate_system_prompt.click(
|
348 |
fn=generate_system_prompt,
|
349 |
inputs=[dataset_description],
|
350 |
outputs=[system_prompt],
|
|
|
362 |
outputs=[table],
|
363 |
show_progress=True,
|
364 |
)
|
365 |
+
|
366 |
# Add a header for the full dataset generation section
|
367 |
+
gr.Markdown("## Generate full dataset")
|
368 |
gr.Markdown("Once you're satisfied with the sample, generate a larger dataset and push it to the hub.")
|
|
|
|
|
369 |
with gr.Column() as push_to_hub_ui:
|
370 |
with gr.Row(variant="panel"):
|
371 |
num_turns = gr.Number(
|
|
|
381 |
maximum=5000,
|
382 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
383 |
)
|
384 |
+
|
385 |
|
386 |
with gr.Row(variant="panel"):
|
387 |
+
hf_token = gr.Textbox(label="HF token")
|
388 |
+
repo_id = gr.Textbox(label="HF repo ID", placeholder="owner/dataset_name")
|
389 |
+
private = gr.Checkbox(label="Private dataset", value=True, interactive=True)
|
390 |
|
391 |
btn_generate_full_dataset = gr.Button(
|
392 |
value="⚗️ Generate Full Dataset", variant="primary"
|
|
|
399 |
num_turns,
|
400 |
num_rows,
|
401 |
private,
|
402 |
+
repo_id,
|
|
|
403 |
],
|
404 |
outputs=[table],
|
405 |
show_progress=True,
|
406 |
)
|
|
|
|
|
|