Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -392,14 +392,14 @@ def generate_beams(start_sentence, scores, length_penalty, decoded_sequences):
|
|
392 |
return original_tree
|
393 |
|
394 |
@spaces.GPU
|
395 |
-
def get_beam_search_html(input_text, number_steps, number_beams, length_penalty,
|
396 |
inputs = tokenizer([input_text], return_tensors="pt")
|
397 |
|
398 |
outputs = model.generate(
|
399 |
**inputs,
|
400 |
max_new_tokens=number_steps,
|
401 |
num_beams=number_beams,
|
402 |
-
num_return_sequences=
|
403 |
return_dict_in_generate=True,
|
404 |
length_penalty=length_penalty,
|
405 |
output_scores=True,
|
@@ -425,7 +425,7 @@ def get_beam_search_html(input_text, number_steps, number_beams, length_penalty,
|
|
425 |
return html, markdown
|
426 |
|
427 |
|
428 |
-
def
|
429 |
return gr.Slider(label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams)
|
430 |
|
431 |
|
@@ -447,7 +447,7 @@ Play with the parameters below to understand how beam search decoding works!
|
|
447 |
- **Number of beams** (`num_beams`): the number of beams to use
|
448 |
- **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
|
449 |
This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences.
|
450 |
-
- **Number of sequences** (`num_return_sequences`): the number of sequences to be returned at the end of generation.
|
451 |
"""
|
452 |
)
|
453 |
text = gr.Textbox(
|
@@ -464,17 +464,17 @@ This parameter will not impact the beam search paths, but only influence the cho
|
|
464 |
length_penalty = gr.Slider(
|
465 |
label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
|
466 |
)
|
467 |
-
|
468 |
-
label="Number of sequences", minimum=1, maximum=4, step=1, value=3
|
469 |
)
|
470 |
|
471 |
-
n_beams.change(fn=
|
472 |
button = gr.Button()
|
473 |
out_html = gr.Markdown()
|
474 |
out_markdown = gr.Markdown()
|
475 |
button.click(
|
476 |
get_beam_search_html,
|
477 |
-
inputs=[text, n_steps, n_beams, length_penalty,
|
478 |
outputs=[out_html, out_markdown],
|
479 |
)
|
480 |
|
|
|
392 |
return original_tree
|
393 |
|
394 |
@spaces.GPU
|
395 |
+
def get_beam_search_html(input_text, number_steps, number_beams, length_penalty, num_return_sequences):
|
396 |
inputs = tokenizer([input_text], return_tensors="pt")
|
397 |
|
398 |
outputs = model.generate(
|
399 |
**inputs,
|
400 |
max_new_tokens=number_steps,
|
401 |
num_beams=number_beams,
|
402 |
+
num_return_sequences=num_return_sequences,
|
403 |
return_dict_in_generate=True,
|
404 |
length_penalty=length_penalty,
|
405 |
output_scores=True,
|
|
|
425 |
return html, markdown
|
426 |
|
427 |
|
428 |
+
def change_num_return_sequences(n_beams):
|
429 |
return gr.Slider(label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams)
|
430 |
|
431 |
|
|
|
447 |
- **Number of beams** (`num_beams`): the number of beams to use
|
448 |
- **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
|
449 |
This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences.
|
450 |
+
- **Number of return sequences** (`num_return_sequences`): the number of sequences to be returned at the end of generation. Should be `<= num_beams'
|
451 |
"""
|
452 |
)
|
453 |
text = gr.Textbox(
|
|
|
464 |
length_penalty = gr.Slider(
|
465 |
label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
|
466 |
)
|
467 |
+
num_return_sequences = gr.Slider(
|
468 |
+
label="Number of return sequences", minimum=1, maximum=4, step=1, value=3
|
469 |
)
|
470 |
|
471 |
+
n_beams.change(fn=change_num_return_sequences, inputs=n_beams, outputs=num_return_sequences)
|
472 |
button = gr.Button()
|
473 |
out_html = gr.Markdown()
|
474 |
out_markdown = gr.Markdown()
|
475 |
button.click(
|
476 |
get_beam_search_html,
|
477 |
+
inputs=[text, n_steps, n_beams, length_penalty, num_return_sequences],
|
478 |
outputs=[out_html, out_markdown],
|
479 |
)
|
480 |
|