m-ric HF staff commited on
Commit
f9fcae6
β€’
1 Parent(s): 38ca58e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -272,7 +272,7 @@ class BeamNode:
272
  is_selected_sequence: bool
273
 
274
 
275
- def generate_beams(start_sentence, scores, length_penalty, decoded_sequences, beam_indexes_source):
276
  original_tree = BeamNode(
277
  cumulative_score=0,
278
  current_token_ix=None,
@@ -284,7 +284,6 @@ def generate_beams(start_sentence, scores, length_penalty, decoded_sequences, be
284
  is_final=False,
285
  is_selected_sequence=False,
286
  )
287
- n_beams = len(scores[0])
288
  beam_trees = [original_tree] * n_beams
289
  generation_length = len(scores)
290
 
@@ -429,7 +428,7 @@ def get_beam_search_html(
429
  outputs = model.generate(
430
  **inputs,
431
  max_new_tokens=number_steps,
432
- num_beams=number_beams,
433
  num_return_sequences=num_return_sequences,
434
  return_dict_in_generate=True,
435
  length_penalty=length_penalty,
@@ -447,6 +446,7 @@ def get_beam_search_html(
447
  markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
448
 
449
  original_tree = generate_beams(
 
450
  input_text,
451
  outputs.scores[:],
452
  length_penalty,
@@ -493,7 +493,7 @@ This parameter will not impact the beam search paths, but only influence the cho
493
  label="Number of steps", minimum=1, maximum=12, step=1, value=5
494
  )
495
  n_beams = gr.Slider(
496
- label="Number of beams", minimum=2, maximum=4, step=1, value=4
497
  )
498
  length_penalty = gr.Slider(
499
  label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
 
272
  is_selected_sequence: bool
273
 
274
 
275
+ def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences, beam_indexes_source):
276
  original_tree = BeamNode(
277
  cumulative_score=0,
278
  current_token_ix=None,
 
284
  is_final=False,
285
  is_selected_sequence=False,
286
  )
 
287
  beam_trees = [original_tree] * n_beams
288
  generation_length = len(scores)
289
 
 
428
  outputs = model.generate(
429
  **inputs,
430
  max_new_tokens=number_steps,
431
+ num_beams=max(number_beams, 2),
432
  num_return_sequences=num_return_sequences,
433
  return_dict_in_generate=True,
434
  length_penalty=length_penalty,
 
446
  markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
447
 
448
  original_tree = generate_beams(
449
+ number_beams,
450
  input_text,
451
  outputs.scores[:],
452
  length_penalty,
 
493
  label="Number of steps", minimum=1, maximum=12, step=1, value=5
494
  )
495
  n_beams = gr.Slider(
496
+ label="Number of beams", minimum=1, maximum=4, step=1, value=4
497
  )
498
  length_penalty = gr.Slider(
499
  label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1