m-ric's picture
m-ric HF staff
Update app.py
7a37cfb verified
raw
history blame
6.7 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
STYLE = """
.container {
width: 100%;
display: grid;
align-items: center;
margin: 0!important;
}
.prose ul ul {
margin: 0!important;
font-size: 13px!important;
}
.tree {
padding: 0px;
margin: 0!important;
box-sizing: border-box;
font-size: 16px;
width: 100%;
height: auto;
text-align: center;
}
.tree ul {
padding-top: 20px;
position: relative;
transition: .5s;
margin: 0!important;
}
.tree li {
display: inline-table;
text-align: center;
list-style-type: none;
position: relative;
padding: 10px;
transition: .5s;
}
.tree li::before, .tree li::after {
content: '';
position: absolute;
top: 0;
right: 50%;
border-top: 1px solid #ccc;
width: 51%;
height: 10px;
}
.tree li::after {
right: auto;
left: 50%;
border-left: 1px solid #ccc;
}
.tree li:only-child::after, .tree li:only-child::before {
display: none;
}
.tree li:only-child {
padding-top: 0;
}
.tree li:first-child::before, .tree li:last-child::after {
border: 0 none;
}
.tree li:last-child::before {
border-right: 1px solid #ccc;
border-radius: 0 5px 0 0;
-webkit-border-radius: 0 5px 0 0;
-moz-border-radius: 0 5px 0 0;
}
.tree li:first-child::after {
border-radius: 5px 0 0 0;
-webkit-border-radius: 5px 0 0 0;
-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
content: '';
position: absolute;
top: 0;
left: 50%;
border-left: 1px solid #ccc;
width: 0;
height: 20px;
}
.tree li a {
border: 1px solid #ccc;
padding: 10px;
display: inline-grid;
border-radius: 5px;
text-decoration-line: none;
border-radius: 5px;
transition: .5s;
}
.tree li a span {
border: 1px solid #ccc;
border-radius: 5px;
color: #666;
padding: 8px;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a {
background: #c8e4f8;
color: #000;
border: 1px solid #94a0b4;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
border-color: #94a0b4;
}
"""
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Loading finished.")
def generate_html(token, node):
"""Recursively generate HTML for the tree."""
html_content = f" <li> <a href='#'> <span> <b>{token}</b> </span> "
html_content += node["table"] if node["table"] is not None else ""
html_content += "</a>"
if len(node["children"].keys()) > 0:
html_content += "<ul> "
for token, subnode in node["children"].items():
html_content += generate_html(token, subnode)
html_content += "</ul>"
html_content += "</li>"
return html_content
def generate_markdown_table(scores, sequence_prob, top_k=4, chosen_tokens=None):
markdown_table = """
<table>
<tr>
<th><b>Token</b></th>
<th><b>Step score</b></th>
<th><b>Cumulative score</b></th>
</tr>"""
for token_idx in np.argsort(scores)[-top_k:]:
token = tokenizer.decode([token_idx])
style = ""
if chosen_tokens and token in chosen_tokens:
style = "background-color:red"
markdown_table += f"""
<tr style={style}>
<td>{token}</td>
<td>{scores[token_idx]:.4f}</td>
<td>{scores[token_idx] + sequence_prob:.4f}</td>
</tr>"""
markdown_table += """
</table>"""
return markdown_table
def display_tree(start_sentence, scores, sequences, beam_indices):
display = """<div class="container">
<div class="tree">
<ul>"""
sequences = sequences.cpu().numpy()
print(tokenizer.batch_decode(sequences))
original_tree = {"table": None, "children": {}}
for sequence_ix in range(len(sequences)):
current_sequence_score = 0
current_tree = original_tree
for step, step_scores in enumerate(scores):
current_token_choice_ix = sequences[sequence_ix, step]
current_token_choice = tokenizer.decode([current_token_choice_ix])
current_beam = beam_indices[sequence_ix, step]
if current_token_choice not in current_tree["children"]:
current_tree["children"][current_token_choice] = {
"table": None,
"children": {},
}
# Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree
markdown_table = generate_markdown_table(
step_scores[current_beam, :], current_sequence_score,
chosen_tokens=current_tree["children"].keys(),
)
current_tree["table"] = markdown_table
current_tree = current_tree["children"][current_token_choice]
# Keep up to date the current cumulative score
current_sequence_score += step_scores[current_beam, current_token_choice_ix]
display += generate_html(start_sentence, original_tree)
display += """
</ul>
</div>
</body>
"""
return display
@spaces.GPU
def get_tables(input_text, number_steps, number_beams):
inputs = tokenizer([input_text], return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=number_steps,
num_beams=number_beams,
num_return_sequences=number_beams,
return_dict_in_generate=True,
output_scores=True,
top_k=5,
temperature=1.0,
do_sample=True,
)
print(outputs.sequences_scores)
tables = display_tree(
input_text,
outputs.scores,
outputs.sequences[:, len(inputs) :],
outputs.beam_indices[:, : -len(inputs)],
)
return tables
with gr.Blocks(
theme=gr.themes.Soft(
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green
),
css=STYLE,
) as demo:
text = gr.Textbox(label="Sentence to decode from", value="Today is")
steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4)
beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
button = gr.Button()
out = gr.Markdown(label="Output")
button.click(get_tables, inputs=[text, steps, beams], outputs=out)
demo.launch()