dvilasuero HF staff commited on
Commit
c973277
·
verified ·
1 Parent(s): a13f86c

Reduce simple dataset generation time

Browse files
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -190,31 +190,73 @@ if __name__ == "__main__":
190
  def get_pipeline(num_turns, num_rows, system_prompt):
191
  input_mappings = _get_output_mappings(num_turns)
192
  output_mappings = input_mappings
193
- with Pipeline(name="sft") as pipeline:
194
- magpie = MagpieGenerator(
195
- llm=InferenceEndpointsLLM(
196
- model_id=MODEL,
197
- tokenizer_id=MODEL,
198
- api_key=os.environ["HF_TOKEN"],
199
- magpie_pre_query_template="llama3",
200
- generation_kwargs={
201
- "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
202
- "do_sample": True,
203
- "max_new_tokens": 2048,
204
- "stop_sequences": _STOP_SEQUENCES,
205
- },
206
- ),
207
- batch_size=2,
208
- n_turns=num_turns,
209
- num_rows=num_rows,
210
- system_prompt=system_prompt,
211
- output_mappings=output_mappings,
212
- )
213
- keep_columns = KeepColumns(
214
- columns=list(output_mappings.values()) + ["model_name"],
215
- )
216
- magpie.connect(keep_columns)
217
- return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
 
220
  def get_prompt_generation_step():
 
190
  def get_pipeline(num_turns, num_rows, system_prompt):
191
  input_mappings = _get_output_mappings(num_turns)
192
  output_mappings = input_mappings
193
+ if num_turns == 1:
194
+ with Pipeline(name="sft") as pipeline:
195
+ magpie = MagpieGenerator(
196
+ llm=InferenceEndpointsLLM(
197
+ model_id=MODEL,
198
+ tokenizer_id=MODEL,
199
+ api_key=os.environ["HF_TOKEN"],
200
+ magpie_pre_query_template="llama3",
201
+ generation_kwargs={
202
+ "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
203
+ "do_sample": True,
204
+ "max_new_tokens": 512,
205
+ "stop_sequences": _STOP_SEQUENCES,
206
+ },
207
+ ),
208
+ batch_size=2,
209
+ n_turns=num_turns,
210
+ num_rows=num_rows,
211
+ system_prompt=system_prompt,
212
+ output_mappings=output_mappings,
213
+ only_instructions=True
214
+ )
215
+
216
+ generate_response = TextGeneration(
217
+ llm=InferenceEndpointsLLM(
218
+ model_id=MODEL,
219
+ tokenizer_id=MODEL,
220
+ generation_kwargs={
221
+ "temperature": 0.8,
222
+ "max_new_tokens": 1024
223
+ },
224
+ )
225
+ )
226
+
227
+ keep_columns = KeepColumns(
228
+ columns=list(output_mappings.values()) + ["model_name"],
229
+ )
230
+
231
+ magpie.connect(generate_response)
232
+ generate_response.connect(keep_columns)
233
+ return pipeline
234
+ else:
235
+ with Pipeline(name="sft") as pipeline:
236
+ magpie = MagpieGenerator(
237
+ llm=InferenceEndpointsLLM(
238
+ model_id=MODEL,
239
+ tokenizer_id=MODEL,
240
+ api_key=os.environ["HF_TOKEN"],
241
+ magpie_pre_query_template="llama3",
242
+ generation_kwargs={
243
+ "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
244
+ "do_sample": True,
245
+ "max_new_tokens": 2048,
246
+ "stop_sequences": _STOP_SEQUENCES,
247
+ },
248
+ ),
249
+ batch_size=2,
250
+ n_turns=num_turns,
251
+ num_rows=num_rows,
252
+ system_prompt=system_prompt,
253
+ output_mappings=output_mappings,
254
+ )
255
+ keep_columns = KeepColumns(
256
+ columns=list(output_mappings.values()) + ["model_name"],
257
+ )
258
+ magpie.connect(keep_columns)
259
+ return pipeline
260
 
261
 
262
  def get_prompt_generation_step():