Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import numpy as np | |
import gradio as gr | |
import spaces | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = AutoModelForCausalLM.from_pretrained("gpt2") | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
print("Loading finished.") | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
if torch.cuda.is_available(): | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
STYLE = """ | |
.custom-container { | |
width: 100%; | |
display: grid; | |
align-items: center; | |
margin: 0!important; | |
overflow: scroll; | |
} | |
.prose ul ul { | |
margin: 0!important; | |
font-size: 10px!important; | |
} | |
.prose td, th { | |
padding-left: 2px; | |
padding-right: 2px; | |
padding-top: 0; | |
padding-bottom: 0; | |
} | |
.tree { | |
padding: 0px; | |
margin: 0!important; | |
box-sizing: border-box; | |
font-size: 10px; | |
width: 100%; | |
min-width: 2000px; | |
height: auto; | |
text-align: center; | |
} | |
.tree ul { | |
padding-top: 20px; | |
position: relative; | |
transition: .5s; | |
margin: 0!important; | |
display: flex; | |
flex-direction: row; | |
justify-content: center; | |
gap:10px; | |
} | |
.tree li { | |
display: inline-table; | |
text-align: center; | |
list-style-type: none; | |
position: relative; | |
padding-top: 10px; | |
transition: .5s; | |
} | |
.tree li::before, .tree li::after { | |
content: ''; | |
position: absolute; | |
top: 0; | |
right: 50%; | |
border-top: 1px solid #ccc; | |
width: 51%; | |
height: 10px; | |
} | |
.tree li::after { | |
right: auto; | |
left: 50%; | |
border-left: 1px solid #ccc; | |
} | |
.tree li:only-child::after, .tree li:only-child::before { | |
display: none; | |
} | |
.tree li:first-child::before, .tree li:last-child::after { | |
border: 0 none; | |
} | |
.tree li:last-child::before { | |
border-right: 1px solid #ccc; | |
border-radius: 0 5px 0 0; | |
-webkit-border-radius: 0 5px 0 0; | |
-moz-border-radius: 0 5px 0 0; | |
} | |
.tree li:first-child::after { | |
border-radius: 5px 0 0 0; | |
-webkit-border-radius: 5px 0 0 0; | |
-moz-border-radius: 5px 0 0 0; | |
} | |
.tree ul ul::before { | |
content: ''; | |
position: absolute; | |
top: 0; | |
left: 50%; | |
border-left: 1px solid #ccc; | |
width: 0; | |
height: 20px; | |
} | |
.tree li a { | |
border: 1px solid #ccc; | |
padding: 5px; | |
display: inline-grid; | |
border-radius: 5px; | |
text-decoration-line: none; | |
border-radius: 5px; | |
transition: .5s; | |
} | |
.tree li a span { | |
color: #666; | |
padding: 5px; | |
font-size: 12px; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
font-weight: 500; | |
} | |
/*Hover-Section*/ | |
.tree li a:hover, .tree li a:hover+ul li a { | |
background: #c8e4f8; | |
color: #000; | |
} | |
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before { | |
border-color: #94a0b4; | |
} | |
.chosen { | |
background-color: red; | |
} | |
""" | |
def generate_nodes(token, node): | |
"""Recursively generate HTML for the tree nodes.""" | |
html_content = f" <li> <a href='#' class={('chosen' if node.table is None else '')}> <span> <b>{token}</b> </span> " | |
html_content += node.table if node.table is not None else "" | |
html_content += "</a>" | |
if len(node.children.keys()) > 0: | |
html_content += "<ul> " | |
for token, subnode in node.children.items(): | |
html_content += generate_nodes(token, subnode) | |
html_content += "</ul>" | |
html_content += "</li>" | |
return html_content | |
def generate_markdown_table(scores, sequence_prob, top_k=4, chosen_tokens=None): | |
markdown_table = """ | |
<table> | |
<tr> | |
<th><b>Token</b></th> | |
<th><b>Step score</b></th> | |
<th><b>Total score</b></th> | |
</tr>""" | |
for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]: | |
token = tokenizer.decode([token_idx]) | |
style = "" | |
if chosen_tokens and token in chosen_tokens: | |
style = "background-color:red" | |
markdown_table += f""" | |
<tr style={style}> | |
<td>{token}</td> | |
<td>{scores[token_idx]:.4f}</td> | |
<td>{scores[token_idx] + sequence_prob:.4f}</td> | |
</tr>""" | |
markdown_table += """ | |
</table>""" | |
return markdown_table | |
def generate_html(start_sentence, original_tree): | |
html_output = """<div class="custom-container"> | |
<div class="tree"> | |
<ul>""" | |
html_output += generate_nodes(start_sentence, original_tree) | |
html_output += """ | |
</ul> | |
</div> | |
</body> | |
""" | |
return html_output | |
import pandas as pd | |
from typing import Dict | |
from dataclasses import dataclass | |
class BeamNode: | |
cumulative_score: float | |
table: str | |
current_sentence: str | |
children: Dict[str, "BeamNode"] | |
def generate_beams(start_sentence, scores, sequences, beam_indices): | |
print(tokenizer.batch_decode(sequences)) | |
sequences = sequences.cpu().numpy() | |
original_tree = BeamNode( | |
cumulative_score=0, table=None, current_sentence=start_sentence, children={} | |
) | |
n_beams = len(scores[0]) | |
beam_trees = [original_tree] * n_beams | |
for step, step_scores in enumerate(scores): | |
( | |
top_token_indexes, | |
top_cumulative_scores, | |
beam_indexes, | |
current_completions, | |
top_tokens, | |
) = ([], [], [], [], []) | |
for beam_ix in range(n_beams): | |
current_beam = beam_trees[beam_ix] | |
# Get top cumulative scores for the current beam | |
current_top_token_indexes = list( | |
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1] | |
) | |
top_token_indexes += current_top_token_indexes | |
top_cumulative_scores += list( | |
np.array(scores[step][beam_ix][current_top_token_indexes]) | |
+ current_beam.cumulative_score | |
) | |
beam_indexes += [beam_ix] * n_beams | |
current_completions += [beam_trees[beam_ix].current_sentence] * n_beams | |
top_tokens += [ | |
tokenizer.decode([el]) for el in current_top_token_indexes | |
] | |
top_df = pd.DataFrame.from_dict( | |
{ | |
"token_index": top_token_indexes, | |
"cumulative_score": top_cumulative_scores, | |
"beam_index": beam_indexes, | |
"current_completions": current_completions, | |
"token": top_tokens, | |
} | |
) | |
maxes = top_df.groupby(["token_index", "current_completions"])[ | |
"cumulative_score" | |
].idxmax() | |
top_df = top_df.loc[maxes] | |
# Sort all top probabilities and keep top n_beams | |
top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[ | |
:n_beams | |
] | |
# Write the scores table - one per beam source? | |
# Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse | |
for beam_ix in reversed(list(range(n_beams))): | |
current_beam = beam_trees[beam_ix] | |
selected_tokens = top_df_selected.loc[top_df_selected["beam_index"] == beam_ix] | |
markdown_table = generate_markdown_table( | |
step_scores[beam_ix, :], | |
current_beam.cumulative_score, | |
chosen_tokens=list(selected_tokens["token"].values), | |
) | |
beam_trees[beam_ix].table = markdown_table | |
# Add new children for each beam | |
cumulative_scores = [beam.cumulative_score for beam in beam_trees] | |
for beam_ix in range(n_beams): | |
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] | |
current_token_choice = tokenizer.decode([current_token_choice_ix]) | |
# Update the source tree | |
source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"]) | |
previous_len = len(str(original_tree)) | |
beam_trees[source_beam_ix].children[current_token_choice] = BeamNode( | |
table=None, | |
children={}, | |
current_sentence=beam_trees[source_beam_ix].current_sentence | |
+ current_token_choice, | |
cumulative_score=cumulative_scores[source_beam_ix] | |
+ scores[step][source_beam_ix][current_token_choice_ix].numpy(), | |
) | |
assert ( | |
len(str(original_tree)) > previous_len | |
), "Original tree has not increased size" | |
# Reassign all beams at once | |
beam_trees = [ | |
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])] | |
for beam_ix in range(n_beams) | |
] | |
# Advance all beams by one token | |
for beam_ix in range(n_beams): | |
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] | |
current_token_choice = tokenizer.decode([current_token_choice_ix]) | |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice] | |
return original_tree | |
def get_beam_search_html(input_text, number_steps, number_beams): | |
inputs = tokenizer([input_text], return_tensors="pt") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=number_steps, | |
num_beams=number_beams, | |
num_return_sequences=number_beams, | |
return_dict_in_generate=True, | |
output_scores=True, | |
top_k=5, | |
do_sample=False, | |
) | |
original_tree = generate_beams( | |
input_text, | |
outputs.scores[:], | |
outputs.sequences[:, :], | |
outputs.beam_indices[:, :], | |
) | |
html = generate_html(input_text, original_tree) | |
print(html) | |
return html | |
with gr.Blocks( | |
theme=gr.themes.Soft( | |
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green | |
), | |
css=STYLE, | |
) as demo: | |
text = gr.Textbox(label="Sentence to decode from", value="Today is") | |
steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4) | |
beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3) | |
button = gr.Button() | |
out = gr.Markdown(label="Output") | |
button.click(get_beam_search_html, inputs=[text, steps, beams], outputs=out) | |
demo.launch() |