Commit
·
9b4773a
1
Parent(s):
5829740
feat: update batch size
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -74,6 +74,8 @@ def generate_dataset(
|
|
74 |
if repo_name is not None and org_name is not None
|
75 |
else None
|
76 |
)
|
|
|
|
|
77 |
if repo_id is not None:
|
78 |
if not all([repo_id, org_name, repo_name]):
|
79 |
raise gr.Error(
|
@@ -295,7 +297,7 @@ with gr.Blocks(
|
|
295 |
],
|
296 |
outputs=[table],
|
297 |
show_progress=True,
|
298 |
-
).
|
299 |
fn=show_success_message,
|
300 |
inputs=[org_name, repo_name],
|
301 |
outputs=[success_message],
|
|
|
74 |
if repo_name is not None and org_name is not None
|
75 |
else None
|
76 |
)
|
77 |
+
if oauth_token is None or oauth_token == "":
|
78 |
+
print(oauth_token, repo_id)
|
79 |
if repo_id is not None:
|
80 |
if not all([repo_id, org_name, repo_name]):
|
81 |
raise gr.Error(
|
|
|
297 |
],
|
298 |
outputs=[table],
|
299 |
show_progress=True,
|
300 |
+
).success(
|
301 |
fn=show_success_message,
|
302 |
inputs=[org_name, repo_name],
|
303 |
outputs=[success_message],
|
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
@@ -138,6 +138,7 @@ _STOP_SEQUENCES = [
|
|
138 |
"assistant",
|
139 |
" \n\n",
|
140 |
]
|
|
|
141 |
|
142 |
|
143 |
def _get_output_mappings(num_turns):
|
@@ -205,33 +206,30 @@ def get_pipeline(num_turns, num_rows, system_prompt):
|
|
205 |
"stop_sequences": _STOP_SEQUENCES,
|
206 |
},
|
207 |
),
|
208 |
-
batch_size=
|
209 |
n_turns=num_turns,
|
210 |
num_rows=num_rows,
|
211 |
system_prompt=system_prompt,
|
212 |
output_mappings={"instruction": "prompt"},
|
213 |
-
only_instruction=True
|
214 |
)
|
215 |
-
|
216 |
generate_response = TextGeneration(
|
217 |
llm=InferenceEndpointsLLM(
|
218 |
model_id=MODEL,
|
219 |
tokenizer_id=MODEL,
|
220 |
api_key=os.environ["HF_TOKEN"],
|
221 |
-
generation_kwargs={
|
222 |
-
"temperature": 0.8,
|
223 |
-
"max_new_tokens": 1024
|
224 |
-
},
|
225 |
),
|
226 |
system_prompt=system_prompt,
|
227 |
output_mappings={"generation": "completion"},
|
228 |
-
input_mappings={"instruction": "prompt"}
|
229 |
)
|
230 |
-
|
231 |
keep_columns = KeepColumns(
|
232 |
columns=list(output_mappings.values()) + ["model_name"],
|
233 |
)
|
234 |
-
|
235 |
magpie.connect(generate_response)
|
236 |
generate_response.connect(keep_columns)
|
237 |
return pipeline
|
@@ -250,7 +248,7 @@ def get_pipeline(num_turns, num_rows, system_prompt):
|
|
250 |
"stop_sequences": _STOP_SEQUENCES,
|
251 |
},
|
252 |
),
|
253 |
-
batch_size=
|
254 |
n_turns=num_turns,
|
255 |
num_rows=num_rows,
|
256 |
system_prompt=system_prompt,
|
|
|
138 |
"assistant",
|
139 |
" \n\n",
|
140 |
]
|
141 |
+
DEFAULT_BATCH_SIZE = 1
|
142 |
|
143 |
|
144 |
def _get_output_mappings(num_turns):
|
|
|
206 |
"stop_sequences": _STOP_SEQUENCES,
|
207 |
},
|
208 |
),
|
209 |
+
batch_size=DEFAULT_BATCH_SIZE,
|
210 |
n_turns=num_turns,
|
211 |
num_rows=num_rows,
|
212 |
system_prompt=system_prompt,
|
213 |
output_mappings={"instruction": "prompt"},
|
214 |
+
only_instruction=True,
|
215 |
)
|
216 |
+
|
217 |
generate_response = TextGeneration(
|
218 |
llm=InferenceEndpointsLLM(
|
219 |
model_id=MODEL,
|
220 |
tokenizer_id=MODEL,
|
221 |
api_key=os.environ["HF_TOKEN"],
|
222 |
+
generation_kwargs={"temperature": 0.8, "max_new_tokens": 1024},
|
|
|
|
|
|
|
223 |
),
|
224 |
system_prompt=system_prompt,
|
225 |
output_mappings={"generation": "completion"},
|
226 |
+
input_mappings={"instruction": "prompt"},
|
227 |
)
|
228 |
+
|
229 |
keep_columns = KeepColumns(
|
230 |
columns=list(output_mappings.values()) + ["model_name"],
|
231 |
)
|
232 |
+
|
233 |
magpie.connect(generate_response)
|
234 |
generate_response.connect(keep_columns)
|
235 |
return pipeline
|
|
|
248 |
"stop_sequences": _STOP_SEQUENCES,
|
249 |
},
|
250 |
),
|
251 |
+
batch_size=DEFAULT_BATCH_SIZE,
|
252 |
n_turns=num_turns,
|
253 |
num_rows=num_rows,
|
254 |
system_prompt=system_prompt,
|
src/distilabel_dataset_generator/utils.py
CHANGED
@@ -30,7 +30,7 @@ def get_login_button():
|
|
30 |
return gr.LoginButton(
|
31 |
value="Sign in with Hugging Face!",
|
32 |
size="lg",
|
33 |
-
)
|
34 |
|
35 |
|
36 |
def get_duplicate_button():
|
|
|
30 |
return gr.LoginButton(
|
31 |
value="Sign in with Hugging Face!",
|
32 |
size="lg",
|
33 |
+
).activate()
|
34 |
|
35 |
|
36 |
def get_duplicate_button():
|