davidberenstein1957 HF staff commited on
Commit
9b4773a
·
1 Parent(s): 5829740

feat: update batch size

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -74,6 +74,8 @@ def generate_dataset(
74
  if repo_name is not None and org_name is not None
75
  else None
76
  )
 
 
77
  if repo_id is not None:
78
  if not all([repo_id, org_name, repo_name]):
79
  raise gr.Error(
@@ -295,7 +297,7 @@ with gr.Blocks(
295
  ],
296
  outputs=[table],
297
  show_progress=True,
298
- ).then(
299
  fn=show_success_message,
300
  inputs=[org_name, repo_name],
301
  outputs=[success_message],
 
74
  if repo_name is not None and org_name is not None
75
  else None
76
  )
77
+ if oauth_token is None or oauth_token == "":
78
+ print(oauth_token, repo_id)
79
  if repo_id is not None:
80
  if not all([repo_id, org_name, repo_name]):
81
  raise gr.Error(
 
297
  ],
298
  outputs=[table],
299
  show_progress=True,
300
+ ).success(
301
  fn=show_success_message,
302
  inputs=[org_name, repo_name],
303
  outputs=[success_message],
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -138,6 +138,7 @@ _STOP_SEQUENCES = [
138
  "assistant",
139
  " \n\n",
140
  ]
 
141
 
142
 
143
  def _get_output_mappings(num_turns):
@@ -205,33 +206,30 @@ def get_pipeline(num_turns, num_rows, system_prompt):
205
  "stop_sequences": _STOP_SEQUENCES,
206
  },
207
  ),
208
- batch_size=2,
209
  n_turns=num_turns,
210
  num_rows=num_rows,
211
  system_prompt=system_prompt,
212
  output_mappings={"instruction": "prompt"},
213
- only_instruction=True
214
  )
215
-
216
  generate_response = TextGeneration(
217
  llm=InferenceEndpointsLLM(
218
  model_id=MODEL,
219
  tokenizer_id=MODEL,
220
  api_key=os.environ["HF_TOKEN"],
221
- generation_kwargs={
222
- "temperature": 0.8,
223
- "max_new_tokens": 1024
224
- },
225
  ),
226
  system_prompt=system_prompt,
227
  output_mappings={"generation": "completion"},
228
- input_mappings={"instruction": "prompt"}
229
  )
230
-
231
  keep_columns = KeepColumns(
232
  columns=list(output_mappings.values()) + ["model_name"],
233
  )
234
-
235
  magpie.connect(generate_response)
236
  generate_response.connect(keep_columns)
237
  return pipeline
@@ -250,7 +248,7 @@ def get_pipeline(num_turns, num_rows, system_prompt):
250
  "stop_sequences": _STOP_SEQUENCES,
251
  },
252
  ),
253
- batch_size=2,
254
  n_turns=num_turns,
255
  num_rows=num_rows,
256
  system_prompt=system_prompt,
 
138
  "assistant",
139
  " \n\n",
140
  ]
141
+ DEFAULT_BATCH_SIZE = 1
142
 
143
 
144
  def _get_output_mappings(num_turns):
 
206
  "stop_sequences": _STOP_SEQUENCES,
207
  },
208
  ),
209
+ batch_size=DEFAULT_BATCH_SIZE,
210
  n_turns=num_turns,
211
  num_rows=num_rows,
212
  system_prompt=system_prompt,
213
  output_mappings={"instruction": "prompt"},
214
+ only_instruction=True,
215
  )
216
+
217
  generate_response = TextGeneration(
218
  llm=InferenceEndpointsLLM(
219
  model_id=MODEL,
220
  tokenizer_id=MODEL,
221
  api_key=os.environ["HF_TOKEN"],
222
+ generation_kwargs={"temperature": 0.8, "max_new_tokens": 1024},
 
 
 
223
  ),
224
  system_prompt=system_prompt,
225
  output_mappings={"generation": "completion"},
226
+ input_mappings={"instruction": "prompt"},
227
  )
228
+
229
  keep_columns = KeepColumns(
230
  columns=list(output_mappings.values()) + ["model_name"],
231
  )
232
+
233
  magpie.connect(generate_response)
234
  generate_response.connect(keep_columns)
235
  return pipeline
 
248
  "stop_sequences": _STOP_SEQUENCES,
249
  },
250
  ),
251
+ batch_size=DEFAULT_BATCH_SIZE,
252
  n_turns=num_turns,
253
  num_rows=num_rows,
254
  system_prompt=system_prompt,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -30,7 +30,7 @@ def get_login_button():
30
  return gr.LoginButton(
31
  value="Sign in with Hugging Face!",
32
  size="lg",
33
- )
34
 
35
 
36
  def get_duplicate_button():
 
30
  return gr.LoginButton(
31
  value="Sign in with Hugging Face!",
32
  size="lg",
33
+ ).activate()
34
 
35
 
36
  def get_duplicate_button():