m-ric HF staff commited on
Commit
9f940a0
β€’
1 Parent(s): 939822c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -46,6 +46,7 @@ STYLE = """
46
  height: auto;
47
  text-align: center;
48
  display:inline-block;
 
49
  }
50
  #root {
51
  display: inline-grid!important;
@@ -417,6 +418,7 @@ def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequ
417
 
418
  return original_tree
419
 
 
420
  @spaces.GPU
421
  def get_beam_search_html(
422
  input_text, number_steps, number_beams, length_penalty, num_return_sequences
@@ -441,9 +443,12 @@ def get_beam_search_html(
441
  # Sequences are padded anyway so you can batch decode them
442
  decoded_sequences = tokenizer.batch_decode(outputs.sequences)
443
 
444
- sequence_scores = (outputs.sequences_scores if number_beams > 1 else outputs.scores)
445
- for i, sequence in enumerate(decoded_sequences):
446
- markdown += f"\n- Score `{sequence_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
 
 
 
447
 
448
  if number_beams > 1:
449
  original_tree = generate_beams(
@@ -455,9 +460,9 @@ def get_beam_search_html(
455
  )
456
  else:
457
  original_tree = generate_beams(
458
- n_beams,
459
  input_text,
460
- outputs.logits,
461
  0,
462
  decoded_sequences,
463
  )
 
46
  height: auto;
47
  text-align: center;
48
  display:inline-block;
49
+ padding-bottom: 10px!important;
50
  }
51
  #root {
52
  display: inline-grid!important;
 
418
 
419
  return original_tree
420
 
421
+
422
  @spaces.GPU
423
  def get_beam_search_html(
424
  input_text, number_steps, number_beams, length_penalty, num_return_sequences
 
443
  # Sequences are padded anyway so you can batch decode them
444
  decoded_sequences = tokenizer.batch_decode(outputs.sequences)
445
 
446
+ if number_beams > 1:
447
+ for i, sequence in enumerate(decoded_sequences):
448
+ markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
449
+ else:
450
+ markdown += f"\n- `{clean(decoded_sequences[0].replace('<s> ', ''))}`"
451
+ print(outputs.logits)
452
 
453
  if number_beams > 1:
454
  original_tree = generate_beams(
 
460
  )
461
  else:
462
  original_tree = generate_beams(
463
+ number_beams,
464
  input_text,
465
+ outputs.scores,
466
  0,
467
  decoded_sequences,
468
  )