princepride commited on
Commit
65e3693
1 Parent(s): 6a89750

Create document_trans.py

Browse files
Files changed (1) hide show
  1. document_trans.py +214 -0
document_trans.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, pipeline
2
+ from abc import ABC, abstractmethod
3
+ from typing import Type
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import os
7
+
8
+ script_dir = os.path.dirname(os.path.abspath(__file__))
9
+ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
10
+
11
+ class Model():
12
+ def __init__(self, modelname, selected_lora_model, selected_gpu):
13
+ def get_gpu_index(gpu_info, target_gpu_name):
14
+ """
15
+ 从 GPU 信息中获取目标 GPU 的索引
16
+ Args:
17
+ gpu_info (list): 包含 GPU 名称的列表
18
+ target_gpu_name (str): 目标 GPU 的名称
19
+ Returns:
20
+ int: 目标 GPU 的索引,如果未找到则返回 -1
21
+ """
22
+ for i, name in enumerate(gpu_info):
23
+ if target_gpu_name.lower() in name.lower():
24
+ return i
25
+ return -1
26
+ if selected_gpu != "cpu":
27
+ gpu_count = torch.cuda.device_count()
28
+ gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
29
+ selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
30
+ self.device_name = f"cuda:{selected_gpu_index}"
31
+ else:
32
+ self.device_name = "cpu"
33
+ print("device_name", self.device_name)
34
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
35
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
36
+ # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
37
+
38
+ def generate(self, inputs, original_language, target_languages, max_batch_size):
39
+ def language_mapping(original_language):
40
+ d = {
41
+ "Achinese (Arabic script)": "ace_Arab",
42
+ "Achinese (Latin script)": "ace_Latn",
43
+ "Mesopotamian Arabic": "acm_Arab",
44
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
45
+ "Tunisian Arabic": "aeb_Arab",
46
+ "Afrikaans": "afr_Latn",
47
+ "South Levantine Arabic": "ajp_Arab",
48
+ "Akan": "aka_Latn",
49
+ "Amharic": "amh_Ethi",
50
+ "North Levantine Arabic": "apc_Arab",
51
+ "Standard Arabic": "arb_Arab",
52
+ "Najdi Arabic": "ars_Arab",
53
+ "Moroccan Arabic": "ary_Arab",
54
+ "Egyptian Arabic": "arz_Arab",
55
+ "Assamese": "asm_Beng",
56
+ "Asturian": "ast_Latn",
57
+ "Awadhi": "awa_Deva",
58
+ "Central Aymara": "ayr_Latn",
59
+ "South Azerbaijani": "azb_Arab",
60
+ "North Azerbaijani": "azj_Latn",
61
+ "Bashkir": "bak_Cyrl",
62
+ "Bambara": "bam_Latn",
63
+ "Balinese": "ban_Latn",
64
+ "Belarusian": "bel_Cyrl",
65
+ "Bemba": "bem_Latn",
66
+ "Bengali": "ben_Beng",
67
+ "Bhojpuri": "bho_Deva",
68
+ "Banjar (Arabic script)": "bjn_Arab",
69
+ "Banjar (Latin script)": "bjn_Latn",
70
+ "Tibetan": "bod_Tibt",
71
+ "Bosnian": "bos_Latn",
72
+ "Buginese": "bug_Latn",
73
+ "Bulgarian": "bul_Cyrl",
74
+ "Catalan": "cat_Latn",
75
+ "Cebuano": "ceb_Latn",
76
+ "Czech": "ces_Latn",
77
+ "Chokwe": "cjk_Latn",
78
+ "Central Kurdish": "ckb_Arab",
79
+ "Crimean Tatar": "crh_Latn",
80
+ "Welsh": "cym_Latn",
81
+ "Danish": "dan_Latn",
82
+ "German": "deu_Latn",
83
+ "Dinka": "dik_Latn",
84
+ "Jula": "dyu_Latn",
85
+ "Dzongkha": "dzo_Tibt",
86
+ "Greek": "ell_Grek",
87
+ "English": "eng_Latn",
88
+ "Esperanto": "epo_Latn",
89
+ "Estonian": "est_Latn",
90
+ "Basque": "eus_Latn",
91
+ "Ewe": "ewe_Latn",
92
+ "Faroese": "fao_Latn",
93
+ "Persian": "pes_Arab",
94
+ "Fijian": "fij_Latn",
95
+ "Finnish": "fin_Latn",
96
+ "Fon": "fon_Latn",
97
+ "French": "fra_Latn",
98
+ "Friulian": "fur_Latn",
99
+ "Nigerian Fulfulde": "fuv_Latn",
100
+ "Scottish Gaelic": "gla_Latn",
101
+ "Irish": "gle_Latn",
102
+ "Galician": "glg_Latn",
103
+ "Guarani": "grn_Latn",
104
+ "Gujarati": "guj_Gujr",
105
+ "Haitian Creole": "hat_Latn",
106
+ "Hausa": "hau_Latn",
107
+ "Hebrew": "heb_Hebr",
108
+ "Hindi": "hin_Deva",
109
+ "Chhattisgarhi": "hne_Deva",
110
+ "Croatian": "hrv_Latn",
111
+ "Hungarian": "hun_Latn",
112
+ "Armenian": "hye_Armn",
113
+ "Igbo": "ibo_Latn",
114
+ "Iloko": "ilo_Latn",
115
+ "Indonesian": "ind_Latn",
116
+ "Icelandic": "isl_Latn",
117
+ "Italian": "ita_Latn",
118
+ "Javanese": "jav_Latn",
119
+ "Japanese": "jpn_Jpan",
120
+ "Kabyle": "kab_Latn",
121
+ "Kachin": "kac_Latn",
122
+ "Arabic": "ar_AR",
123
+ "Chinese": "zho_Hans",
124
+ "Spanish": "spa_Latn",
125
+ "Dutch": "nld_Latn",
126
+ "Kazakh": "kaz_Cyrl",
127
+ "Korean": "kor_Hang",
128
+ "Lithuanian": "lit_Latn",
129
+ "Malayalam": "mal_Mlym",
130
+ "Marathi": "mar_Deva",
131
+ "Nepali": "ne_NP",
132
+ "Polish": "pol_Latn",
133
+ "Portuguese": "por_Latn",
134
+ "Russian": "rus_Cyrl",
135
+ "Sinhala": "sin_Sinh",
136
+ "Tamil": "tam_Taml",
137
+ "Turkish": "tur_Latn",
138
+ "Ukrainian": "ukr_Cyrl",
139
+ "Urdu": "urd_Arab",
140
+ "Vietnamese": "vie_Latn",
141
+ "Thai":"tha_Thai"
142
+ }
143
+ return d[original_language]
144
+ self.tokenizer.src_lang = language_mapping(original_language)
145
+ if self.device_name == "cpu":
146
+ # Tokenize input
147
+ input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
148
+ output = []
149
+ for target_language in target_languages:
150
+ # Get language code for the target language
151
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
152
+ # Generate translation
153
+ generated_tokens = self.model.generate(
154
+ **input_ids,
155
+ forced_bos_token_id=target_lang_code,
156
+ max_length=128
157
+ )
158
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
159
+ # Append result to output
160
+ output.append({
161
+ "target_language": target_language,
162
+ "generated_translation": generated_translation,
163
+ })
164
+ outputs = []
165
+ length = len(output[0]["generated_translation"])
166
+ for i in range(length):
167
+ temp = []
168
+ for trans in output:
169
+ temp.append({
170
+ "target_language": trans["target_language"],
171
+ "generated_translation": trans['generated_translation'][i],
172
+ })
173
+ outputs.append(temp)
174
+ return outputs
175
+ else:
176
+ # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
177
+ # max_batch_size = 10
178
+ # Ensure batch size is within model limits:
179
+ batch_size = min(len(inputs), int(max_batch_size))
180
+ batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
181
+ temp_outputs = []
182
+ processed_num = 0
183
+ for index, batch in enumerate(batches):
184
+ # Tokenize input
185
+ input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
186
+ temp = []
187
+ for target_language in target_languages:
188
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
189
+ generated_tokens = self.model.generate(
190
+ **input_ids,
191
+ forced_bos_token_id=target_lang_code,
192
+ )
193
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
194
+ # Append result to output
195
+ temp.append({
196
+ "target_language": target_language,
197
+ "generated_translation": generated_translation,
198
+ })
199
+ input_ids.to('cpu')
200
+ del input_ids
201
+ temp_outputs.append(temp)
202
+ processed_num += len(batch)
203
+ outputs = []
204
+ for temp_output in temp_outputs:
205
+ length = len(temp_output[0]["generated_translation"])
206
+ for i in range(length):
207
+ temp = []
208
+ for trans in temp_output:
209
+ temp.append({
210
+ "target_language": trans["target_language"],
211
+ "generated_translation": trans['generated_translation'][i],
212
+ })
213
+ outputs.append(temp)
214
+ return outputs