from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from argparse import ArgumentParser
import time
import os, re
import fnmatch
import glob
import shutil
import zipfile
from tqdm import tqdm
def find_all_htmls(root_dir):
html_files = []
for foldername, subfolders, filenames in os.walk(root_dir):
for extension in ['*.html', '*.xhtml', '*.htm']:
for filename in fnmatch.filter(filenames, extension):
file_path = os.path.join(foldername, filename)
html_files.append(file_path)
return html_files
def get_html_text_list(epub_path, text_length):
data_list = []
def clean_text(text):
text=re.sub(r'', '', text)
text=re.sub(r'<[^>]*>|\n', '', text)
return text
with open(epub_path, 'r', encoding='utf-8') as f:
file_text = f.read()
matches = re.finditer(r'<(h[1-6]|p).*?>(.+?)\1>', file_text, flags=re.DOTALL)
if not matches:
print("perhaps this file is a struct file")
return data_list, file_text
groups = []
text = ''
pre_end = 0
for match in matches:
if len(text + match.group(2)) <= text_length:
new_text = clean_text(match.group(2))
if new_text:
groups.append(match)
text += '\n' + new_text
else:
data_list.append((text, groups, pre_end))
pre_end = groups[-1].end()
new_text = clean_text(match.group(2))
if new_text:
groups = [match]
text = clean_text(match.group(2))
else:
groups = []
text = ''
if text:
data_list.append((text, groups, pre_end))
# TEST:
# for d in data_list:
# print(f"{len(d[0])}", end=" ")
return data_list, file_text
def get_prompt(input, model_version):
if model_version == '0.5' or model_version == '0.8':
prompt = "将下面的日文文本翻译成中文:" + input + ""
return prompt
if model_version == '0.7':
prompt = f"<|im_start|>user\n将下面的日文文本翻译成中文:{input}<|im_end|>\n<|im_start|>assistant\n"
return prompt
if model_version == '0.1':
prompt = "Human: \n将下面的日文文本翻译成中文:" + input + "\n\nAssistant: \n"
return prompt
if model_version == '0.4':
prompt = "User: 将下面的日文文本翻译成中文:" + input + "\nAssistant: "
return prompt
raise ValueError(f"Wrong model version{model_version}, please view https://huggingface.co/sakuraumi/Sakura-13B-Galgame")
def split_response(response, model_version):
response = response.replace("", "")
if model_version == '0.5' or model_version == '0.8':
output = response.split("")[1]
return output
if model_version == '0.7':
output = response.split("<|im_start|>assistant\n")[1]
return output
if model_version == '0.1':
output = response.split("\n\nAssistant: \n")[1]
return output
if model_version == '0.4':
output = response.split("\nAssistant: ")[1]
return output
raise ValueError(f"Wrong model version{model_version}, please view https://huggingface.co/sakuraumi/Sakura-13B-Galgame")
def detect_degeneration(generation: list, model_version):
if model_version != "0.8":
return False
i = generation.index(196)
generation = generation[i+1:]
if len(generation) >= 1023:
print("model degeneration detected, retrying...")
return True
else:
return False
def get_model_response(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str, model_version: str, generation_config: GenerationConfig, text_length: int):
backup_generation_config_stage2 = GenerationConfig(
temperature=1,
top_p=0.6,
top_k=40,
num_beams=1,
bos_token_id=1,
eos_token_id=2,
pad_token_id=0,
max_new_tokens=1024,
min_new_tokens=1,
do_sample=True
)
backup_generation_config_stage3 = GenerationConfig(
top_k=5,
num_beams=1,
bos_token_id=1,
eos_token_id=2,
pad_token_id=0,
max_new_tokens=1024,
min_new_tokens=1,
penalty_alpha=0.3
)
backup_generation_config = [backup_generation_config_stage2, backup_generation_config_stage3]
generation = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device), generation_config=generation_config)[0]
if len(generation) > 2 * text_length:
stage = 0
while detect_degeneration(list(generation), model_version):
stage += 1
if stage > 2:
print("model degeneration cannot be avoided.")
break
generation = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device), generation_config=backup_generation_config[stage-1])[0]
response = tokenizer.decode(generation)
output = split_response(response, model_version)
return output
def main():
parser = ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="SakuraLLM/Sakura-13B-LNovel-v0.8", help="model huggingface id or local path.")
parser.add_argument("--use_gptq_model", action="store_true", help="whether your model is gptq quantized.")
parser.add_argument("--model_version", type=str, default="0.8", help="model version written on huggingface readme, now we have ['0.1', '0.4', '0.5', '0.7', '0.8']")
parser.add_argument("--data_path", type=str, default="", help="file path of the epub you want to translate.")
parser.add_argument("--data_folder", type=str, default="", help="folder path of the epubs you want to translate.")
parser.add_argument("--output_folder", type=str, default="", help="save folder path of the epubs model translated.")
parser.add_argument("--text_length", type=int, default=512, help="input max length in each inference.")
args = parser.parse_args()
if args.use_gptq_model:
from auto_gptq import AutoGPTQForCausalLM
generation_config = GenerationConfig(
temperature=0.1,
top_p=0.3,
top_k=40,
num_beams=1,
bos_token_id=1,
eos_token_id=2,
pad_token_id=0,
max_new_tokens=1024,
min_new_tokens=1,
do_sample=True
)
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False, trust_remote_code=True)
if args.use_gptq_model:
model = AutoGPTQForCausalLM.from_quantized(args.model_name_or_path, device="cuda:0", trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, device_map="auto", trust_remote_code=True)
print("Start translating...")
start = time.time()
epub_list = []
save_list = []
if args.data_path:
epub_list.append(args.data_path)
save_list.append(os.path.join(args.output_folder, os.path.basename(args.data_path)))
if args.data_folder:
os.makedirs(args.output_folder, exist_ok=True)
for f in os.listdir(args.data_folder):
if f.endswith(".epub"):
epub_list.append(os.path.join(args.data_folder, f))
save_list.append(os.path.join(args.output_folder, f))
for epub_path, save_path in zip(epub_list, save_list):
print(f"translating {epub_path}...")
start_epub = time.time()
if os.path.exists('./temp'):
shutil.rmtree('./temp')
with zipfile.ZipFile(epub_path, 'r') as f:
f.extractall('./temp')
for html_path in find_all_htmls('./temp'):
print(f"\ttranslating {html_path}...")
start_html = time.time()
translated = ''
data_list, file_text = get_html_text_list(html_path, args.text_length)
if len(data_list) == 0:
continue
for text, groups, pre_end in tqdm(data_list):
prompt = get_prompt(text, args.model_version)
output = get_model_response(model, tokenizer, prompt, args.model_version, generation_config, args.text_length)
texts = output.strip().split('\n')
if len(texts) < len(groups):
texts += [''] * (len(groups) - len(texts))
else:
texts = texts[:len(groups)-1] + ['
'.join(texts[len(groups)-1:])]
for t, match in zip(texts, groups):
t = match.group(0).replace(match.group(2), t)
translated += file_text[pre_end:match.start()] + t
pre_end = match.end()
translated += file_text[data_list[-1][1][-1].end():]
with open(html_path, 'w', encoding='utf-8') as f:
f.write(translated)
end_html = time.time()
print(f"\t{html_path} translated, used time: ", end_html-start_html)
with zipfile.ZipFile(save_path, 'w', zipfile.ZIP_DEFLATED) as f:
for file_path in glob.glob(f'./temp/**', recursive=True):
if not os.path.isdir(file_path):
relative_path = os.path.relpath(file_path, './temp')
f.write(file_path, relative_path)
shutil.rmtree('./temp')
end_epub = time.time()
print(f"{epub_path} translated, used time: ", end_epub-start_epub)
end = time.time()
print("translation completed, used time: ", end-start)
if __name__ == "__main__":
main()