app.py CHANGED
@@ -54,6 +54,17 @@ demo = gr.TabbedInterface(
54
  margin-bottom: 20px;
55
  }
56
  }
 
 
 
 
 
 
 
 
 
 
 
57
  </style>
58
  <div class="header-container">
59
  <div class="logo-container">
@@ -62,7 +73,7 @@ demo = gr.TabbedInterface(
62
  </a>
63
  </div>
64
  <div class="title-container">
65
- <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
66
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
67
  </div>
68
  </div>
 
54
  margin-bottom: 20px;
55
  }
56
  }
57
+ button[role="tab"].selected,
58
+ button[role="tab"][aria-selected="true"],
59
+ button[role="tab"][data-tab-id][aria-selected="true"] {
60
+ background-color: #000000;
61
+ color: white;
62
+ border: none;
63
+ font-size: 16px;
64
+ font-weight: bold;
65
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
66
+ transition: background-color 0.3s ease, color 0.3s ease;
67
+ }
68
  </style>
69
  <div class="header-container">
70
  <div class="logo-container">
 
73
  </a>
74
  </div>
75
  <div class="title-container">
76
+ <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
77
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
78
  </div>
79
  </div>
pdm.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -6,11 +6,12 @@ authors = [
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
- "distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
 
12
  ]
13
- requires-python = ">=3.10"
14
  readme = "README.md"
15
  license = {text = "apache 2"}
16
 
 
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
+ "distilabel[hf-inference-endpoints,argilla]==1.4.0",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
12
+ "sentence-transformers>=3.2.0",
13
  ]
14
+ requires-python = "<3.13,>=3.10"
15
  readme = "README.md"
16
  license = {text = "apache 2"}
17
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  transformers
2
  gradio[oauth]
3
- distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop
4
- beautifulsoup4
 
 
1
  transformers
2
  gradio[oauth]
3
+ distilabel[hf-inference-endpoints,argilla]
4
+ beautifulsoup4
5
+ sentence-transformers
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import io
2
- from typing import Union
3
 
 
4
  import gradio as gr
5
  import pandas as pd
6
  from datasets import Dataset
@@ -8,7 +10,12 @@ from distilabel.distiset import Distiset
8
  from distilabel.steps.tasks.text_generation import TextGeneration
9
  from gradio.oauth import OAuthToken
10
  from huggingface_hub import upload_file
 
11
 
 
 
 
 
12
  from src.distilabel_dataset_generator.pipelines.sft import (
13
  DEFAULT_BATCH_SIZE,
14
  DEFAULT_DATASET_DESCRIPTIONS,
@@ -21,12 +28,21 @@ from src.distilabel_dataset_generator.pipelines.sft import (
21
  get_response_generator,
22
  )
23
  from src.distilabel_dataset_generator.utils import (
 
24
  get_login_button,
25
  get_org_dropdown,
26
  swap_visibilty,
27
  )
28
 
29
 
 
 
 
 
 
 
 
 
30
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
31
  progress(0.0, desc="Generating system prompt")
32
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
@@ -82,7 +98,7 @@ def generate_dataset(
82
  num_rows: int = 5,
83
  is_sample: bool = False,
84
  progress=gr.Progress(),
85
- ):
86
  progress(0.0, desc="(1/2) Generating instructions")
87
  magpie_generator = get_magpie_generator(
88
  num_turns, num_rows, system_prompt, is_sample
@@ -191,7 +207,12 @@ def push_to_hub(
191
  repo_name: str = None,
192
  oauth_token: Union[OAuthToken, None] = None,
193
  progress=gr.Progress(),
194
- ):
 
 
 
 
 
195
  progress(0.1, desc="Setting up dataset")
196
  repo_id = _check_push_to_hub(org_name, repo_name)
197
  distiset = Distiset(
@@ -208,7 +229,167 @@ def push_to_hub(
208
  create_pr=False,
209
  )
210
  progress(1.0, desc="Dataset pushed to hub")
211
- return dataframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  def upload_pipeline_code(
@@ -313,7 +494,7 @@ with gr.Blocks(
313
  # Add a header for the full dataset generation section
314
  gr.Markdown("## Generate full dataset")
315
  gr.Markdown(
316
- "Once you're satisfied with the sample, generate a larger dataset and push it to the Hub."
317
  )
318
 
319
  with gr.Column() as push_to_hub_ui:
@@ -333,27 +514,64 @@ with gr.Blocks(
333
  maximum=500,
334
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
335
  )
336
- with gr.Row(variant="panel"):
337
- org_name = get_org_dropdown()
338
- repo_name = gr.Textbox(
339
- label="Repo name", placeholder="dataset_name", value="my-distiset"
340
- )
341
- private = gr.Checkbox(
342
- label="Private dataset",
343
- value=True,
344
- interactive=True,
345
- scale=0.5,
346
- )
347
- with gr.Row() as regenerate_row:
348
- btn_generate_full_dataset = gr.Button(
349
- value="Generate", variant="primary", scale=2
350
- )
351
- btn_generate_and_push_to_hub = gr.Button(
352
- value="Generate and Push to Hub", variant="primary", scale=2
353
- )
354
- btn_push_to_hub = gr.Button(
355
- value="Push to Hub", variant="primary", scale=2
356
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  with gr.Row():
358
  final_dataset = gr.Dataframe(
359
  value=DEFAULT_DATASETS[0],
@@ -365,7 +583,25 @@ with gr.Blocks(
365
  with gr.Row():
366
  success_message = gr.Markdown(visible=False)
367
 
368
- def show_success_message(org_name, repo_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  return gr.Markdown(
370
  value=f"""
371
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
@@ -378,7 +614,7 @@ with gr.Blocks(
378
  </a>
379
  </p>
380
  </div>
381
- """,
382
  visible=True,
383
  )
384
 
@@ -407,8 +643,11 @@ with gr.Blocks(
407
  inputs=[sample_dataset],
408
  outputs=[final_dataset],
409
  )
410
-
411
- btn_generate_full_dataset.click(
 
 
 
412
  fn=hide_success_message,
413
  outputs=[success_message],
414
  ).then(
@@ -418,6 +657,30 @@ with gr.Blocks(
418
  show_progress=True,
419
  )
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  btn_generate_and_push_to_hub.click(
422
  fn=hide_success_message,
423
  outputs=[success_message],
@@ -437,7 +700,7 @@ with gr.Blocks(
437
  outputs=[],
438
  show_progress=True,
439
  ).success(
440
- fn=show_success_message,
441
  inputs=[org_name, repo_name],
442
  outputs=[success_message],
443
  )
@@ -456,11 +719,30 @@ with gr.Blocks(
456
  outputs=[],
457
  show_progress=True,
458
  ).success(
459
- fn=show_success_message,
460
  inputs=[org_name, repo_name],
461
  outputs=[success_message],
462
  )
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  system_prompt.change(
465
  fn=generate_pipeline_code,
466
  inputs=[system_prompt, num_turns, num_rows],
 
1
+ import ast
2
  import io
3
+ from typing import Dict, List, Union
4
 
5
+ import argilla as rg
6
  import gradio as gr
7
  import pandas as pd
8
  from datasets import Dataset
 
10
  from distilabel.steps.tasks.text_generation import TextGeneration
11
  from gradio.oauth import OAuthToken
12
  from huggingface_hub import upload_file
13
+ from huggingface_hub.hf_api import HfApi
14
 
15
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
16
+ get_embeddings,
17
+ get_sentence_embedding_dimensions,
18
+ )
19
  from src.distilabel_dataset_generator.pipelines.sft import (
20
  DEFAULT_BATCH_SIZE,
21
  DEFAULT_DATASET_DESCRIPTIONS,
 
28
  get_response_generator,
29
  )
30
  from src.distilabel_dataset_generator.utils import (
31
+ get_argilla_client,
32
  get_login_button,
33
  get_org_dropdown,
34
  swap_visibilty,
35
  )
36
 
37
 
38
+ def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
39
+ return ast.literal_eval(
40
+ messages.replace("'user'}", "'user'},")
41
+ .replace("'system'}", "'system'},")
42
+ .replace("'assistant'}", "'assistant'},")
43
+ )
44
+
45
+
46
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
47
  progress(0.0, desc="Generating system prompt")
48
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
 
98
  num_rows: int = 5,
99
  is_sample: bool = False,
100
  progress=gr.Progress(),
101
+ ) -> pd.DataFrame:
102
  progress(0.0, desc="(1/2) Generating instructions")
103
  magpie_generator = get_magpie_generator(
104
  num_turns, num_rows, system_prompt, is_sample
 
207
  repo_name: str = None,
208
  oauth_token: Union[OAuthToken, None] = None,
209
  progress=gr.Progress(),
210
+ ) -> pd.DataFrame:
211
+ original_dataframe = dataframe.copy(deep=True)
212
+ if "messages" in dataframe.columns:
213
+ dataframe["messages"] = dataframe["messages"].apply(
214
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
215
+ )
216
  progress(0.1, desc="Setting up dataset")
217
  repo_id = _check_push_to_hub(org_name, repo_name)
218
  distiset = Distiset(
 
229
  create_pr=False,
230
  )
231
  progress(1.0, desc="Dataset pushed to hub")
232
+ return original_dataframe
233
+
234
+
235
+ def push_to_argilla(
236
+ dataframe: pd.DataFrame,
237
+ dataset_name: str,
238
+ oauth_token: Union[OAuthToken, None] = None,
239
+ progress=gr.Progress(),
240
+ ) -> pd.DataFrame:
241
+ original_dataframe = dataframe.copy(deep=True)
242
+ if "messages" in dataframe.columns:
243
+ dataframe["messages"] = dataframe["messages"].apply(
244
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
245
+ )
246
+ try:
247
+ progress(0.1, desc="Setting up user and workspace")
248
+ client = get_argilla_client()
249
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
250
+
251
+ # Create user if it doesn't exist
252
+ rg_user = client.users(username=hf_user)
253
+ if rg_user is None:
254
+ rg_user = client.users.add(rg.User(username=hf_user, role="admin"))
255
+
256
+ # Create workspace if it doesn't exist
257
+ workspace = client.workspaces(name=rg_user.username)
258
+ if workspace is None:
259
+ workspace = client.workspaces.add(rg.Workspace(name=rg_user.username))
260
+ workspace.add_user(rg_user)
261
+
262
+ if "messages" in dataframe.columns:
263
+ settings = rg.Settings(
264
+ fields=[
265
+ rg.ChatField(
266
+ name="messages",
267
+ description="The messages in the conversation",
268
+ title="Messages",
269
+ ),
270
+ ],
271
+ questions=[
272
+ rg.RatingQuestion(
273
+ name="rating",
274
+ title="Rating",
275
+ description="The rating of the conversation",
276
+ values=list(range(1, 6)),
277
+ ),
278
+ ],
279
+ metadata=[
280
+ rg.IntegerMetadataProperty(
281
+ name="user_message_length", title="User Message Length"
282
+ ),
283
+ rg.IntegerMetadataProperty(
284
+ name="assistant_message_length",
285
+ title="Assistant Message Length",
286
+ ),
287
+ ],
288
+ vectors=[
289
+ rg.VectorField(
290
+ name="messages_embeddings",
291
+ dimensions=get_sentence_embedding_dimensions(),
292
+ )
293
+ ],
294
+ guidelines="Please review the conversation and provide a score for the assistant's response.",
295
+ )
296
+
297
+ dataframe["user_message_length"] = dataframe["messages"].apply(
298
+ lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"])
299
+ )
300
+ dataframe["assistant_message_length"] = dataframe["messages"].apply(
301
+ lambda x: sum(
302
+ [len(y["content"]) for y in x if y["role"] == "assistant"]
303
+ )
304
+ )
305
+ dataframe["messages_embeddings"] = get_embeddings(
306
+ dataframe["messages"].apply(
307
+ lambda x: " ".join([y["content"] for y in x])
308
+ )
309
+ )
310
+ else:
311
+ settings = rg.Settings(
312
+ fields=[
313
+ rg.TextField(
314
+ name="system_prompt",
315
+ title="System Prompt",
316
+ description="The system prompt used for the conversation",
317
+ required=False,
318
+ ),
319
+ rg.TextField(
320
+ name="prompt",
321
+ title="Prompt",
322
+ description="The prompt used for the conversation",
323
+ ),
324
+ rg.TextField(
325
+ name="completion",
326
+ title="Completion",
327
+ description="The completion from the assistant",
328
+ ),
329
+ ],
330
+ questions=[
331
+ rg.RatingQuestion(
332
+ name="rating",
333
+ title="Rating",
334
+ description="The rating of the conversation",
335
+ values=list(range(1, 6)),
336
+ ),
337
+ ],
338
+ metadata=[
339
+ rg.IntegerMetadataProperty(
340
+ name="prompt_length", title="Prompt Length"
341
+ ),
342
+ rg.IntegerMetadataProperty(
343
+ name="completion_length", title="Completion Length"
344
+ ),
345
+ ],
346
+ vectors=[
347
+ rg.VectorField(
348
+ name="prompt_embeddings",
349
+ dimensions=get_sentence_embedding_dimensions(),
350
+ )
351
+ ],
352
+ guidelines="Please review the conversation and correct the prompt and completion where needed.",
353
+ )
354
+ dataframe["prompt_length"] = dataframe["prompt"].apply(len)
355
+ dataframe["completion_length"] = dataframe["completion"].apply(len)
356
+ dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
357
+
358
+ progress(0.5, desc="Creating dataset")
359
+ rg_dataset = client.datasets(name=dataset_name, workspace=rg_user.username)
360
+ if rg_dataset is None:
361
+ rg_dataset = rg.Dataset(
362
+ name=dataset_name,
363
+ workspace=rg_user.username,
364
+ settings=settings,
365
+ client=client,
366
+ )
367
+ rg_dataset = rg_dataset.create()
368
+ progress(0.7, desc="Pushing dataset to Argilla")
369
+ hf_dataset = Dataset.from_pandas(dataframe)
370
+ rg_dataset.records.log(records=hf_dataset)
371
+ progress(1.0, desc="Dataset pushed to Argilla")
372
+ except Exception as e:
373
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
374
+ return original_dataframe
375
+
376
+
377
+ def validate_argilla_dataset_name(
378
+ dataset_name: str,
379
+ final_dataset: pd.DataFrame,
380
+ add_to_existing_dataset: bool,
381
+ oauth_token: Union[OAuthToken, None] = None,
382
+ progress=gr.Progress(),
383
+ ) -> str:
384
+ progress(0, desc="Validating dataset configuration")
385
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
386
+ client = get_argilla_client()
387
+ if dataset_name is None or dataset_name == "":
388
+ raise gr.Error("Dataset name is required")
389
+ dataset = client.datasets(name=dataset_name, workspace=hf_user)
390
+ if dataset and not add_to_existing_dataset:
391
+ raise gr.Error(f"Dataset {dataset_name} already exists")
392
+ return final_dataset
393
 
394
 
395
  def upload_pipeline_code(
 
494
  # Add a header for the full dataset generation section
495
  gr.Markdown("## Generate full dataset")
496
  gr.Markdown(
497
+ "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
498
  )
499
 
500
  with gr.Column() as push_to_hub_ui:
 
514
  maximum=500,
515
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
516
  )
517
+
518
+ with gr.Tab(label="Argilla"):
519
+ if get_argilla_client():
520
+ with gr.Row(variant="panel"):
521
+ dataset_name = gr.Textbox(
522
+ label="Dataset name",
523
+ placeholder="dataset_name",
524
+ value="my-distiset",
525
+ )
526
+ add_to_existing_dataset = gr.Checkbox(
527
+ label="Allow adding records to existing dataset",
528
+ info="When selected, you do need to ensure the number of turns in the conversation is the same as the number of turns in the existing dataset.",
529
+ value=False,
530
+ interactive=True,
531
+ scale=0.5,
532
+ )
533
+
534
+ with gr.Row(variant="panel"):
535
+ btn_generate_full_dataset_copy = gr.Button(
536
+ value="Generate", variant="primary", scale=2
537
+ )
538
+ btn_generate_and_push_to_argilla = gr.Button(
539
+ value="Generate and Push to Argilla",
540
+ variant="primary",
541
+ scale=2,
542
+ )
543
+ btn_push_to_argilla = gr.Button(
544
+ value="Push to Argilla", variant="primary", scale=2
545
+ )
546
+ else:
547
+ gr.Markdown(
548
+ "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla."
549
+ )
550
+ with gr.Tab("Hugging Face Hub"):
551
+ with gr.Row(variant="panel"):
552
+ org_name = get_org_dropdown()
553
+ repo_name = gr.Textbox(
554
+ label="Repo name",
555
+ placeholder="dataset_name",
556
+ value="my-distiset",
557
+ )
558
+ private = gr.Checkbox(
559
+ label="Private dataset",
560
+ value=True,
561
+ interactive=True,
562
+ scale=0.5,
563
+ )
564
+ with gr.Row(variant="panel"):
565
+ btn_generate_full_dataset = gr.Button(
566
+ value="Generate", variant="primary", scale=2
567
+ )
568
+ btn_generate_and_push_to_hub = gr.Button(
569
+ value="Generate and Push to Hub", variant="primary", scale=2
570
+ )
571
+ btn_push_to_hub = gr.Button(
572
+ value="Push to Hub", variant="primary", scale=2
573
+ )
574
+
575
  with gr.Row():
576
  final_dataset = gr.Dataframe(
577
  value=DEFAULT_DATASETS[0],
 
583
  with gr.Row():
584
  success_message = gr.Markdown(visible=False)
585
 
586
+ def show_success_message_argilla():
587
+ client = get_argilla_client()
588
+ argilla_api_url = client.api_url
589
+ return gr.Markdown(
590
+ value=f"""
591
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
592
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
593
+ <p style="margin-top: 0.5em;">
594
+ Your dataset is now available at:
595
+ <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
596
+ {argilla_api_url}
597
+ </a>
598
+ </p>
599
+ </div>
600
+ """,
601
+ visible=True,
602
+ )
603
+
604
+ def show_success_message_hub(org_name, repo_name):
605
  return gr.Markdown(
606
  value=f"""
607
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
 
614
  </a>
615
  </p>
616
  </div>
617
+ """,
618
  visible=True,
619
  )
620
 
 
643
  inputs=[sample_dataset],
644
  outputs=[final_dataset],
645
  )
646
+ gr.on(
647
+ triggers=[
648
+ btn_generate_full_dataset.click,
649
+ btn_generate_full_dataset_copy.click,
650
+ ],
651
  fn=hide_success_message,
652
  outputs=[success_message],
653
  ).then(
 
657
  show_progress=True,
658
  )
659
 
660
+ btn_generate_and_push_to_argilla.click(
661
+ fn=validate_argilla_dataset_name,
662
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
663
+ outputs=[final_dataset],
664
+ show_progress=True,
665
+ ).success(
666
+ fn=hide_success_message,
667
+ outputs=[success_message],
668
+ ).success(
669
+ fn=generate_dataset,
670
+ inputs=[system_prompt, num_turns, num_rows],
671
+ outputs=[final_dataset],
672
+ show_progress=True,
673
+ ).success(
674
+ fn=push_to_argilla,
675
+ inputs=[final_dataset, dataset_name],
676
+ outputs=[final_dataset],
677
+ show_progress=True,
678
+ ).success(
679
+ fn=show_success_message_argilla,
680
+ inputs=[],
681
+ outputs=[success_message],
682
+ )
683
+
684
  btn_generate_and_push_to_hub.click(
685
  fn=hide_success_message,
686
  outputs=[success_message],
 
700
  outputs=[],
701
  show_progress=True,
702
  ).success(
703
+ fn=show_success_message_hub,
704
  inputs=[org_name, repo_name],
705
  outputs=[success_message],
706
  )
 
719
  outputs=[],
720
  show_progress=True,
721
  ).success(
722
+ fn=show_success_message_hub,
723
  inputs=[org_name, repo_name],
724
  outputs=[success_message],
725
  )
726
 
727
+ btn_push_to_argilla.click(
728
+ fn=hide_success_message,
729
+ outputs=[success_message],
730
+ ).success(
731
+ fn=validate_argilla_dataset_name,
732
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
733
+ outputs=[final_dataset],
734
+ show_progress=True,
735
+ ).success(
736
+ fn=push_to_argilla,
737
+ inputs=[final_dataset, dataset_name],
738
+ outputs=[final_dataset],
739
+ show_progress=True,
740
+ ).success(
741
+ fn=show_success_message_argilla,
742
+ inputs=[],
743
+ outputs=[success_message],
744
+ )
745
+
746
  system_prompt.change(
747
  fn=generate_pipeline_code,
748
  inputs=[system_prompt, num_turns, num_rows],
src/distilabel_dataset_generator/pipelines/embeddings.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+ from sentence_transformers.models import StaticEmbedding
5
+
6
+ # Initialize a StaticEmbedding module
7
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
8
+ model = SentenceTransformer(modules=[static_embedding])
9
+
10
+
11
+ def get_embeddings(texts: List[str]) -> List[List[float]]:
12
+ return [embedding.tolist() for embedding in model.encode(texts)]
13
+
14
+
15
+ def get_sentence_embedding_dimensions() -> int:
16
+ return model.get_sentence_embedding_dimension()
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -189,7 +189,7 @@ with Pipeline(name="sft") as pipeline:
189
  tokenizer_id=MODEL,
190
  magpie_pre_query_template="llama3",
191
  generation_kwargs={{
192
- "temperature": 0.8,
193
  "do_sample": True,
194
  "max_new_tokens": 2048,
195
  "stop_sequences": {_STOP_SEQUENCES}
@@ -231,7 +231,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
231
  api_key=_get_next_api_key(),
232
  magpie_pre_query_template="llama3",
233
  generation_kwargs={
234
- "temperature": 0.8,
235
  "do_sample": True,
236
  "max_new_tokens": 256 if is_sample else 512,
237
  "stop_sequences": _STOP_SEQUENCES,
@@ -250,7 +250,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
250
  api_key=_get_next_api_key(),
251
  magpie_pre_query_template="llama3",
252
  generation_kwargs={
253
- "temperature": 0.8,
254
  "do_sample": True,
255
  "max_new_tokens": 256 if is_sample else 1024,
256
  "stop_sequences": _STOP_SEQUENCES,
 
189
  tokenizer_id=MODEL,
190
  magpie_pre_query_template="llama3",
191
  generation_kwargs={{
192
+ "temperature": 1,
193
  "do_sample": True,
194
  "max_new_tokens": 2048,
195
  "stop_sequences": {_STOP_SEQUENCES}
 
231
  api_key=_get_next_api_key(),
232
  magpie_pre_query_template="llama3",
233
  generation_kwargs={
234
+ "temperature": 1,
235
  "do_sample": True,
236
  "max_new_tokens": 256 if is_sample else 512,
237
  "stop_sequences": _STOP_SEQUENCES,
 
250
  api_key=_get_next_api_key(),
251
  magpie_pre_query_template="llama3",
252
  generation_kwargs={
253
+ "temperature": 1,
254
  "do_sample": True,
255
  "max_new_tokens": 256 if is_sample else 1024,
256
  "stop_sequences": _STOP_SEQUENCES,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
 
2
 
 
3
  import gradio as gr
4
  from gradio.oauth import (
5
  OAUTH_CLIENT_ID,
@@ -81,3 +83,15 @@ def swap_visibilty(oauth_token: OAuthToken = None):
81
  return gr.update(elem_classes=["main_ui_logged_in"])
82
  else:
83
  return gr.update(elem_classes=["main_ui_logged_out"])
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Union
3
 
4
+ import argilla as rg
5
  import gradio as gr
6
  from gradio.oauth import (
7
  OAUTH_CLIENT_ID,
 
83
  return gr.update(elem_classes=["main_ui_logged_in"])
84
  else:
85
  return gr.update(elem_classes=["main_ui_logged_out"])
86
+
87
+
88
+ def get_argilla_client() -> Union[rg.Argilla, None]:
89
+ try:
90
+ return rg.Argilla(
91
+ api_url=os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
92
+ or os.getenv("ARGILLA_API_URL"),
93
+ api_key=os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
94
+ or os.getenv("ARGILLA_API_KEY"),
95
+ )
96
+ except Exception:
97
+ return None