Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import numpy as np | |
import gradio as gr | |
import spaces | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = AutoModelForCausalLM.from_pretrained("gpt2") | |
print("Loading finished.") | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
if torch.cuda.is_available(): | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
STYLE = """ | |
.custom-container { | |
display: grid; | |
align-items: center; | |
margin: 0!important; | |
overflow-y: hidden; | |
} | |
.prose ul ul { | |
font-size: 10px!important; | |
} | |
.prose li { | |
margin-bottom: 0!important; | |
} | |
.prose table { | |
margin-bottom: 0!important; | |
} | |
.prose td, th { | |
padding-left: 2px; | |
padding-right: 2px; | |
padding-top: 0; | |
padding-bottom: 0; | |
text-wrap:nowrap; | |
} | |
.tree { | |
padding: 0px; | |
margin: 0!important; | |
box-sizing: border-box; | |
font-size: 10px; | |
width: 100%; | |
height: auto; | |
text-align: center; | |
display:inline-block; | |
} | |
#root { | |
display: inline-grid!important; | |
width:auto!important; | |
min-width: 220px; | |
} | |
.tree ul { | |
padding-left: 20px; | |
position: relative; | |
transition: all 0.5s ease 0s; | |
display: flex; | |
flex-direction: column; | |
gap: 10px; | |
margin: 0px !important; | |
} | |
.tree li { | |
display: flex; | |
text-align: center; | |
list-style-type: none; | |
position: relative; | |
padding-left: 20px; | |
transition: all 0.5s ease 0s; | |
flex-direction: row; | |
justify-content: start; | |
align-items: center; | |
} | |
.tree li::before, .tree li::after { | |
content: ""; | |
position: absolute; | |
left: 0px; | |
border-left: 1px solid var(--body-text-color); | |
width: 20px; | |
} | |
.tree li::before { | |
top: 0; | |
height:50%; | |
} | |
.tree li::after { | |
top: 50%; | |
height: 55%; | |
bottom: auto; | |
border-top: 1px solid var(--body-text-color); | |
} | |
.tree li:only-child::after, li:only-child::before { | |
display: none; | |
} | |
.tree li:first-child::before, .tree li:last-child::after { | |
border: 0 none; | |
} | |
.tree li:last-child::before { | |
border-bottom: 1px solid var(--body-text-color); | |
border-radius: 0px 0px 0px 5px; | |
-webkit-border-radius: 0px 0px 0px 5px; | |
-moz-border-radius: 0px 0px 0px 5px; | |
} | |
.tree li:first-child::after { | |
border-radius: 5px 0 0 0; | |
-webkit-border-radius: 5px 0 0 0; | |
-moz-border-radius: 5px 0 0 0; | |
} | |
.tree ul ul::before { | |
content: ""; | |
position: absolute; | |
left: 0; | |
top: 50%; | |
border-top: 1px solid var(--body-text-color); | |
width: 20px; | |
height: 0; | |
} | |
.tree ul:has(> li:only-child)::before { | |
width:40px; | |
} | |
.tree li a:before { | |
border-right: 2px solid var(--body-text-color); | |
border-bottom: 2px solid var(--body-text-color); | |
content: ""; | |
position: absolute; | |
width: 10px; | |
left: 8px; | |
height: 10px; | |
top: 50%; | |
margin-top: -5px; | |
transform: rotate(315deg); | |
} | |
.tree li a { | |
border: 1px solid var(--body-text-color); | |
padding: 5px; | |
border-radius: 5px; | |
text-decoration-line: none; | |
border-radius: 5px; | |
transition: .5s; | |
display: flex; | |
align-items: center; | |
justify-content: space-between; | |
overflow: hidden; | |
} | |
.tree li a span { | |
padding: 5px; | |
font-size: 12px; | |
letter-spacing: 1px; | |
font-weight: 500; | |
} | |
/*Hover-Section*/ | |
.tree li a:hover, .tree li a:hover+ul li a { | |
background: #ffedd5; | |
} | |
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before { | |
border-color: #7c2d12; | |
} | |
.end-of-text, .chosen { | |
background-color: #ea580c; | |
} | |
.end-of-text { | |
width:auto!important; | |
} | |
.nonfinal { | |
width:280px; | |
min-width: 280px; | |
} | |
""" | |
def clean(s): | |
return s.replace("\n", r"\n").replace("\t", r"\t").strip() | |
def generate_markdown_table( | |
scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None | |
): | |
markdown_table = """ | |
<table> | |
<tr> | |
<th><b>Token</b></th> | |
<th><b>Step score</b></th> | |
<th><b>Total score</b></th> | |
</tr>""" | |
for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]: | |
token = tokenizer.decode([token_idx]) | |
item_class = "" | |
if chosen_tokens and token in chosen_tokens: | |
item_class = "chosen" | |
markdown_table += f""" | |
<tr class={item_class}> | |
<td>{clean(token)}</td> | |
<td>{scores[token_idx]:.4f}</td> | |
<td>{(scores[token_idx] + previous_cumul_score)/score_divider:.4f}</td> | |
</tr>""" | |
markdown_table += """ | |
</table>""" | |
return markdown_table | |
def generate_nodes(token_ix, node, step): | |
"""Recursively generate HTML for the tree nodes.""" | |
token = tokenizer.decode([token_ix]) | |
if node.is_final: | |
return f"<li> <a href='#' class='end-of-text'> <span> <b>{clean(token)}</b> <br>Total score: {node.total_score:.2f}</span> </a> </li>" | |
html_content = ( | |
f"<li> <a href='#' class='nonfinal'> <span> <b>{clean(token)}</b> </span>" | |
) | |
if node.table is not None: | |
html_content += node.table | |
html_content += "</a>" | |
if len(node.children.keys()) > 0: | |
html_content += "<ul> " | |
for token_ix, subnode in node.children.items(): | |
html_content += generate_nodes(token_ix, subnode, step=step + 1) | |
html_content += "</ul>" | |
html_content += "</li>" | |
return html_content | |
def generate_html(start_sentence, original_tree): | |
html_output = f"""<div class="custom-container"> | |
<div class="tree"> | |
<ul> <li> <a href='#' id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>""" | |
html_output += "<ul> " | |
for token_ix, subnode in original_tree.children.items(): | |
html_output += generate_nodes(token_ix, subnode, step=1) | |
html_output += "</ul>" | |
html_output += """ | |
</li> </ul> | |
</div> | |
</body> | |
""" | |
return html_output | |
import pandas as pd | |
from typing import Dict | |
from dataclasses import dataclass | |
class BeamNode: | |
current_token_ix: int | |
cumulative_score: float | |
children_score_divider: float | |
table: str | |
current_sentence: str | |
children: Dict[int, "BeamNode"] | |
total_score: float | |
is_final: bool | |
def generate_beams(start_sentence, scores, sequences, length_penalty): | |
sequences = sequences.cpu().numpy() | |
input_length = len(tokenizer([start_sentence], return_tensors="pt")) | |
original_tree = BeamNode( | |
cumulative_score=0, | |
current_token_ix=None, | |
table=None, | |
current_sentence=start_sentence, | |
children={}, | |
children_score_divider=((input_length + 1) ** length_penalty), | |
total_score=None, | |
is_final=False, | |
) | |
n_beams = len(scores[0]) | |
beam_trees = [original_tree] * n_beams | |
for step, step_scores in enumerate(scores): | |
( | |
top_token_indexes, | |
top_cumulative_scores, | |
beam_indexes, | |
current_completions, | |
top_tokens, | |
) = ([], [], [], [], []) | |
for beam_ix in range(n_beams): # Get possible descendants for each beam | |
current_beam = beam_trees[beam_ix] | |
# skip if the beam is already final | |
if current_beam.is_final: | |
continue | |
# Get top cumulative scores for the current beam | |
current_top_token_indexes = list( | |
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1] | |
) | |
top_token_indexes += current_top_token_indexes | |
top_cumulative_scores += list( | |
np.array(scores[step][beam_ix][current_top_token_indexes]) | |
+ current_beam.cumulative_score | |
) | |
beam_indexes += [beam_ix] * n_beams | |
current_completions += [beam_trees[beam_ix].current_sentence] * n_beams | |
top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes] | |
top_df = pd.DataFrame.from_dict( | |
{ | |
"token_index": top_token_indexes, | |
"cumulative_score": top_cumulative_scores, | |
"beam_index": beam_indexes, | |
"current_completions": current_completions, | |
"token": top_tokens, | |
} | |
) | |
maxes = top_df.groupby(["token_index", "current_completions"])[ | |
"cumulative_score" | |
].idxmax() | |
top_df = top_df.loc[maxes] | |
# Sort all top probabilities and keep top n_beams | |
top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[ | |
:n_beams | |
] | |
# Write the scores table - one per beam source? | |
# Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse | |
for beam_ix in reversed(list(range(n_beams))): | |
current_beam = beam_trees[beam_ix] | |
selected_tokens = top_df_selected.loc[ | |
top_df_selected["beam_index"] == beam_ix | |
] | |
markdown_table = generate_markdown_table( | |
step_scores[beam_ix, :], | |
current_beam.cumulative_score, | |
current_beam.children_score_divider, | |
chosen_tokens=list(selected_tokens["token"].values), | |
) | |
beam_trees[beam_ix].table = markdown_table | |
# Add new children for each beam | |
cumulative_scores = [beam.cumulative_score for beam in beam_trees] | |
for beam_ix in range(n_beams): | |
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] | |
current_token_choice = tokenizer.decode([current_token_choice_ix]) | |
# Update the source tree | |
source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"]) | |
cumulative_score = ( | |
cumulative_scores[source_beam_ix] | |
+ scores[step][source_beam_ix][current_token_choice_ix].numpy() | |
) | |
beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode( | |
current_token_ix=current_token_choice_ix, | |
table=None, | |
children={}, | |
current_sentence=beam_trees[source_beam_ix].current_sentence | |
+ current_token_choice, | |
cumulative_score=cumulative_score, | |
total_score=cumulative_score | |
/ ((input_length + step - 1) ** length_penalty), | |
children_score_divider=((input_length + step) ** length_penalty), | |
is_final=( | |
step == len(scores) - 1 | |
or current_token_choice_ix == tokenizer.eos_token_id | |
), | |
) | |
# Reassign all beams at once | |
beam_trees = [ | |
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])] | |
for beam_ix in range(n_beams) | |
] | |
# Advance all beams by one token | |
for beam_ix in range(n_beams): | |
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] | |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix] | |
return original_tree | |
def get_beam_search_html(input_text, number_steps, number_beams, length_penalty): | |
inputs = tokenizer([input_text], return_tensors="pt") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=number_steps, | |
num_beams=number_beams, | |
num_return_sequences=number_beams, | |
return_dict_in_generate=True, | |
length_penalty=length_penalty, | |
output_scores=True, | |
do_sample=False, | |
) | |
markdown = "Output sequences:" | |
# Sequences are padded anyway so you can batch decode them | |
decoded_sequences = tokenizer.batch_decode(outputs.sequences) | |
for i, sequence in enumerate(decoded_sequences): | |
markdown += f"\n- '{clean(sequence.replace('<s> ', ''))}' (score {outputs.sequences_scores[i]:.2f})" | |
original_tree = generate_beams( | |
input_text, | |
outputs.scores[:], | |
outputs.sequences[:, :], | |
length_penalty, | |
) | |
html = generate_html(input_text, original_tree) | |
return html, markdown | |
with gr.Blocks( | |
theme=gr.themes.Soft( | |
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.yellow | |
), | |
css=STYLE, | |
) as demo: | |
gr.Markdown( | |
"""# Beam search visualizer | |
Play with the parameters below to understand how beam search decoding works! | |
#### Parameters: | |
- **Sentence to decode from**: the input sequence to your decoder. | |
- **Number of steps**: the number of tokens to generate | |
- **Number of beams**: the number of beams to use | |
- **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. | |
""" | |
) | |
text = gr.Textbox( | |
label="Sentence to decode from", | |
value="Conclusion: thanks a lot. This article was originally published on", | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Number of steps", minimum=1, maximum=8, step=1, value=4 | |
) | |
beams = gr.Slider( | |
label="Number of beams", minimum=2, maximum=4, step=1, value=3 | |
) | |
length_penalty = gr.Slider( | |
label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1 | |
) | |
button = gr.Button() | |
out_html = gr.Markdown() | |
out_markdown = gr.Markdown() | |
button.click( | |
get_beam_search_html, | |
inputs=[text, steps, beams, length_penalty], | |
outputs=[out_html, out_markdown], | |
) | |
demo.launch() |