import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
print("Loading finished.")
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
STYLE = """
.custom-container {
display: grid;
align-items: center;
margin: 0!important;
overflow-y: hidden;
}
.prose ul ul {
font-size: 10px!important;
}
.prose li {
margin-bottom: 0!important;
}
.prose table {
margin-bottom: 0!important;
}
.prose td, th {
padding-left: 2px;
padding-right: 2px;
padding-top: 0;
padding-bottom: 0;
text-wrap:nowrap;
}
.tree {
padding: 0px;
margin: 0!important;
box-sizing: border-box;
font-size: 10px;
width: 100%;
height: auto;
text-align: center;
display:inline-block;
}
#root {
display: inline-grid!important;
width:auto!important;
min-width: 220px;
}
.tree ul {
padding-left: 20px;
position: relative;
transition: all 0.5s ease 0s;
display: flex;
flex-direction: column;
gap: 10px;
margin: 0px !important;
}
.tree li {
display: flex;
text-align: center;
list-style-type: none;
position: relative;
padding-left: 20px;
transition: all 0.5s ease 0s;
flex-direction: row;
justify-content: start;
align-items: center;
}
.tree li::before, .tree li::after {
content: "";
position: absolute;
left: 0px;
border-left: 1px solid var(--body-text-color);
width: 20px;
}
.tree li::before {
top: 0;
height:50%;
}
.tree li::after {
top: 50%;
height: 55%;
bottom: auto;
border-top: 1px solid var(--body-text-color);
}
.tree li:only-child::after, li:only-child::before {
display: none;
}
.tree li:first-child::before, .tree li:last-child::after {
border: 0 none;
}
.tree li:last-child::before {
border-bottom: 1px solid var(--body-text-color);
border-radius: 0px 0px 0px 5px;
-webkit-border-radius: 0px 0px 0px 5px;
-moz-border-radius: 0px 0px 0px 5px;
}
.tree li:first-child::after {
border-radius: 5px 0 0 0;
-webkit-border-radius: 5px 0 0 0;
-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
content: "";
position: absolute;
left: 0;
top: 50%;
border-top: 1px solid var(--body-text-color);
width: 20px;
height: 0;
}
.tree ul:has(> li:only-child)::before {
width:40px;
}
.child:before {
border-right: 2px solid var(--body-text-color);
border-bottom: 2px solid var(--body-text-color);
content: "";
position: absolute;
width: 10px;
left: 8px;
height: 10px;
top: 50%;
margin-top: -5px;
transform: rotate(315deg);
}
.box {
border: 1px solid var(--body-text-color);
padding: 5px;
border-radius: 5px;
text-decoration-line: none;
border-radius: 5px;
transition: .5s;
display: flex;
align-items: center;
justify-content: space-between;
overflow: hidden;
cursor: pointer;
}
.box span {
padding: 5px;
font-size: 12px;
letter-spacing: 1px;
font-weight: 500;
}
/*Hover-Section*/
.box:hover, .box:hover+ul li .box {
background: var(--primary-500);
}
.box:hover+ul li::after, .box:hover+ul li::before, .box:hover+ul::before, .box:hover+ul ul::before, .box:hover+ul .box::before {
border-color: var(--primary-500);
}
.chosen-token {
background-color: var(--primary-400);
}
.chosen-token td, .chosen-token tr {
color: black!important;
}
.end-of-text {
width:auto!important;
}
.nonfinal {
width:280px;
min-width: 280px;
}
.selected-sequence {
background-color: var(--secondary-500);
}
.nonselected-sequence {
background-color: var(--primary-500);
}
"""
def clean(s):
return s.replace("\n", r"\n").replace("\t", r"\t").strip()
def generate_markdown_table(
scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None
):
markdown_table = """
Token |
Step score |
Total score |
"""
for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
token = tokenizer.decode([token_idx])
item_class = ""
if chosen_tokens and token in chosen_tokens:
item_class = "chosen-token"
markdown_table += f"""
{clean(token)} |
{scores[token_idx]:.4f} |
{(scores[token_idx] + previous_cumul_score)/score_divider:.4f} |
"""
markdown_table += """
"""
return markdown_table
def generate_nodes(node, step):
"""Recursively generate HTML for the tree nodes."""
token = tokenizer.decode([node.current_token_ix])
if node.is_final:
if node.is_selected_sequence:
selected_class = "selected-sequence"
else:
selected_class = "nonselected-sequence"
return f" {clean(token)}
Total score: {node.total_score:.2f}
"
html_content = (
f" {clean(token)} "
)
if node.table is not None:
html_content += node.table
html_content += "
"
if len(node.children.keys()) > 0:
html_content += " "
for token_ix, subnode in node.children.items():
html_content += generate_nodes(subnode, step=step + 1)
html_content += "
"
html_content += ""
return html_content
def generate_html(start_sentence, original_tree):
html_output = f"""
{start_sentence} {original_tree.table}
"""
html_output += " "
for subnode in original_tree.children.values():
html_output += generate_nodes(subnode, step=1)
html_output += "
"
html_output += """
"""
return html_output
import pandas as pd
from typing import Dict
from dataclasses import dataclass
@dataclass
class BeamNode:
current_token_ix: int
cumulative_score: float
children_score_divider: float
table: str
current_sequence: str
children: Dict[int, "BeamNode"]
total_score: float
is_final: bool
is_selected_sequence: bool
def generate_beams(start_sentence, scores, length_penalty, decoded_sequences):
input_length = len(tokenizer([start_sentence], return_tensors="pt"))
original_tree = BeamNode(
cumulative_score=0,
current_token_ix=None,
table=None,
current_sequence=start_sentence,
children={},
children_score_divider=((input_length + 1) ** length_penalty),
total_score=None,
is_final=False,
is_selected_sequence=False,
)
n_beams = len(scores[0])
beam_trees = [original_tree] * n_beams
for step, step_scores in enumerate(scores):
(
top_token_indexes,
top_cumulative_scores,
beam_indexes,
current_sequence,
top_tokens,
) = ([], [], [], [], [])
for beam_ix in range(n_beams): # Get possible descendants for each beam
current_beam = beam_trees[beam_ix]
# skip if the beam is already final
if current_beam.is_final:
continue
# Get top cumulative scores for the current beam
current_top_token_indexes = list(
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
)
top_token_indexes += current_top_token_indexes
top_cumulative_scores += list(
np.array(scores[step][beam_ix][current_top_token_indexes])
+ current_beam.cumulative_score
)
beam_indexes += [beam_ix] * n_beams
current_sequence += [beam_trees[beam_ix].current_sequence] * n_beams
top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes]
top_df = pd.DataFrame.from_dict(
{
"token_index": top_token_indexes,
"cumulative_score": top_cumulative_scores,
"beam_index": beam_indexes,
"current_sequence": current_sequence,
"token": top_tokens,
}
)
maxes = top_df.groupby(["token_index", "current_sequence"])[
"cumulative_score"
].idxmax()
top_df = top_df.loc[maxes]
# Sort all top probabilities and keep top n_beams
top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
:n_beams
]
# Write the scores table - one per beam source
for beam_ix in reversed(list(range(n_beams))):
current_beam = beam_trees[beam_ix]
if current_beam.table is None:
selected_tokens = top_df_selected.loc[
top_df_selected["current_sequence"] == current_beam.current_sequence
]
markdown_table = generate_markdown_table(
step_scores[beam_ix, :],
current_beam.cumulative_score,
current_beam.children_score_divider,
chosen_tokens=list(selected_tokens["token"].values),
)
beam_trees[beam_ix].table = markdown_table
# Add new children for each beam
cumulative_scores = [beam.cumulative_score for beam in beam_trees]
for beam_ix in range(n_beams):
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
current_token_choice = tokenizer.decode([current_token_choice_ix])
# Update the source tree
source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"])
cumulative_score = (
cumulative_scores[source_beam_ix]
+ scores[step][source_beam_ix][current_token_choice_ix].numpy()
)
current_sequence = (
beam_trees[source_beam_ix].current_sequence + current_token_choice
)
beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
current_token_ix=current_token_choice_ix,
table=None,
children={},
current_sequence=current_sequence,
cumulative_score=cumulative_score,
total_score=cumulative_score
/ ((input_length + step - 1) ** length_penalty),
children_score_divider=((input_length + step) ** length_penalty),
is_final=(
step == len(scores) - 1
or current_token_choice_ix == tokenizer.eos_token_id
),
is_selected_sequence=(
current_sequence.replace("<|endoftext|>", "")
in [el.replace("<|endoftext|>", "") for el in decoded_sequences]
),
)
# Reassign all beams at once
beam_trees = [
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
for beam_ix in range(n_beams)
]
# Advance all beams by one token
for beam_ix in range(n_beams):
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
return original_tree
@spaces.GPU
def get_beam_search_html(
input_text, number_steps, number_beams, length_penalty, num_return_sequences
):
inputs = tokenizer([input_text], return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=number_steps,
num_beams=number_beams,
num_return_sequences=num_return_sequences,
return_dict_in_generate=True,
length_penalty=length_penalty,
output_scores=True,
do_sample=False,
)
markdown = "The conclusive sequences are the ones that end in an `<|endoftext|>` token or at the end of generation."
markdown += "\n\nThey are ranked by their scores, as given by the formula `score = cumulative_score / (output_length ** length_penalty)`.\n\n"
markdown += "Only the top `num_beams` scoring sequences are returned: in the tree they are highlighted in **blue**."
markdown += " The non-selected sequences are also shown in the tree, highlighted in **yellow**."
markdown += "\n#### Output sequences:"
# Sequences are padded anyway so you can batch decode them
decoded_sequences = tokenizer.batch_decode(outputs.sequences)
for i, sequence in enumerate(decoded_sequences):
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace(' ', ''))}`"
original_tree = generate_beams(
input_text,
outputs.scores[:],
length_penalty,
decoded_sequences,
)
html = generate_html(input_text, original_tree)
return html, markdown
def change_num_return_sequences(n_beams):
return gr.Slider(
label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams
)
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue=gr.themes.colors.yellow,
secondary_hue=gr.themes.colors.blue,
),
css=STYLE,
) as demo:
gr.Markdown(
"""# Beam Search Visualizer
Play with the parameters below to understand how beam search decoding works!
#### Parameters:
- **Sentence to decode from** (`inputs`): the input sequence to your decoder.
- **Number of steps** (`max_new_tokens`): the number of tokens to generate.
- **Number of beams** (`num_beams`): the number of beams to use.
- **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences.
- **Number of return sequences** (`num_return_sequences`): the number of sequences to be returned at the end of generation. Should be `<= num_beams`.
"""
)
text = gr.Textbox(
label="Sentence to decode from",
value="Conclusion: thanks a lot. This article was originally published on",
)
with gr.Row():
n_steps = gr.Slider(
label="Number of steps", minimum=1, maximum=10, step=1, value=4
)
n_beams = gr.Slider(
label="Number of beams", minimum=2, maximum=4, step=1, value=3
)
length_penalty = gr.Slider(
label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
)
num_return_sequences = gr.Slider(
label="Number of return sequences", minimum=1, maximum=3, step=1, value=2
)
n_beams.change(
fn=change_num_return_sequences, inputs=n_beams, outputs=num_return_sequences
)
button = gr.Button()
out_html = gr.Markdown()
out_markdown = gr.Markdown()
button.click(
get_beam_search_html,
inputs=[text, n_steps, n_beams, length_penalty, num_return_sequences],
outputs=[out_html, out_markdown],
)
demo.launch()