sakuraumi commited on
Commit
88dd46b
1 Parent(s): 6c3e0fe

Create translate_epub.py

Browse files
Files changed (1) hide show
  1. translate_epub.py +248 -0
translate_epub.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
2
+ from argparse import ArgumentParser
3
+ import time
4
+ import os, re
5
+ import fnmatch
6
+ import glob
7
+ import shutil
8
+ import zipfile
9
+ from tqdm import tqdm
10
+
11
+ def find_all_htmls(root_dir):
12
+ html_files = []
13
+ for foldername, subfolders, filenames in os.walk(root_dir):
14
+ for extension in ['*.html', '*.xhtml', '*.htm']:
15
+ for filename in fnmatch.filter(filenames, extension):
16
+ file_path = os.path.join(foldername, filename)
17
+ html_files.append(file_path)
18
+ return html_files
19
+
20
+ def get_html_text_list(epub_path, text_length):
21
+ data_list = []
22
+
23
+ def clean_text(text):
24
+ text=re.sub(r'<rt[^>]*?>.*?</rt>', '', text)
25
+ text=re.sub(r'<[^>]*>|\n', '', text)
26
+ return text
27
+
28
+ with open(epub_path, 'r', encoding='utf-8') as f:
29
+ file_text = f.read()
30
+ matches = re.finditer(r'<(h[1-6]|p).*?>(.+?)</\1>', file_text, flags=re.DOTALL)
31
+ if not matches:
32
+ print("perhaps this file is a struct file")
33
+ return data_list, file_text
34
+ groups = []
35
+ text = ''
36
+ pre_end = 0
37
+ for match in matches:
38
+ if len(text + match.group(2)) <= text_length:
39
+ new_text = clean_text(match.group(2))
40
+ if new_text:
41
+ groups.append(match)
42
+ text += '\n' + new_text
43
+ else:
44
+ data_list.append((text, groups, pre_end))
45
+ pre_end = groups[-1].end()
46
+ new_text = clean_text(match.group(2))
47
+ if new_text:
48
+ groups = [match]
49
+ text = clean_text(match.group(2))
50
+ else:
51
+ groups = []
52
+ text = ''
53
+
54
+ if text:
55
+ data_list.append((text, groups, pre_end))
56
+ # TEST:
57
+ # for d in data_list:
58
+ # print(f"{len(d[0])}", end=" ")
59
+ return data_list, file_text
60
+
61
+ def get_prompt(input, model_version):
62
+ if model_version == '0.5' or model_version == '0.8':
63
+ prompt = "<reserved_106>将下面的日文文本翻译成中文:" + input + "<reserved_107>"
64
+ return prompt
65
+ if model_version == '0.7':
66
+ prompt = f"<|im_start|>user\n将下面的日文文本翻译成中文:{input}<|im_end|>\n<|im_start|>assistant\n"
67
+ return prompt
68
+ if model_version == '0.1':
69
+ prompt = "Human: \n将下面的日文文本翻译成中文:" + input + "\n\nAssistant: \n"
70
+ return prompt
71
+ if model_version == '0.4':
72
+ prompt = "User: 将下面的日文文本翻译成中文:" + input + "\nAssistant: "
73
+ return prompt
74
+
75
+ raise ValueError(f"Wrong model version{model_version}, please view https://huggingface.co/sakuraumi/Sakura-13B-Galgame")
76
+
77
+ def split_response(response, model_version):
78
+ response = response.replace("</s>", "")
79
+ if model_version == '0.5' or model_version == '0.8':
80
+ output = response.split("<reserved_107>")[1]
81
+ return output
82
+ if model_version == '0.7':
83
+ output = response.split("<|im_start|>assistant\n")[1]
84
+ return output
85
+ if model_version == '0.1':
86
+ output = response.split("\n\nAssistant: \n")[1]
87
+ return output
88
+ if model_version == '0.4':
89
+ output = response.split("\nAssistant: ")[1]
90
+ return output
91
+
92
+ raise ValueError(f"Wrong model version{model_version}, please view https://huggingface.co/sakuraumi/Sakura-13B-Galgame")
93
+
94
+ def detect_degeneration(generation: list, model_version):
95
+ if model_version != "0.8":
96
+ return False
97
+ i = generation.index(196)
98
+ generation = generation[i+1:]
99
+ if len(generation) >= 1023:
100
+ print("model degeneration detected, retrying...")
101
+ return True
102
+ else:
103
+ return False
104
+
105
+ def get_model_response(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str, model_version: str, generation_config: GenerationConfig, text_length: int):
106
+
107
+ backup_generation_config_stage2 = GenerationConfig(
108
+ temperature=1,
109
+ top_p=0.6,
110
+ top_k=40,
111
+ num_beams=1,
112
+ bos_token_id=1,
113
+ eos_token_id=2,
114
+ pad_token_id=0,
115
+ max_new_tokens=1024,
116
+ min_new_tokens=1,
117
+ do_sample=True
118
+ )
119
+
120
+ backup_generation_config_stage3 = GenerationConfig(
121
+ top_k=5,
122
+ num_beams=1,
123
+ bos_token_id=1,
124
+ eos_token_id=2,
125
+ pad_token_id=0,
126
+ max_new_tokens=1024,
127
+ min_new_tokens=1,
128
+ penalty_alpha=0.3
129
+ )
130
+
131
+ backup_generation_config = [backup_generation_config_stage2, backup_generation_config_stage3]
132
+
133
+ generation = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device), generation_config=generation_config)[0]
134
+ if len(generation) > 2 * text_length:
135
+ stage = 0
136
+ while detect_degeneration(list(generation), model_version):
137
+ stage += 1
138
+ if stage > 2:
139
+ print("model degeneration cannot be avoided.")
140
+ break
141
+ generation = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device), generation_config=backup_generation_config[stage-1])[0]
142
+ response = tokenizer.decode(generation)
143
+ output = split_response(response, model_version)
144
+ return output
145
+
146
+
147
+ def main():
148
+ parser = ArgumentParser()
149
+ parser.add_argument("--model_name_or_path", type=str, default="SakuraLLM/Sakura-13B-LNovel-v0.8", help="model huggingface id or local path.")
150
+ parser.add_argument("--use_gptq_model", action="store_true", help="whether your model is gptq quantized.")
151
+ parser.add_argument("--model_version", type=str, default="0.8", help="model version written on huggingface readme, now we have ['0.1', '0.4', '0.5', '0.7', '0.8']")
152
+ parser.add_argument("--data_path", type=str, default="", help="file path of the epub you want to translate.")
153
+ parser.add_argument("--data_folder", type=str, default="", help="folder path of the epubs you want to translate.")
154
+ parser.add_argument("--output_folder", type=str, default="", help="save folder path of the epubs model translated.")
155
+ parser.add_argument("--text_length", type=int, default=512, help="input max length in each inference.")
156
+ args = parser.parse_args()
157
+
158
+ if args.use_gptq_model:
159
+ from auto_gptq import AutoGPTQForCausalLM
160
+
161
+ generation_config = GenerationConfig(
162
+ temperature=0.1,
163
+ top_p=0.3,
164
+ top_k=40,
165
+ num_beams=1,
166
+ bos_token_id=1,
167
+ eos_token_id=2,
168
+ pad_token_id=0,
169
+ max_new_tokens=1024,
170
+ min_new_tokens=1,
171
+ do_sample=True
172
+ )
173
+
174
+ print("Loading model...")
175
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False, trust_remote_code=True)
176
+
177
+ if args.use_gptq_model:
178
+ model = AutoGPTQForCausalLM.from_quantized(args.model_name_or_path, device="cuda:0", trust_remote_code=True)
179
+ else:
180
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, device_map="auto", trust_remote_code=True)
181
+
182
+ print("Start translating...")
183
+ start = time.time()
184
+
185
+ epub_list = []
186
+ save_list = []
187
+ if args.data_path:
188
+ epub_list.append(args.data_path)
189
+ save_list.append(os.path.join(args.output_folder, os.path.basename(args.data_path)))
190
+ if args.data_folder:
191
+ os.makedirs(args.output_folder, exist_ok=True)
192
+ for f in os.listdir(args.data_folder):
193
+ if f.endswith(".epub"):
194
+ epub_list.append(os.path.join(args.data_folder, f))
195
+ save_list.append(os.path.join(args.output_folder, f))
196
+
197
+ for epub_path, save_path in zip(epub_list, save_list):
198
+ print(f"translating {epub_path}...")
199
+ start_epub = time.time()
200
+
201
+ if os.path.exists('./temp'):
202
+ shutil.rmtree('./temp')
203
+ with zipfile.ZipFile(epub_path, 'r') as f:
204
+ f.extractall('./temp')
205
+
206
+ for html_path in find_all_htmls('./temp'):
207
+ print(f"\ttranslating {html_path}...")
208
+ start_html = time.time()
209
+
210
+ translated = ''
211
+ data_list, file_text = get_html_text_list(html_path, args.text_length)
212
+ if len(data_list) == 0:
213
+ continue
214
+ for text, groups, pre_end in tqdm(data_list):
215
+ prompt = get_prompt(text, args.model_version)
216
+ output = get_model_response(model, tokenizer, prompt, args.model_version, generation_config, args.text_length)
217
+ texts = output.strip().split('\n')
218
+ if len(texts) < len(groups):
219
+ texts += [''] * (len(groups) - len(texts))
220
+ else:
221
+ texts = texts[:len(groups)-1] + ['<br/>'.join(texts[len(groups)-1:])]
222
+ for t, match in zip(texts, groups):
223
+ t = match.group(0).replace(match.group(2), t)
224
+ translated += file_text[pre_end:match.start()] + t
225
+ pre_end = match.end()
226
+
227
+ translated += file_text[data_list[-1][1][-1].end():]
228
+ with open(html_path, 'w', encoding='utf-8') as f:
229
+ f.write(translated)
230
+
231
+ end_html = time.time()
232
+ print(f"\t{html_path} translated, used time: ", end_html-start_html)
233
+
234
+ with zipfile.ZipFile(save_path, 'w', zipfile.ZIP_DEFLATED) as f:
235
+ for file_path in glob.glob(f'./temp/**', recursive=True):
236
+ if not os.path.isdir(file_path):
237
+ relative_path = os.path.relpath(file_path, './temp')
238
+ f.write(file_path, relative_path)
239
+ shutil.rmtree('./temp')
240
+
241
+ end_epub = time.time()
242
+ print(f"{epub_path} translated, used time: ", end_epub-start_epub)
243
+
244
+ end = time.time()
245
+ print("translation completed, used time: ", end-start)
246
+
247
+ if __name__ == "__main__":
248
+ main()