m-ric HF staff commited on
Commit
73826bd
β€’
1 Parent(s): ceb7300

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -272,7 +272,7 @@ class BeamNode:
272
  is_selected_sequence: bool
273
 
274
 
275
- def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences, beam_indexes_source):
276
  original_tree = BeamNode(
277
  cumulative_score=0,
278
  current_token_ix=None,
@@ -415,8 +415,6 @@ def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequ
415
  current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
416
  beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
417
 
418
- print(f"Step {step}, beams kept: {beams_to_keep}")
419
-
420
  return original_tree
421
 
422
  @spaces.GPU
@@ -445,14 +443,23 @@ def get_beam_search_html(
445
  for i, sequence in enumerate(decoded_sequences):
446
  markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
447
 
448
- original_tree = generate_beams(
449
- number_beams,
450
- input_text,
451
- outputs.scores[:],
452
- length_penalty,
453
- decoded_sequences,
454
- outputs.beam_indices,
455
- )
 
 
 
 
 
 
 
 
 
456
  html = generate_html(input_text, original_tree)
457
  return html, markdown
458
 
 
272
  is_selected_sequence: bool
273
 
274
 
275
+ def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences):
276
  original_tree = BeamNode(
277
  cumulative_score=0,
278
  current_token_ix=None,
 
415
  current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
416
  beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
417
 
 
 
418
  return original_tree
419
 
420
  @spaces.GPU
 
443
  for i, sequence in enumerate(decoded_sequences):
444
  markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
445
 
446
+ if number_beams > 1:
447
+ original_tree = generate_beams(
448
+ number_beams,
449
+ input_text,
450
+ outputs.scores[:],
451
+ length_penalty,
452
+ decoded_sequences,
453
+ )
454
+ else:
455
+ original_tree = generate_beams(
456
+ n_beams,
457
+ start_sentence,
458
+ outputs.logits,
459
+ 0,
460
+ decoded_sequences,
461
+ )
462
+
463
  html = generate_html(input_text, original_tree)
464
  return html, markdown
465