鹿子木 絵里奈
commited on
Commit
•
ddbadbf
1
Parent(s):
beaa78d
Add application file
Browse files- app.py +210 -0
- requirements.txt +67 -0
app.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import T5Tokenizer, AutoModelForCausalLM, GenerationConfig
|
3 |
+
|
4 |
+
# 0. モデルとトークナイザーの定義
|
5 |
+
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
|
6 |
+
tokenizer.do_lower_case = True # rinna/japanese-gpt2特有のハック
|
7 |
+
model = AutoModelForCausalLM.from_pretrained(
|
8 |
+
"rinna/japanese-gpt2-small",
|
9 |
+
pad_token_id=tokenizer.eos_token_id # warningを避けるために、padにEOSトークンを割りあてる
|
10 |
+
)
|
11 |
+
|
12 |
+
# 1. Gradioのコンポーネントのイベント処理用の関数の定義
|
13 |
+
def generate(text, max_length, num_beams, p):
|
14 |
+
"""初回のテキスト生成
|
15 |
+
|
16 |
+
テキスト生成を行うが、デコード方法によって異なる結果になることを示すための処理を行う。
|
17 |
+
指定されたパラメタを使って、異なる4つデコード方法を同時に出力する。
|
18 |
+
|
19 |
+
Args:
|
20 |
+
text: str
|
21 |
+
Stateから取得(続きを生成するためのプロンプト)
|
22 |
+
max_length: int
|
23 |
+
Sliderから取得(全てのデコード方法に共通のパラメタ。生成する単語数)
|
24 |
+
num_beams: int
|
25 |
+
Sliderから取得(Beam Searchのパラメタ)
|
26 |
+
p: int
|
27 |
+
Sliderから取得(Top-p Samplingのパラメタ)
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
tuple(str1, str2, str3)
|
31 |
+
str1: State(生成結果を入出力の状態に反映)
|
32 |
+
str2: TextArea(全文表示用のコンポーネントで使用)
|
33 |
+
str3: TextArea(今回生成した文を表示するコンポーネントで使用)
|
34 |
+
"""
|
35 |
+
# テキスト生成用のconfigクラスを使って、4パターンの設定を定義する。
|
36 |
+
generate_config_list = [
|
37 |
+
GenerationConfig(
|
38 |
+
max_new_tokens=max_length,
|
39 |
+
no_repeat_ngram_size=3,
|
40 |
+
num_beams=1, # beam幅の設定、2以上ではbeam searchになる。
|
41 |
+
do_sample=False # Samplingの設定
|
42 |
+
),
|
43 |
+
GenerationConfig(
|
44 |
+
max_new_tokens=max_length,
|
45 |
+
no_repeat_ngram_size=3,
|
46 |
+
num_beams=1,
|
47 |
+
do_sample=True
|
48 |
+
),
|
49 |
+
GenerationConfig(
|
50 |
+
max_new_tokens=max_length,
|
51 |
+
no_repeat_ngram_size=3,
|
52 |
+
num_beams=num_beams,
|
53 |
+
do_sample=False
|
54 |
+
),
|
55 |
+
GenerationConfig(
|
56 |
+
max_new_tokens=max_length,
|
57 |
+
no_repeat_ngram_size=3,
|
58 |
+
do_sample=True,
|
59 |
+
top_p=p # Top-p Samplingのパラメタの設定
|
60 |
+
)
|
61 |
+
]
|
62 |
+
generated_texts = []
|
63 |
+
|
64 |
+
inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt")["input_ids"]
|
65 |
+
for generate_config in generate_config_list:
|
66 |
+
# テキスト生成
|
67 |
+
output = model.generate(inputs, generation_config=generate_config)
|
68 |
+
generated = tokenizer.decode(output[0], skip_special_tokens=True)
|
69 |
+
# 読みやすくさの処理を行なって、リストに追加
|
70 |
+
generated_texts.append("。\n".join(generated.replace(" ", "").split("。")))
|
71 |
+
|
72 |
+
# gradioはtupleを想定している。これと同じ処理:return generated_texts[0], generated_texts[1], generated_texts[2]
|
73 |
+
# pythonのタプルは「,」によって生成される。丸括弧は省略可能。参考:https://note.nkmk.me/python-function-return-multiple-values/
|
74 |
+
return tuple(generated_texts)
|
75 |
+
|
76 |
+
def select_out1(out1):
|
77 |
+
"""out1が生成された時に、out1を後続の処理のデフォルト値に入力
|
78 |
+
"""
|
79 |
+
return out1, out1, out1
|
80 |
+
|
81 |
+
def select_out(radio, out1, out2, out3, out4):
|
82 |
+
"""後続の処理に使用する、初回の処理結果を選択する
|
83 |
+
"""
|
84 |
+
if radio == "1.Greedy":
|
85 |
+
out = out1
|
86 |
+
elif radio == "2.Sampling":
|
87 |
+
out = out2
|
88 |
+
elif radio == "3.Beam Search":
|
89 |
+
out = out3
|
90 |
+
else:
|
91 |
+
out = out4
|
92 |
+
return out, out, out
|
93 |
+
|
94 |
+
def generate_next(now_text, radio, max_length, num_beams, p):
|
95 |
+
"""続き生成
|
96 |
+
|
97 |
+
これまで出力したテキストを入力して受け取り、続きを生成する。
|
98 |
+
デコード方法を指定することができるが、そのパラメタは初回のテキスト生成と同じになる。
|
99 |
+
|
100 |
+
Args:
|
101 |
+
now_text: str
|
102 |
+
Stateから取得(続きを生成するためのプロンプト)
|
103 |
+
radio: str
|
104 |
+
Radioから取得(使用するデコード方法の名前)
|
105 |
+
max_length: int
|
106 |
+
Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
|
107 |
+
num_beams: int
|
108 |
+
Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
|
109 |
+
p: int
|
110 |
+
Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
next_text: str
|
114 |
+
State(生成結果を入出力の状態に反映)
|
115 |
+
next_text: str
|
116 |
+
TextArea(全文表示用のコンポーネントで使用)
|
117 |
+
gen_text: str
|
118 |
+
TextArea(今回生成した文を表示するコンポーネントで使用)
|
119 |
+
"""
|
120 |
+
# デコード方法の指定に合わせて、cofingを定義
|
121 |
+
if radio == "1.Greedy":
|
122 |
+
generate_config = GenerationConfig(
|
123 |
+
max_new_tokens=max_length,
|
124 |
+
no_repeat_ngram_size=3,
|
125 |
+
num_beams=1,
|
126 |
+
do_sample=False
|
127 |
+
)
|
128 |
+
elif radio == "2.Sampling":
|
129 |
+
generate_config = GenerationConfig(
|
130 |
+
max_new_tokens=max_length,
|
131 |
+
no_repeat_ngram_size=3,
|
132 |
+
num_beams=1,
|
133 |
+
do_sample=True
|
134 |
+
)
|
135 |
+
elif radio == "3.Beam Search":
|
136 |
+
generate_config = GenerationConfig(
|
137 |
+
max_new_tokens=max_length,
|
138 |
+
no_repeat_ngram_size=3,
|
139 |
+
num_beams=num_beams,
|
140 |
+
do_sample=False
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
generate_config = GenerationConfig(
|
144 |
+
max_new_tokens=max_length,
|
145 |
+
no_repeat_ngram_size=3,
|
146 |
+
do_sample=True,
|
147 |
+
top_p=p
|
148 |
+
)
|
149 |
+
|
150 |
+
# テキスト生成
|
151 |
+
inputs = tokenizer(now_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
|
152 |
+
output = model.generate(inputs, generation_config=generate_config)
|
153 |
+
generated = tokenizer.decode(output[0], skip_special_tokens=True)
|
154 |
+
# 結果の整形処理
|
155 |
+
next_text = "。\n".join(generated.replace(" ", "").split("。"))
|
156 |
+
gen_text = next_text[len(now_text)+1:] # 今回生成したテキストを抽出
|
157 |
+
|
158 |
+
return next_text, next_text, gen_text
|
159 |
+
|
160 |
+
# 2. GradioによるUI/イベント処理の定義
|
161 |
+
with gr.Blocks() as demo:
|
162 |
+
# 2.1. UI
|
163 |
+
gr.Markdown('''
|
164 |
+
# テキスト生成
|
165 |
+
テキストを入力すると、4パターンのデコード方法でテキスト生成を実行します。
|
166 |
+
## 4つのパターン(入門編)
|
167 |
+
1. Greedy: ビームサーチもサンプリングも行いません。毎回、最も確率の高い単語を選択します。
|
168 |
+
2. Sampling: モデルによって与えられた語彙全体の確率分布に基づいて次の単語を選択します。
|
169 |
+
3. Beam Search: 各タイムステップで複数の仮説を保持し、最終的に仮説ごとのシーケンス全体で最も高い確率を持つ仮説を選択します。
|
170 |
+
4. Top-p Sampling: 2の方法に関して、確率の和がpになる最小の単語にフィルタリングすることで、確率が低い単語が選ばれる可能性を無くします。
|
171 |
+
''')
|
172 |
+
|
173 |
+
with gr.Row(): # 行に分ける。なので、このブロック内にあるコンポーネントは横に並ぶ。
|
174 |
+
with gr.Column(): # さらに列に分ける。なので、このブロック内にあるコンポーネントは縦に並ぶ。
|
175 |
+
input_text = gr.Textbox(value="福岡のご飯は美味しい。", label="プロンプト")
|
176 |
+
max_length = gr.Slider(100, 1000, step=100, value=100, label="生成するテキストの長さ")
|
177 |
+
num_beams = gr.Slider(1, 10, step=1, value=6, label="beam幅")
|
178 |
+
p = gr.Slider(0, 1, step=0.01, value=0.92, label="p")
|
179 |
+
btn1 = gr.Button("4パターンで生成")
|
180 |
+
|
181 |
+
with gr.Column():
|
182 |
+
out1 = gr.Textbox(label="Greedy")
|
183 |
+
out2 = gr.Textbox(label="Sampling")
|
184 |
+
out3 = gr.Textbox(label="Beam Search")
|
185 |
+
out4 = gr.Textbox(label="Top-p Sampling")
|
186 |
+
|
187 |
+
with gr.Row():
|
188 |
+
with gr.Column():
|
189 |
+
gr.Markdown("## どの結果の続きが気になりますか?")
|
190 |
+
radio1 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="結果の選択")
|
191 |
+
output_text = gr.Textbox(label="初回の結果")
|
192 |
+
|
193 |
+
with gr.Row():
|
194 |
+
with gr.Column():
|
195 |
+
gr.Markdown(f"## どの方法で続きを生成しますか?")
|
196 |
+
history = gr.State()
|
197 |
+
now_text = gr.TextArea(label="これまでの結果")
|
198 |
+
radio2 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="続き生成のデコード方法")
|
199 |
+
btn2 = gr.Button("続きを生成")
|
200 |
+
next_text = gr.TextArea(label="今回の生成結果")
|
201 |
+
|
202 |
+
|
203 |
+
# 2.2 イベント処理
|
204 |
+
btn1.click(fn=generate, inputs=[input_text, max_length, num_beams, p], outputs=[out1, out2, out3, out4])
|
205 |
+
out1.change(select_out1, inputs=[out1], outputs=[output_text, history, now_text])
|
206 |
+
radio1.change(select_out, inputs=[radio1, out1, out2, out3, out4], outputs=[output_text, history, now_text])
|
207 |
+
btn2.click(fn=generate_next, inputs=[history, radio2, max_length, num_beams, p], outputs=[history, now_text, next_text])
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==22.1.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==4.2.2
|
5 |
+
anyio==3.6.2
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==22.2.0
|
8 |
+
certifi==2022.12.7
|
9 |
+
charset-normalizer==2.1.1
|
10 |
+
click==8.1.3
|
11 |
+
contourpy==1.0.7
|
12 |
+
cycler==0.11.0
|
13 |
+
entrypoints==0.4
|
14 |
+
fastapi==0.89.1
|
15 |
+
ffmpy==0.3.0
|
16 |
+
filelock==3.9.0
|
17 |
+
fonttools==4.38.0
|
18 |
+
frozenlist==1.3.3
|
19 |
+
fsspec==2023.1.0
|
20 |
+
gradio==3.17.1
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==0.16.3
|
23 |
+
httpx==0.23.3
|
24 |
+
huggingface-hub==0.12.0
|
25 |
+
idna==3.4
|
26 |
+
Jinja2==3.1.2
|
27 |
+
jsonschema==4.17.3
|
28 |
+
kiwisolver==1.4.4
|
29 |
+
linkify-it-py==1.0.3
|
30 |
+
markdown-it-py==2.1.0
|
31 |
+
MarkupSafe==2.1.2
|
32 |
+
matplotlib==3.6.3
|
33 |
+
mdit-py-plugins==0.3.3
|
34 |
+
mdurl==0.1.2
|
35 |
+
multidict==6.0.4
|
36 |
+
numpy==1.24.2
|
37 |
+
orjson==3.8.5
|
38 |
+
packaging==23.0
|
39 |
+
pandas==1.5.3
|
40 |
+
Pillow==9.4.0
|
41 |
+
pycryptodome==3.17
|
42 |
+
pydantic==1.10.4
|
43 |
+
pydub==0.25.1
|
44 |
+
pyparsing==3.0.9
|
45 |
+
pyrsistent==0.19.3
|
46 |
+
python-dateutil==2.8.2
|
47 |
+
python-multipart==0.0.5
|
48 |
+
pytz==2022.7.1
|
49 |
+
PyYAML==6.0
|
50 |
+
regex==2022.10.31
|
51 |
+
requests==2.28.2
|
52 |
+
rfc3986==1.5.0
|
53 |
+
sentencepiece==0.1.97
|
54 |
+
six==1.16.0
|
55 |
+
sniffio==1.3.0
|
56 |
+
starlette==0.22.0
|
57 |
+
tokenizers==0.13.2
|
58 |
+
toolz==0.12.0
|
59 |
+
torch==1.13.1
|
60 |
+
tqdm==4.64.1
|
61 |
+
transformers==4.26.0
|
62 |
+
typing_extensions==4.4.0
|
63 |
+
uc-micro-py==1.0.1
|
64 |
+
urllib3==1.26.14
|
65 |
+
uvicorn==0.20.0
|
66 |
+
websockets==10.4
|
67 |
+
yarl==1.8.2
|