dvilasuero HF staff commited on
Commit
6a4ac56
·
verified ·
1 Parent(s): f945ced

Performance improvements

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -22,11 +22,12 @@ from src.distilabel_dataset_generator.utils import (
22
  )
23
 
24
 
25
- def _run_pipeline(result_queue, num_turns, num_rows, system_prompt):
26
  pipeline = get_pipeline(
27
  num_turns,
28
  num_rows,
29
  system_prompt,
 
30
  )
31
  distiset: Distiset = pipeline.run(use_cache=False)
32
  result_queue.put(distiset)
@@ -54,7 +55,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
54
 
55
  def generate_sample_dataset(system_prompt, progress=gr.Progress()):
56
  progress(0.1, desc="Initializing sample dataset generation")
57
- result = generate_dataset(system_prompt, num_turns=1, num_rows=1, progress=progress)
58
  progress(1.0, desc="Sample dataset generated")
59
  return result
60
 
@@ -68,6 +69,7 @@ def generate_dataset(
68
  repo_name: str = None,
69
  oauth_token: str = None,
70
  progress=gr.Progress(),
 
71
  ):
72
  repo_id = (
73
  f"{org_name}/{repo_name}"
@@ -88,8 +90,9 @@ def generate_dataset(
88
  gr.Info(
89
  "You can only generate a dataset with 1000 or fewer rows. Setting to 1000."
90
  )
91
-
92
- if num_rows < 10:
 
93
  duration = 60
94
  elif num_rows < 30:
95
  duration = 120
@@ -105,7 +108,7 @@ def generate_dataset(
105
  result_queue = multiprocessing.Queue()
106
  p = multiprocessing.Process(
107
  target=_run_pipeline,
108
- args=(result_queue, num_turns, num_rows, system_prompt),
109
  )
110
 
111
  try:
@@ -175,28 +178,31 @@ with gr.Blocks(
175
  )
176
  with gr.Row():
177
  gr.Column(scale=1)
178
- btn_generate_system_prompt = gr.Button(value="Generate sample dataset")
179
  gr.Column(scale=1)
 
180
 
181
  system_prompt = gr.TextArea(
182
- label="If you want to improve the dataset, you can tune the system prompt and regenerate the sample",
183
  value=DEFAULT_SYSTEM_PROMPT,
184
  )
185
 
186
- with gr.Row():
187
- gr.Column(scale=1)
188
- btn_generate_sample_dataset = gr.Button(
189
- value="Regenerate sample dataset",
190
- )
191
- gr.Column(scale=1)
192
-
193
  with gr.Row():
194
  table = gr.DataFrame(
195
  value=DEFAULT_DATASET,
 
196
  interactive=False,
197
  wrap=True,
198
  )
199
 
 
 
 
 
 
 
 
 
200
  result = btn_generate_system_prompt.click(
201
  fn=generate_system_prompt,
202
  inputs=[dataset_description],
@@ -233,10 +239,10 @@ with gr.Blocks(
233
  info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
234
  )
235
  num_rows = gr.Number(
236
- value=100,
237
  label="Number of rows in the dataset",
238
  minimum=1,
239
- maximum=1000,
240
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
241
  )
242
 
@@ -249,16 +255,24 @@ with gr.Blocks(
249
  visible=False,
250
  )
251
  org_name = get_org_dropdown()
252
- repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name")
253
  private = gr.Checkbox(
254
  label="Private dataset", value=True, interactive=True, scale=0.5
255
  )
256
-
257
- btn_generate_full_dataset = gr.Button(
258
- value="⚗️ Generate Full Dataset", variant="primary"
259
- )
260
-
 
261
  success_message = gr.Markdown(visible=False)
 
 
 
 
 
 
 
262
 
263
  def show_success_message(org_name, repo_name):
264
  return gr.Markdown(
@@ -294,7 +308,7 @@ with gr.Blocks(
294
  repo_name,
295
  oauth_token,
296
  ],
297
- outputs=[table],
298
  show_progress=True,
299
  ).success(
300
  fn=show_success_message,
 
22
  )
23
 
24
 
25
+ def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, is_sample):
26
  pipeline = get_pipeline(
27
  num_turns,
28
  num_rows,
29
  system_prompt,
30
+ is_sample
31
  )
32
  distiset: Distiset = pipeline.run(use_cache=False)
33
  result_queue.put(distiset)
 
55
 
56
  def generate_sample_dataset(system_prompt, progress=gr.Progress()):
57
  progress(0.1, desc="Initializing sample dataset generation")
58
+ result = generate_dataset(system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True)
59
  progress(1.0, desc="Sample dataset generated")
60
  return result
61
 
 
69
  repo_name: str = None,
70
  oauth_token: str = None,
71
  progress=gr.Progress(),
72
+ is_sample: bool = False,
73
  ):
74
  repo_id = (
75
  f"{org_name}/{repo_name}"
 
90
  gr.Info(
91
  "You can only generate a dataset with 1000 or fewer rows. Setting to 1000."
92
  )
93
+ if num_rows < 5:
94
+ duration = 25
95
+ elif num_rows < 10:
96
  duration = 60
97
  elif num_rows < 30:
98
  duration = 120
 
108
  result_queue = multiprocessing.Queue()
109
  p = multiprocessing.Process(
110
  target=_run_pipeline,
111
+ args=(result_queue, num_turns, num_rows, system_prompt, is_sample),
112
  )
113
 
114
  try:
 
178
  )
179
  with gr.Row():
180
  gr.Column(scale=1)
181
+ btn_generate_system_prompt = gr.Button(value="Generate sample")
182
  gr.Column(scale=1)
183
+
184
 
185
  system_prompt = gr.TextArea(
186
+ label="System prompt for dataset generation. You can tune it and regenerate the sample",
187
  value=DEFAULT_SYSTEM_PROMPT,
188
  )
189
 
 
 
 
 
 
 
 
190
  with gr.Row():
191
  table = gr.DataFrame(
192
  value=DEFAULT_DATASET,
193
+ label="Sample dataset. Prompts and completions truncated to 256 tokens.",
194
  interactive=False,
195
  wrap=True,
196
  )
197
 
198
+
199
+ with gr.Row():
200
+ gr.Column(scale=1)
201
+ btn_generate_sample_dataset = gr.Button(
202
+ value="Regenerate sample",
203
+ )
204
+ gr.Column(scale=1)
205
+
206
  result = btn_generate_system_prompt.click(
207
  fn=generate_system_prompt,
208
  inputs=[dataset_description],
 
239
  info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
240
  )
241
  num_rows = gr.Number(
242
+ value=10,
243
  label="Number of rows in the dataset",
244
  minimum=1,
245
+ maximum=500,
246
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
247
  )
248
 
 
255
  visible=False,
256
  )
257
  org_name = get_org_dropdown()
258
+ repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name", value="my-distiset")
259
  private = gr.Checkbox(
260
  label="Private dataset", value=True, interactive=True, scale=0.5
261
  )
262
+ with gr.Row() as regenerate_row:
263
+ gr.Column(scale=1)
264
+ btn_generate_full_dataset = gr.Button(
265
+ value="Generate Full Dataset", variant="primary"
266
+ )
267
+ gr.Column(scale=1)
268
  success_message = gr.Markdown(visible=False)
269
+ with gr.Row():
270
+ final_dataset = gr.DataFrame(
271
+ value=DEFAULT_DATASET,
272
+ label="Generated dataset",
273
+ interactive=False,
274
+ wrap=True,
275
+ )
276
 
277
  def show_success_message(org_name, repo_name):
278
  return gr.Markdown(
 
308
  repo_name,
309
  oauth_token,
310
  ],
311
+ outputs=[final_dataset],
312
  show_progress=True,
313
  ).success(
314
  fn=show_success_message,