johnsu6616 commited on
Commit
856e316
1 Parent(s): a384160
Files changed (2) hide show
  1. app.py +147 -81
  2. requirements.txt +3 -3
app.py CHANGED
@@ -5,8 +5,8 @@ import gradio as gr
5
  import torch
6
 
7
  from transformers import AutoModelForCausalLM
8
- from transformers import AutoTokenizer
9
  from transformers import AutoModelForSeq2SeqLM
 
10
 
11
  from transformers import AutoProcessor
12
 
@@ -14,12 +14,16 @@ from transformers import pipeline
14
 
15
  from transformers import set_seed
16
 
 
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
20
  big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
21
 
22
- text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
 
 
23
 
24
  zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
25
  zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
@@ -27,17 +31,14 @@ zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
27
  en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
28
  en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
29
 
30
-
31
  def translate_zh2en(text):
32
  with torch.no_grad():
33
 
 
34
  text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text)
35
  text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text)
36
-
37
  text = text.replace('\n', ',')
38
-
39
  text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text)
40
-
41
  text = re.sub(r',+', ',', text)
42
 
43
  encoded = zh2en_tokenizer([text], return_tensors='pt')
@@ -45,80 +46,68 @@ def translate_zh2en(text):
45
  result = zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
46
 
47
  result = result.strip()
48
- if result != "No,no," :
 
49
  result = text
50
- return result
51
 
 
52
 
53
  def translate_en2zh(text):
54
  with torch.no_grad():
55
 
56
  encoded = en2zh_tokenizer([text], return_tensors="pt")
57
  sequences = en2zh_model.generate(**encoded)
58
- return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
59
-
60
- def test05(text):
61
-
62
- return text
63
 
64
- def test06(text):
 
65
 
66
- return text
 
 
 
 
 
67
 
 
68
 
69
- def text_generate(text):
70
  seed = random.randint(100, 1000000)
71
  set_seed(seed)
72
-
73
  text_in_english = translate_zh2en(text)
74
- result = ""
75
- for _ in range(6):
76
- sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
77
- list = []
78
- for sequence in sequences:
79
-
80
-
81
- line = sequence['generated_text'].strip()
82
-
83
- if line != text_in_english and len(line) > (len(text_in_english) + 4):
84
 
85
- list.append(translate_en2zh(line)+"\n")
86
- list.append(line+"\n")
87
- list.append("\n")
88
 
89
- result = "".join(list)
 
 
90
 
91
- result = re.sub('[^ ]+\.[^ ]+', '', result)
92
-
93
- result = result.replace('<', '').replace('>', '')
94
-
95
- if result != '':
96
- break
97
- return result
98
 
 
 
99
 
100
- def load_prompter():
101
- prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
102
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
103
- tokenizer.pad_token = tokenizer.eos_token
104
- tokenizer.padding_side = "left"
105
- return prompter_model, tokenizer
106
 
107
- prompter_model, prompter_tokenizer = load_prompter()
108
 
109
- def generate_prompter(text):
110
- text = translate_zh2en(text)
111
 
112
- input_ids = prompter_tokenizer(text.strip()+" Rephrase:", return_tensors="pt").input_ids
113
- eos_id = prompter_tokenizer.eos_token_id
114
  outputs = prompter_model.generate(
115
  input_ids,
116
  do_sample=False,
117
- max_new_tokens=75,
118
  num_beams=3,
119
  num_return_sequences=3,
120
- eos_token_id=eos_id,
121
  pad_token_id=eos_id,
 
122
  length_penalty=-1.0
123
  )
124
  output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
@@ -134,69 +123,143 @@ def generate_prompter(text):
134
  result.append("\n")
135
  return "".join(result)
136
 
137
- def combine_text(text):
138
- text01 = generate_prompter(text)
139
- text02 = text_generate(text)
140
- return text01,text02
141
 
142
- def get_prompt_from_image(input_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  image = input_image.convert('RGB')
144
  pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
145
- generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
146
  generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
147
- result01 = generate_prompter(generated_caption)
148
- result02 = text_generate(generated_caption)
149
- return result01,result02
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  with gr.Blocks() as block:
152
  with gr.Column():
153
  with gr.Tab('工作區'):
154
  with gr.Row():
155
  input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...')
156
- input_image = gr.Image(type='pil')
157
  with gr.Row():
158
  txt_prompter_btn = gr.Button('文生文')
159
  pic_prompter_btn = gr.Button('圖生文')
160
  with gr.Row():
161
- Textbox_1 = gr.Textbox(lines=6, label='生成方式A')
 
 
 
 
 
 
 
 
 
162
  with gr.Row():
163
- Textbox_2 = gr.Textbox(lines=6, label='生成方式B')
 
164
  with gr.Tab('測試區'):
165
  with gr.Row():
166
  input_test01 = gr.Textbox(lines=2, label='中英翻譯', placeholder='在此输入文字...')
167
  test01_btn = gr.Button('執行')
168
  Textbox_test01 = gr.Textbox(lines=2, label='輸出結果')
169
  with gr.Row():
170
- input_test02 = gr.Textbox(lines=2, label='英中翻譯', placeholder='在此输入文字...')
171
  test02_btn = gr.Button('執行')
172
  Textbox_test02 = gr.Textbox(lines=2, label='輸出結果')
173
  with gr.Row():
174
- input_test03 = gr.Textbox(lines=2, label='SD模式', placeholder='在此输入文字...')
175
  test03_btn = gr.Button('執行')
176
  Textbox_test03 = gr.Textbox(lines=2, label='輸出結果')
177
  with gr.Row():
178
- input_test04 = gr.Textbox(lines=2, label='瞎掰模式', placeholder='在此输入文字...')
179
  test04_btn = gr.Button('執行')
180
  Textbox_test04 = gr.Textbox(lines=2, label='輸出結果')
181
  with gr.Row():
182
- input_test05 = gr.Textbox(lines=2, label='沒作用', placeholder='在此输入文字...')
183
  test05_btn = gr.Button('執行')
184
  Textbox_test05 = gr.Textbox(lines=2, label='輸出結果')
185
  with gr.Row():
186
- input_test06 = gr.Textbox(lines=2, label='沒作用', placeholder='在此输入文字...')
187
  test06_btn = gr.Button('執行')
188
  Textbox_test06 = gr.Textbox(lines=2, label='輸出結果')
189
 
190
- txt_prompter_btn.click(
191
- fn=combine_text,
192
- inputs=input_text,
193
  outputs=[Textbox_1,Textbox_2]
194
- )
195
-
196
  pic_prompter_btn.click(
197
  fn=get_prompt_from_image,
198
- inputs=input_image,
199
- outputs=[Textbox_1,Textbox_2]
200
  )
201
 
202
  test01_btn.click(
@@ -212,25 +275,28 @@ with gr.Blocks() as block:
212
  )
213
 
214
  test03_btn.click(
215
- fn=generate_prompter,
216
  inputs=input_test03,
217
  outputs=Textbox_test03
218
  )
219
 
220
  test04_btn.click(
221
- fn=text_generate,
222
  inputs=input_test04,
223
  outputs=Textbox_test04
224
  )
 
225
  test05_btn.click(
226
- fn=test05,
227
  inputs=input_test05,
228
  outputs=Textbox_test05
229
  )
 
230
  test06_btn.click(
231
- fn=test06,
232
- inputs=input_test06,
233
- outputs=Textbox_test06
234
  )
235
 
236
  block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')
 
 
5
  import torch
6
 
7
  from transformers import AutoModelForCausalLM
 
8
  from transformers import AutoModelForSeq2SeqLM
9
+ from transformers import AutoTokenizer
10
 
11
  from transformers import AutoProcessor
12
 
 
14
 
15
  from transformers import set_seed
16
 
17
+ global ButtonIndex
18
+
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
22
  big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
23
 
24
+ pipeline_01 = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
25
+ pipeline_02 = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2')
26
+ pipeline_03 = pipeline('text-generation', model='johnsu6616/ModelExport')
27
 
28
  zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
29
  zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
 
31
  en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
32
  en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
33
 
 
34
  def translate_zh2en(text):
35
  with torch.no_grad():
36
 
37
+ text = re.sub(r"[:\-–.!;?_#]", '', text)
38
  text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text)
39
  text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text)
 
40
  text = text.replace('\n', ',')
 
41
  text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text)
 
42
  text = re.sub(r',+', ',', text)
43
 
44
  encoded = zh2en_tokenizer([text], return_tensors='pt')
 
46
  result = zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
47
 
48
  result = result.strip()
49
+
50
+ if result == "No,no," :
51
  result = text
 
52
 
53
+ return result
54
 
55
  def translate_en2zh(text):
56
  with torch.no_grad():
57
 
58
  encoded = en2zh_tokenizer([text], return_tensors="pt")
59
  sequences = en2zh_model.generate(**encoded)
60
+ result = en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
 
 
 
 
61
 
62
+ result = re.sub(r'(\b\w+\b)(?:\W+\1\b)+', r'\1', result)
63
+ return result
64
 
65
+ def load_prompter():
66
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
67
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+ tokenizer.padding_side = "left"
70
+ return prompter_model, tokenizer
71
 
72
+ prompter_model, prompter_tokenizer = load_prompter()
73
 
74
+ def generate_prompter_pipeline_01(text):
75
  seed = random.randint(100, 1000000)
76
  set_seed(seed)
 
77
  text_in_english = translate_zh2en(text)
78
+ response = pipeline_01(text_in_english, max_new_tokens=80, num_return_sequences=3)
79
+ response_list = []
80
+ for x in response:
81
+ resp = x['generated_text'].strip()
 
 
 
 
 
 
82
 
83
+ if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
 
 
84
 
85
+ response_list.append(translate_en2zh(resp)+"\n")
86
+ response_list.append(resp+"\n")
87
+ response_list.append("\n")
88
 
89
+ result = "".join(response_list)
90
+ result = re.sub('[^ ]+\.[^ ]+', '', result)
91
+ result = result.replace('<', '').replace('>', '')
 
 
 
 
92
 
93
+ if result != '':
94
+ return result
95
 
96
+ def generate_prompter_tokenizer_01(text):
 
 
 
 
 
97
 
98
+ text_in_english = translate_zh2en(text)
99
 
100
+ input_ids = prompter_tokenizer(text_in_english.strip()+" Rephrase:", return_tensors="pt").input_ids
 
101
 
102
+ eos_id = 50256
 
103
  outputs = prompter_model.generate(
104
  input_ids,
105
  do_sample=False,
106
+ max_new_tokens=80,
107
  num_beams=3,
108
  num_return_sequences=3,
 
109
  pad_token_id=eos_id,
110
+ eos_token_id=eos_id,
111
  length_penalty=-1.0
112
  )
113
  output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
123
  result.append("\n")
124
  return "".join(result)
125
 
 
 
 
 
126
 
127
+ def generate_prompter_pipeline_02(text):
128
+ seed = random.randint(100, 1000000)
129
+ set_seed(seed)
130
+ text_in_english = translate_zh2en(text)
131
+ response = pipeline_02(text_in_english, max_new_tokens=80, num_return_sequences=3)
132
+ response_list = []
133
+ for x in response:
134
+ resp = x['generated_text'].strip()
135
+ if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
136
+
137
+ response_list.append(translate_en2zh(resp)+"\n")
138
+ response_list.append(resp+"\n")
139
+ response_list.append("\n")
140
+
141
+ result = "".join(response_list)
142
+ result = re.sub('[^ ]+\.[^ ]+','', result)
143
+ result = result.replace("<", "").replace(">", "")
144
+
145
+ if result != "":
146
+ return result
147
+
148
+ def generate_prompter_pipeline_03(text):
149
+ seed = random.randint(100, 1000000)
150
+ set_seed(seed)
151
+ text_in_english = translate_zh2en(text)
152
+ response = pipeline_03(text_in_english, max_new_tokens=80, num_return_sequences=3)
153
+ response_list = []
154
+ for x in response:
155
+ resp = x['generated_text'].strip()
156
+ if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
157
+
158
+
159
+ response_list.append(translate_en2zh(resp)+"\n")
160
+ response_list.append(resp+"\n")
161
+ response_list.append("\n")
162
+
163
+ result = "".join(response_list)
164
+ result = re.sub('[^ ]+\.[^ ]+','', result)
165
+ result = result.replace("<", "").replace(">", "")
166
+
167
+ if result != "":
168
+ return result
169
+
170
+ def generate_render(text,choice):
171
+ if choice == '★pipeline模式(succinctly)':
172
+ outputs = generate_prompter_pipeline_01(text)
173
+ return outputs,choice
174
+ elif choice == '★★tokenizer模式':
175
+ outputs = generate_prompter_tokenizer_01(text)
176
+ return outputs,choice
177
+ elif choice == '★★★pipeline模型(Gustavosta)':
178
+ outputs = generate_prompter_pipeline_02(text)
179
+ return outputs,choice
180
+ elif choice == 'pipeline模型(John)_自訓測試,資料不穩定':
181
+ outputs = generate_prompter_pipeline_03(text)
182
+ return outputs,choice
183
+
184
+ def get_prompt_from_image(input_image,choice):
185
  image = input_image.convert('RGB')
186
  pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
187
+ generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_new_tokens=80)
188
  generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
189
+ text = re.sub(r"[:\-–.!;?_#]", '', generated_caption)
190
+
191
+ if choice == '★pipeline模式(succinctly)':
192
+ outputs = generate_prompter_pipeline_01(text)
193
+ return outputs
194
+ elif choice == '★★tokenizer模式':
195
+ outputs = generate_prompter_tokenizer_01(text)
196
+ return outputs
197
+ elif choice == '★★★pipeline模型(Gustavosta)':
198
+ outputs = generate_prompter_pipeline_02(text)
199
+ return outputs
200
+ elif choice == 'pipeline模型(John)_自訓測試,資料不穩定':
201
+ outputs = generate_prompter_pipeline_03(text)
202
+ return outputs
203
 
204
  with gr.Blocks() as block:
205
  with gr.Column():
206
  with gr.Tab('工作區'):
207
  with gr.Row():
208
  input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...')
209
+ input_image = gr.Image(type='pil', label="選擇圖片(辨識度不佳)")
210
  with gr.Row():
211
  txt_prompter_btn = gr.Button('文生文')
212
  pic_prompter_btn = gr.Button('圖生文')
213
  with gr.Row():
214
+ radio_btn = gr.Radio(
215
+ label="請選擇產出方式",
216
+ choices=['★pipeline模式(succinctly)', '★★tokenizer模式', '★★★pipeline模型(Gustavosta)',
217
+ 'pipeline模型(John)_自訓測試,資料不穩定'],
218
+
219
+ value='★pipeline模式(succinctly)'
220
+ )
221
+
222
+ with gr.Row():
223
+ Textbox_1 = gr.Textbox(lines=6, label='提示詞生成')
224
  with gr.Row():
225
+ Textbox_2 = gr.Textbox(lines=6, label='測試資訊')
226
+
227
  with gr.Tab('測試區'):
228
  with gr.Row():
229
  input_test01 = gr.Textbox(lines=2, label='中英翻譯', placeholder='在此输入文字...')
230
  test01_btn = gr.Button('執行')
231
  Textbox_test01 = gr.Textbox(lines=2, label='輸出結果')
232
  with gr.Row():
233
+ input_test02 = gr.Textbox(lines=2, label='英中翻譯(不精準)', placeholder='在此输入文字...')
234
  test02_btn = gr.Button('執行')
235
  Textbox_test02 = gr.Textbox(lines=2, label='輸出結果')
236
  with gr.Row():
237
+ input_test03 = gr.Textbox(lines=2, label='★pipeline模式(succinctly)', placeholder='在此输入文字...')
238
  test03_btn = gr.Button('執行')
239
  Textbox_test03 = gr.Textbox(lines=2, label='輸出結果')
240
  with gr.Row():
241
+ input_test04 = gr.Textbox(lines=2, label='★★tokenizer模式', placeholder='在此输入文字...')
242
  test04_btn = gr.Button('執行')
243
  Textbox_test04 = gr.Textbox(lines=2, label='輸出結果')
244
  with gr.Row():
245
+ input_test05 = gr.Textbox(lines=2, label='★★★pipeline模型(Gustavosta)', placeholder='在此输入文字...')
246
  test05_btn = gr.Button('執行')
247
  Textbox_test05 = gr.Textbox(lines=2, label='輸出結果')
248
  with gr.Row():
249
+ input_test06 = gr.Textbox(lines=2, label='pipeline模型(John)_自訓測試,資料不穩定', placeholder='在此输入文字...')
250
  test06_btn = gr.Button('執行')
251
  Textbox_test06 = gr.Textbox(lines=2, label='輸出結果')
252
 
253
+ txt_prompter_btn.click (
254
+ fn=generate_render,
255
+ inputs=[input_text,radio_btn],
256
  outputs=[Textbox_1,Textbox_2]
257
+ )
258
+
259
  pic_prompter_btn.click(
260
  fn=get_prompt_from_image,
261
+ inputs=[input_image,radio_btn],
262
+ outputs=Textbox_1
263
  )
264
 
265
  test01_btn.click(
 
275
  )
276
 
277
  test03_btn.click(
278
+ fn= generate_prompter_pipeline_01,
279
  inputs=input_test03,
280
  outputs=Textbox_test03
281
  )
282
 
283
  test04_btn.click(
284
+ fn= generate_prompter_tokenizer_01,
285
  inputs=input_test04,
286
  outputs=Textbox_test04
287
  )
288
+
289
  test05_btn.click(
290
+ fn= generate_prompter_pipeline_02,
291
  inputs=input_test05,
292
  outputs=Textbox_test05
293
  )
294
+
295
  test06_btn.click(
296
+ fn= generate_prompter_pipeline_03,
297
+ inputs= input_test06,
298
+ outputs= Textbox_test06
299
  )
300
 
301
  block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')
302
+
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- transformers==4.29.1
2
  torch==2.0.0
3
- pytorch_lightning==1.7.7
4
  gradio==3.30.0
5
- sentencepiece==0.1.97
6
  sacremoses==0.0.53
 
1
+ transformers==4.29.2
2
  torch==2.0.0
3
+ pytorch_lightning==2.0.2
4
  gradio==3.30.0
5
+ sentencepiece==0.1.99
6
  sacremoses==0.0.53