|
import gradio as gr |
|
from transformers import T5Tokenizer, AutoModelForCausalLM, GenerationConfig |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small") |
|
tokenizer.do_lower_case = True |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"rinna/japanese-gpt2-small", |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
def generate(text, max_length, num_beams, p): |
|
"""初回のテキスト生成 |
|
|
|
テキスト生成を行うが、デコード方法によって異なる結果になることを示すための処理を行う。 |
|
指定されたパラメタを使って、異なる4つデコード方法を同時に出力する。 |
|
|
|
Args: |
|
text: str |
|
Stateから取得(続きを生成するためのプロンプト) |
|
max_length: int |
|
Sliderから取得(全てのデコード方法に共通のパラメタ。生成する単語数) |
|
num_beams: int |
|
Sliderから取得(Beam Searchのパラメタ) |
|
p: int |
|
Sliderから取得(Top-p Samplingのパラメタ) |
|
|
|
Returns: |
|
tuple(str1, str2, str3) |
|
str1: State(生成結果を入出力の状態に反映) |
|
str2: TextArea(全文表示用のコンポーネントで使用) |
|
str3: TextArea(今回生成した文を表示するコンポーネントで使用) |
|
""" |
|
|
|
generate_config_list = [ |
|
GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
num_beams=1, |
|
do_sample=False |
|
), |
|
GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
num_beams=1, |
|
do_sample=True |
|
), |
|
GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
num_beams=num_beams, |
|
do_sample=False |
|
), |
|
GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
do_sample=True, |
|
top_p=p |
|
) |
|
] |
|
generated_texts = [] |
|
|
|
inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt")["input_ids"] |
|
for generate_config in generate_config_list: |
|
|
|
output = model.generate(inputs, generation_config=generate_config) |
|
generated = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
generated_texts.append("。\n".join(generated.replace(" ", "").split("。"))) |
|
|
|
|
|
|
|
return tuple(generated_texts) |
|
|
|
def select_out1(out1): |
|
"""out1が生成された時に、out1を後続の処理のデフォルト値に入力 |
|
""" |
|
return out1, out1, out1 |
|
|
|
def select_out(radio, out1, out2, out3, out4): |
|
"""後続の処理に使用する、初回の処理結果を選択する |
|
""" |
|
if radio == "1.Greedy": |
|
out = out1 |
|
elif radio == "2.Sampling": |
|
out = out2 |
|
elif radio == "3.Beam Search": |
|
out = out3 |
|
else: |
|
out = out4 |
|
return out, out, out |
|
|
|
def generate_next(now_text, radio, max_length, num_beams, p): |
|
"""続き生成 |
|
|
|
これまで出力したテキストを入力して受け取り、続きを生成する。 |
|
デコード方法を指定することができるが、そのパラメタは初回のテキスト生成と同じになる。 |
|
|
|
Args: |
|
now_text: str |
|
Stateから取得(続きを生成するためのプロンプト) |
|
radio: str |
|
Radioから取得(使用するデコード方法の名前) |
|
max_length: int |
|
Sliderから取得(初回のテキスト生成で使用した値をここでも使用) |
|
num_beams: int |
|
Sliderから取得(初回のテキスト生成で使用した値をここでも使用) |
|
p: int |
|
Sliderから取得(初回のテキスト生成で使用した値をここでも使用) |
|
|
|
Returns: |
|
next_text: str |
|
State(生成結果を入出力の状態に反映) |
|
next_text: str |
|
TextArea(全文表示用のコンポーネントで使用) |
|
gen_text: str |
|
TextArea(今回生成した文を表示するコンポーネントで使用) |
|
""" |
|
|
|
if radio == "1.Greedy": |
|
generate_config = GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
num_beams=1, |
|
do_sample=False |
|
) |
|
elif radio == "2.Sampling": |
|
generate_config = GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
num_beams=1, |
|
do_sample=True |
|
) |
|
elif radio == "3.Beam Search": |
|
generate_config = GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
num_beams=num_beams, |
|
do_sample=False |
|
) |
|
else: |
|
generate_config = GenerationConfig( |
|
max_new_tokens=max_length, |
|
no_repeat_ngram_size=3, |
|
do_sample=True, |
|
top_p=p |
|
) |
|
|
|
|
|
inputs = tokenizer(now_text, add_special_tokens=False, return_tensors="pt")["input_ids"] |
|
output = model.generate(inputs, generation_config=generate_config) |
|
generated = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
next_text = "。\n".join(generated.replace(" ", "").split("。")) |
|
gen_text = next_text[len(now_text)+1:] |
|
|
|
return next_text, next_text, gen_text |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(''' |
|
# テキスト生成 |
|
テキストを入力すると、4パターンのデコード方法でテキスト生成を実行します。 |
|
## 4つのパターン(入門編) |
|
1. Greedy: ビームサーチもサンプリングも行いません。毎回、最も確率の高い単語を選択します。 |
|
2. Sampling: モデルによって与えられた語彙全体の確率分布に基づいて次の単語を選択します。 |
|
3. Beam Search: 各タイムステップで複数の仮説を保持し、最終的に仮説ごとのシーケンス全体で最も高い確率を持つ仮説を選択します。 |
|
4. Top-p Sampling: 2の方法に関して、確率の和がpになる最小の単語にフィルタリングすることで、確率が低い単語が選ばれる可能性を無くします。 |
|
''') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(value="福岡のご飯は美味しい。", label="プロンプト") |
|
max_length = gr.Slider(100, 1000, step=100, value=100, label="生成するテキストの長さ") |
|
num_beams = gr.Slider(1, 10, step=1, value=6, label="beam幅") |
|
p = gr.Slider(0, 1, step=0.01, value=0.92, label="p") |
|
btn1 = gr.Button("4パターンで生成") |
|
|
|
with gr.Column(): |
|
out1 = gr.Textbox(label="Greedy") |
|
out2 = gr.Textbox(label="Sampling") |
|
out3 = gr.Textbox(label="Beam Search") |
|
out4 = gr.Textbox(label="Top-p Sampling") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## どの結果の続きが気になりますか?") |
|
radio1 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="結果の選択") |
|
output_text = gr.Textbox(label="初回の結果") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(f"## どの方法で続きを生成しますか?") |
|
history = gr.State() |
|
now_text = gr.TextArea(label="これまでの結果") |
|
radio2 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="続き生成のデコード方法") |
|
btn2 = gr.Button("続きを生成") |
|
next_text = gr.TextArea(label="今回の生成結果") |
|
|
|
|
|
|
|
btn1.click(fn=generate, inputs=[input_text, max_length, num_beams, p], outputs=[out1, out2, out3, out4]) |
|
out1.change(select_out1, inputs=[out1], outputs=[output_text, history, now_text]) |
|
radio1.change(select_out, inputs=[radio1, out1, out2, out3, out4], outputs=[output_text, history, now_text]) |
|
btn2.click(fn=generate_next, inputs=[history, radio2, max_length, num_beams, p], outputs=[history, now_text, next_text]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|