File size: 17,071 Bytes
2d63e52
88356be
 
fb66618
5abc01d
88356be
89b8c0b
 
e4010d8
 
 
2d63e52
 
46edcf6
 
2d63e52
ad36776
e4010d8
ad36776
 
5da9cef
f6daa90
ad36776
9a4471c
e4010d8
9a4471c
07a76f8
 
 
5582229
 
 
e4010d8
 
 
 
 
fe4c0f1
e4010d8
ad36776
a7a4e14
da19af6
a7a4e14
e4010d8
ad36776
 
07a76f8
 
ad36776
5582229
 
 
57c2a5d
5582229
ad36776
07a76f8
 
 
e4010d8
07a76f8
 
 
ad36776
 
07a76f8
 
 
 
 
 
 
 
 
ad36776
 
07a76f8
 
 
 
 
ad36776
07a76f8
 
 
ad36776
07a76f8
 
5582229
07a76f8
 
ea604ea
07a76f8
 
ad36776
 
07a76f8
ad36776
 
07a76f8
 
 
 
ad36776
 
 
 
 
 
 
07a76f8
 
 
 
 
 
 
 
 
 
ad36776
c3aa211
bfbc171
 
37b41e9
 
 
bfbc171
37b41e9
 
 
 
 
73610f4
27c3fcd
e4010d8
ad36776
 
 
 
5582229
 
fe4c0f1
 
ad36776
73610f4
e4010d8
ad36776
 
 
 
 
73610f4
d5baba8
ad36776
73610f4
d5baba8
ad36776
dccdd11
 
 
 
 
fe4c0f1
 
5582229
e4010d8
fe4c0f1
 
 
c3aa211
 
d9bd86a
c3aa211
dccdd11
d5baba8
dccdd11
a9a7993
 
 
ad36776
bc46ee1
7a37cfb
e20ac5c
fe4c0f1
ad36776
 
e20ac5c
 
 
ad36776
 
 
 
7a37cfb
e4010d8
ad36776
e4010d8
ad36776
4ef6980
ad36776
c3aa211
ad36776
4ef6980
e20ac5c
c9b1b1f
e20ac5c
ad36776
 
 
 
 
 
c3aa211
e20ac5c
c3aa211
 
50809fa
dccdd11
 
 
 
a9a7993
50809fa
55b6593
a9a7993
55b6593
efa8da2
 
73610f4
50809fa
e20ac5c
 
 
c3aa211
e20ac5c
 
50809fa
e20ac5c
 
 
e4010d8
e20ac5c
73610f4
a9a7993
50809fa
c3aa211
 
50809fa
e4010d8
73610f4
 
 
e4010d8
 
 
 
 
 
 
 
 
 
 
50809fa
e4010d8
e20ac5c
e4010d8
c3aa211
e20ac5c
50809fa
 
c3aa211
e4010d8
 
73826bd
e4010d8
e20ac5c
50809fa
e20ac5c
c3aa211
e20ac5c
7632c21
50809fa
 
c3aa211
e4010d8
 
7632c21
50809fa
e4010d8
73610f4
 
e4010d8
 
 
 
55b6593
e4010d8
7632c21
 
 
f4596d7
 
e4010d8
50809fa
 
 
 
f4596d7
e4010d8
 
f4596d7
e4010d8
 
f4596d7
e4010d8
f4596d7
e4010d8
 
 
55b6593
07a76f8
f4596d7
e4010d8
 
 
 
 
 
55b6593
e4010d8
7632c21
e4010d8
 
55b6593
e4010d8
 
 
 
ad36776
7632c21
 
e4010d8
7632c21
e4010d8
7632c21
 
 
 
 
 
 
 
 
 
 
 
 
ad36776
aaaec2a
7632c21
 
e4010d8
55b6593
7632c21
 
55b6593
 
7632c21
55b6593
 
 
 
 
7632c21
 
ad36776
73610f4
e4010d8
7632c21
e4010d8
73610f4
 
 
7632c21
7a37cfb
7632c21
c3aa211
 
 
7632c21
e20ac5c
50809fa
e4010d8
 
c3aa211
e20ac5c
7632c21
 
 
55b6593
 
 
 
e4010d8
ad36776
73610f4
e4010d8
7632c21
 
e4010d8
 
 
7632c21
 
e20ac5c
 
e4010d8
ad36776
5abc01d
55b6593
 
 
ad36776
 
 
 
 
ceb7300
a45a1f9
ad36776
5e72e33
ad36776
e4010d8
bc46ee1
c3aa211
 
179959e
 
dccdd11
50809fa
4fde691
 
c3aa211
ad36776
73826bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4010d8
4fde691
7a37cfb
 
a45a1f9
55b6593
 
 
06948d6
 
ad36776
 
c3aa211
 
ad36776
 
 
50809fa
179959e
7632c21
4fde691
7632c21
179959e
c3aa211
93a00aa
 
c3aa211
5a53522
93a00aa
50809fa
 
 
 
7632c21
50809fa
e20ac5c
13abb86
a9a7993
50809fa
13abb86
f9fcae6
50809fa
 
f6daa90
50809fa
a45a1f9
7632c21
13abb86
06948d6
55b6593
 
 
ad36776
4fde691
 
50809fa
 
a45a1f9
50809fa
 
ad36776
7632c21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

print("Loading finished.")

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

STYLE = """
.custom-container {
	display: grid;
	align-items: center;
    margin: 0!important;
    overflow-y: hidden;
}
.prose ul ul {
    font-size: 10px!important;
}
.prose li {
    margin-bottom: 0!important;
}
.prose table {
    margin-bottom: 0!important;
}
.prose td, th {
    padding-left: 2px;
    padding-right: 2px;
    padding-top: 0;
    padding-bottom: 0;
    text-wrap:nowrap;
}
.tree {
	padding: 0px;
	margin: 0!important;
	box-sizing: border-box;
    font-size: 10px;
	width: 100%;
	height: auto;
	text-align: center;
    display:inline-block;
}
#root {
    display: inline-grid!important;
    width:auto!important;
    min-width: 220px;
}
.tree ul {
    padding-left: 20px;
    position: relative;
    transition: all 0.5s ease 0s;
    display: flex;
    flex-direction: column;
    gap: 10px;
    margin: 0px !important;
}
.tree li {
    display: flex;
    text-align: center;
    list-style-type: none;
    position: relative;
    padding-left: 20px;
    transition: all 0.5s ease 0s;
    flex-direction: row;
    justify-content: start;
    align-items: center;
}
.tree li::before, .tree li::after {
    content: "";
    position: absolute;
    left: 0px;
    border-left: 1px solid var(--body-text-color);
    width: 20px;
}
.tree li::before {
    top: 0;
    height:50%;
}
.tree li::after {
    top: 50%;
    height: 55%;
    bottom: auto;
    border-top: 1px solid var(--body-text-color);
}
.tree li:only-child::after, li:only-child::before {
    display: none;
}
.tree li:first-child::before, .tree li:last-child::after {
    border: 0 none;
}
.tree li:last-child::before {
	border-bottom: 1px solid var(--body-text-color);
	border-radius: 0px 0px 0px 5px;
	-webkit-border-radius: 0px 0px 0px 5px;
	-moz-border-radius: 0px 0px 0px 5px;
}
.tree li:first-child::after {
	border-radius: 5px 0 0 0;
	-webkit-border-radius: 5px 0 0 0;
	-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
    content: "";
    position: absolute;
    left: 0;
    top: 50%;
    border-top: 1px solid var(--body-text-color);
    width: 20px;
    height: 0;
}
.tree ul:has(> li:only-child)::before {
    width:40px;
}
.child:before {
    border-right: 2px solid var(--body-text-color);
    border-bottom: 2px solid var(--body-text-color);
    content: "";
    position: absolute;
    width: 10px;
    left: 8px;
    height: 10px;
    top: 50%;
    margin-top: -5px;
    transform: rotate(315deg);
}
.tree li a {
	border: 1px solid var(--body-text-color);
	padding: 5px;
	border-radius: 5px;
	text-decoration-line: none;
	border-radius: 5px;
	transition: .5s;
    display: flex;
    align-items: center;
    justify-content: space-between;
    overflow: hidden;
}
.tree li a span {
	padding: 5px;
	font-size: 12px;
	letter-spacing: 1px;
	font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
	background: var(--primary-500);
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before, .tree li a:hover+ul a::before {
	border-color: var(--primary-500);
}
.chosen-token {
    background-color: var(--primary-400);
}
.chosen-token td, .chosen-token tr {
    color: black!important;
}
.end-of-text {
    width:auto!important;
}
.nonfinal {
    width:280px;
    min-width: 280px;
}
.selected-sequence {
    background-color: var(--secondary-500);
}
.nonselected-sequence {
    background-color: var(--primary-500);
}
.nopadding {
    padding-left: 0;
}
"""


def clean(s):
    return s.replace("\n", r"\n").replace("\t", r"\t").strip()


def generate_markdown_table(
    scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None
):
    markdown_table = """
    <table>
        <tr>
            <th><b>Token</b></th>
            <th><b>Step score</b></th>
            <th><b>Total score</b></th>
        </tr>"""
    for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
        token = tokenizer.decode([token_idx])
        item_class = ""
        if chosen_tokens and token in chosen_tokens:
            item_class = "chosen-token"
        markdown_table += f"""
        <tr class={item_class}>
            <td>{clean(token)}</td>
            <td>{scores[token_idx]:.4f}</td>
            <td>{(scores[token_idx] + previous_cumul_score)/score_divider:.4f}</td>
        </tr>"""
    markdown_table += """
    </table>"""
    return markdown_table


def generate_nodes(node, step):
    """Recursively generate HTML for the tree nodes."""
    token = tokenizer.decode([node.current_token_ix])

    if node.is_final:
        if node.is_selected_sequence:
            selected_class = "selected-sequence"
        else:
            selected_class = "nonselected-sequence"
        return f"<li> <a class='end-of-text child {selected_class}'> <span> <b>{clean(token)}</b> <br>Total score: {node.total_score:.2f}</span> </a> </li>"

    html_content = (
        f"<li> <a class='nonfinal child'> <span> <b>{clean(token)}</b> </span>"
    )
    if node.table is not None:
        html_content += node.table
    html_content += "</a>"

    if len(node.children.keys()) > 0:
        html_content += "<ul> "
        for token_ix, subnode in node.children.items():
            html_content += generate_nodes(subnode, step=step + 1)
        html_content += "</ul>"
    html_content += "</li>"

    return html_content


def generate_html(start_sentence, original_tree):
    html_output = f"""<div class="custom-container">
				<div class="tree">
                <ul> <li> <a id='root' class="nopadding"> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
    html_output += "<ul> "
    for subnode in original_tree.children.values():
        html_output += generate_nodes(subnode, step=1)
    html_output += "</ul>"
    html_output += """
        </li> </ul>
        </div>
    </body>
    """
    return html_output


import pandas as pd
from typing import Dict
from dataclasses import dataclass


@dataclass
class BeamNode:
    current_token_ix: int
    cumulative_score: float
    children_score_divider: float
    table: str
    current_sequence: str
    children: Dict[int, "BeamNode"]
    total_score: float
    is_final: bool
    is_selected_sequence: bool


def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences):
    original_tree = BeamNode(
        cumulative_score=0,
        current_token_ix=None,
        table=None,
        current_sequence=start_sentence,
        children={},
        children_score_divider=(1 ** length_penalty),
        total_score=None,
        is_final=False,
        is_selected_sequence=False,
    )
    beam_trees = [original_tree] * n_beams
    generation_length = len(scores)

    for step, step_scores in enumerate(scores):

        # Gather all possible descendants for each beam
        (
            top_token_indexes,
            top_cumulative_scores,
            beam_indexes,
            current_sequence,
            top_tokens,
            token_scores,
        ) = ([], [], [], [], [], [])

        score_idx = 0
        for beam_ix in range(len(beam_trees)):
            current_beam = beam_trees[beam_ix]

            # skip if the beam is already final
            if current_beam.is_final:
                continue
                
            # Get top cumulative scores for the current beam
            current_top_token_indexes = list(
                np.array(scores[step][score_idx].argsort()[-n_beams:])[::-1]
            )
            top_token_indexes += current_top_token_indexes
            token_scores += list(np.array(scores[step][score_idx][current_top_token_indexes]))
            top_cumulative_scores += list(
                np.array(scores[step][score_idx][current_top_token_indexes])
                + current_beam.cumulative_score
            )
            beam_indexes += [beam_ix] * n_beams
            current_sequence += [beam_trees[beam_ix].current_sequence] * n_beams
            top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes]
            score_idx += 1

        top_df = pd.DataFrame.from_dict(
            {
                "token_index": top_token_indexes,
                "cumulative_score": top_cumulative_scores,
                "beam_index": beam_indexes,
                "current_sequence": current_sequence,
                "token": top_tokens,
                "token_score": token_scores,
            }
        )
        maxes = top_df.groupby(["token_index", "current_sequence"])[
            "cumulative_score"
        ].idxmax()

        top_df = top_df.loc[maxes]

        # Sort all top probabilities and keep top n_beams * 2 (* 2 because each beam may end this iteration, and we
        # want to keep at least `n_beams` beams alive)
        top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
            :n_beams * 2
        ]
        beams_to_keep = 0
        unfinished_beams = 0
        for _, row in top_df_selected.iterrows():
            beams_to_keep += 1
            current_token_choice_ix = row["token_index"]
            is_final = step == len(scores) - 1 or current_token_choice_ix == tokenizer.eos_token_id
            if not is_final:
                unfinished_beams += 1
            if unfinished_beams >= n_beams:
                break
            if step == generation_length - 1 and beams_to_keep == n_beams:
                break
        top_df_selected_filtered = top_df_selected.iloc[:beams_to_keep]

        # Write the scores table in each beam tree
        score_idx = 0
        for beam_ix in range(len(beam_trees)):
            current_beam = beam_trees[beam_ix]
            if current_beam.table is None:
                selected_tokens = top_df_selected_filtered.loc[
                    top_df_selected_filtered["current_sequence"] == current_beam.current_sequence
                ]
                markdown_table = generate_markdown_table(
                    step_scores[score_idx, :],
                    current_beam.cumulative_score,
                    current_beam.children_score_divider,
                    chosen_tokens=list(selected_tokens["token"].values),
                )
                beam_trees[beam_ix].table = markdown_table
            if not current_beam.is_final:
                score_idx = min(score_idx + 1, n_beams - 1)

        # Add new children to each beam
        cumulative_scores = [beam.cumulative_score for beam in beam_trees]
        for _, row in top_df_selected_filtered.iterrows():
            # Update the source tree
            source_beam_ix = int(row["beam_index"])
            current_token_choice_ix = row["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])
            token_scores = row["token_score"]

            cumulative_score = cumulative_scores[source_beam_ix] + np.asarray(token_scores)
            current_sequence = (
                beam_trees[source_beam_ix].current_sequence + current_token_choice
            )
            is_final = step == len(scores) - 1 or current_token_choice_ix == tokenizer.eos_token_id
            beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
                current_token_ix=current_token_choice_ix,
                table=None,
                children={},
                current_sequence=current_sequence,
                cumulative_score=cumulative_score,
                total_score=cumulative_score / (step + 1 ** length_penalty),
                children_score_divider=((step + 2) ** length_penalty),
                is_final=is_final,
                is_selected_sequence=(
                    current_sequence.replace("<|endoftext|>", "")
                    in [el.replace("<|endoftext|>", "") for el in decoded_sequences]
                ),
            )

        # Swap all beams by descending cumul score, so that n°1 has the highest cumulative score, and so on
        beam_trees = [
            beam_trees[int(top_df_selected_filtered.iloc[beam_ix]["beam_index"])]
            for beam_ix in range(beams_to_keep)
        ]

        # Advance all beams by one token
        for beam_ix in range(beams_to_keep):
            current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
            beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]

    return original_tree

@spaces.GPU
def get_beam_search_html(
    input_text, number_steps, number_beams, length_penalty, num_return_sequences
):
    inputs = tokenizer([input_text], return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=number_steps,
        num_beams=number_beams,
        num_return_sequences=num_return_sequences,
        return_dict_in_generate=True,
        length_penalty=length_penalty,
        output_scores=True,
        do_sample=False,
    )
    markdown = "The conclusive sequences are the ones that end in an `<|endoftext|>` token or at the end of generation."
    markdown += "\n\nThey are ranked by their scores, as given by the formula `score = cumulative_score / (output_length ** length_penalty)`.\n\n"
    markdown += "Only the top `num_beams` scoring sequences are returned: in the tree they are highlighted in **<span style='color:var(--secondary-500)!important'>blue</span>**."
    markdown += " The non-selected sequences are also shown in the tree, highlighted in **<span style='color:var(--primary-500)!important'>yellow</span>**."
    markdown += "\n#### <span style='color:var(--secondary-500)!important'>Output sequences:</span>"
    # Sequences are padded anyway so you can batch decode them
    decoded_sequences = tokenizer.batch_decode(outputs.sequences)
    for i, sequence in enumerate(decoded_sequences):
        markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"

    if number_beams > 1:
        original_tree = generate_beams(
            number_beams,
            input_text,
            outputs.scores[:],
            length_penalty,
            decoded_sequences,
        )
    else:
        original_tree = generate_beams(
            n_beams,
            start_sentence,
            outputs.logits,
            0,
            decoded_sequences,
        )
        
    html = generate_html(input_text, original_tree)
    return html, markdown


def change_num_return_sequences(n_beams):
    return gr.Slider(
        label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams
    )


with gr.Blocks(
    theme=gr.themes.Soft(
        primary_hue=gr.themes.colors.yellow,
        secondary_hue=gr.themes.colors.blue,
    ),
    css=STYLE,
) as demo:
    gr.Markdown(
        """# <span style='color:var(--primary-500)!important'>Beam Search Visualizer</span>

Play with the parameters below to understand how beam search decoding works!

#### <span style='color:var(--primary-500)!important'>Parameters:</span>
- **Sentence to decode from** (`inputs`): the input sequence to your decoder.
- **Number of steps** (`max_new_tokens`): the number of tokens to generate.
- **Number of beams** (`num_beams`): the number of beams to use.
- **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences.
- **Number of return sequences** (`num_return_sequences`): the number of sequences to be returned at the end of generation. Should be `<= num_beams`.
"""
    )
    text = gr.Textbox(
        label="Sentence to decode from",
        value="Conclusion: thanks a lot. That's all for today",
    )
    with gr.Row():
        n_steps = gr.Slider(
            label="Number of steps", minimum=1, maximum=12, step=1, value=5
        )
        n_beams = gr.Slider(
            label="Number of beams", minimum=1, maximum=4, step=1, value=4
        )
        length_penalty = gr.Slider(
            label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
        )
        num_return_sequences = gr.Slider(
            label="Number of return sequences", minimum=1, maximum=4, step=1, value=3
        )

    n_beams.change(
        fn=change_num_return_sequences, inputs=n_beams, outputs=num_return_sequences
    )
    button = gr.Button()
    out_html = gr.Markdown()
    out_markdown = gr.Markdown()
    button.click(
        get_beam_search_html,
        inputs=[text, n_steps, n_beams, length_penalty, num_return_sequences],
        outputs=[out_html, out_markdown],
    )

demo.launch()