princepride commited on
Commit
0fc9e07
1 Parent(s): eb95586

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +713 -352
model.py CHANGED
@@ -1,353 +1,714 @@
1
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
- import torch
3
- from modules.file import ExcelFileWriter
4
- import os
5
-
6
- from abc import ABC, abstractmethod
7
- from typing import List
8
- import re
9
-
10
- class FilterPipeline():
11
- def __init__(self, filter_list):
12
- self._filter_list:List[Filter] = filter_list
13
-
14
- def append(self, filter):
15
- self._filter_list.append(filter)
16
-
17
- def batch_encoder(self, inputs):
18
- for filter in self._filter_list:
19
- inputs = filter.encoder(inputs)
20
- return inputs
21
-
22
- def batch_decoder(self, inputs):
23
- for filter in reversed(self._filter_list):
24
- inputs = filter.decoder(inputs)
25
- return inputs
26
-
27
- class Filter(ABC):
28
- def __init__(self):
29
- self.name = 'filter'
30
- self.code = []
31
- @abstractmethod
32
- def encoder(self, inputs):
33
- pass
34
-
35
- @abstractmethod
36
- def decoder(self, inputs):
37
- pass
38
-
39
- class SpecialTokenFilter(Filter):
40
- def __init__(self):
41
- self.name = 'special token filter'
42
- self.code = []
43
- self.special_tokens = ['!', '!']
44
-
45
- def encoder(self, inputs):
46
- filtered_inputs = []
47
- self.code = []
48
- for i, input_str in enumerate(inputs):
49
- if not all(char in self.special_tokens for char in input_str):
50
- filtered_inputs.append(input_str)
51
- else:
52
- self.code.append([i, input_str])
53
- return filtered_inputs
54
-
55
- def decoder(self, inputs):
56
- original_inputs = inputs.copy()
57
- for removed_indice in self.code:
58
- original_inputs.insert(removed_indice[0], removed_indice[1])
59
- return original_inputs
60
-
61
- class SperSignFilter(Filter):
62
- def __init__(self):
63
- self.name = 's persign filter'
64
- self.code = []
65
-
66
- def encoder(self, inputs):
67
- encoded_inputs = []
68
- self.code = [] # 清空 self.code
69
- for i, input_str in enumerate(inputs):
70
- if 's%' in input_str:
71
- encoded_str = input_str.replace('s%', '*')
72
- self.code.append(i) # 将包含 's%' 的字符串的索引存储到 self.code 中
73
- else:
74
- encoded_str = input_str
75
- encoded_inputs.append(encoded_str)
76
- return encoded_inputs
77
-
78
- def decoder(self, inputs):
79
- decoded_inputs = inputs.copy()
80
- for i in self.code:
81
- decoded_inputs[i] = decoded_inputs[i].replace('*', 's%') # 使用 self.code 中的索引还原原始字符串
82
- return decoded_inputs
83
-
84
- class SimilarFilter(Filter):
85
- def __init__(self):
86
- self.name = 'similar filter'
87
- self.code = []
88
-
89
- def is_similar(self, str1, str2):
90
- # 判断两个字符串是否相似(只有数字上有区别)
91
- pattern = re.compile(r'\d+')
92
- return pattern.sub('', str1) == pattern.sub('', str2)
93
-
94
- def encoder(self, inputs):
95
- encoded_inputs = []
96
- self.code = [] # 清空 self.code
97
- i = 0
98
- while i < len(inputs):
99
- encoded_inputs.append(inputs[i])
100
- similar_strs = [inputs[i]]
101
- j = i + 1
102
- while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
103
- similar_strs.append(inputs[j])
104
- j += 1
105
- if len(similar_strs) > 1:
106
- self.code.append((i, similar_strs)) # 将相似字符串的起始索引和实际字符串列表存储到 self.code 中
107
- i = j
108
- return encoded_inputs
109
-
110
- def decoder(self, inputs):
111
- decoded_inputs = []
112
- index = 0
113
- for i, similar_strs in self.code:
114
- decoded_inputs.extend(inputs[index:i])
115
- decoded_inputs.extend(similar_strs) # 直接将实际的相似字符串添加到 decoded_inputs 中
116
- index = i + 1
117
- decoded_inputs.extend(inputs[index:])
118
- return decoded_inputs
119
-
120
- script_dir = os.path.dirname(os.path.abspath(__file__))
121
- parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
122
-
123
- class Model():
124
- def __init__(self, modelname, selected_lora_model, selected_gpu):
125
- def get_gpu_index(gpu_info, target_gpu_name):
126
- """
127
- GPU 信息中获取目标 GPU 的索引
128
- Args:
129
- gpu_info (list): 包含 GPU 名称的列表
130
- target_gpu_name (str): 目标 GPU 的名称
131
-
132
- Returns:
133
- int: 目标 GPU 的索引,如果未找到则返回 -1
134
- """
135
- for i, name in enumerate(gpu_info):
136
- if target_gpu_name.lower() in name.lower():
137
- return i
138
- return -1
139
- if selected_gpu != "cpu":
140
- gpu_count = torch.cuda.device_count()
141
- gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
142
- selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
143
- self.device_name = f"cuda:{selected_gpu_index}"
144
- else:
145
- self.device_name = "cpu"
146
- print("device_name", self.device_name)
147
- self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
148
- self.tokenizer = AutoTokenizer.from_pretrained(modelname)
149
- # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
150
-
151
- def generate(self, inputs, original_language, target_languages, max_batch_size):
152
- filter_list = [SpecialTokenFilter(), SperSignFilter(), SimilarFilter()]
153
- filter_pipeline = FilterPipeline(filter_list)
154
- def language_mapping(original_language):
155
- d = {
156
- "Achinese (Arabic script)": "ace_Arab",
157
- "Achinese (Latin script)": "ace_Latn",
158
- "Mesopotamian Arabic": "acm_Arab",
159
- "Ta'izzi-Adeni Arabic": "acq_Arab",
160
- "Tunisian Arabic": "aeb_Arab",
161
- "Afrikaans": "afr_Latn",
162
- "South Levantine Arabic": "ajp_Arab",
163
- "Akan": "aka_Latn",
164
- "Amharic": "amh_Ethi",
165
- "North Levantine Arabic": "apc_Arab",
166
- "Standard Arabic": "arb_Arab",
167
- "Najdi Arabic": "ars_Arab",
168
- "Moroccan Arabic": "ary_Arab",
169
- "Egyptian Arabic": "arz_Arab",
170
- "Assamese": "asm_Beng",
171
- "Asturian": "ast_Latn",
172
- "Awadhi": "awa_Deva",
173
- "Central Aymara": "ayr_Latn",
174
- "South Azerbaijani": "azb_Arab",
175
- "North Azerbaijani": "azj_Latn",
176
- "Bashkir": "bak_Cyrl",
177
- "Bambara": "bam_Latn",
178
- "Balinese": "ban_Latn",
179
- "Belarusian": "bel_Cyrl",
180
- "Bemba": "bem_Latn",
181
- "Bengali": "ben_Beng",
182
- "Bhojpuri": "bho_Deva",
183
- "Banjar (Arabic script)": "bjn_Arab",
184
- "Banjar (Latin script)": "bjn_Latn",
185
- "Tibetan": "bod_Tibt",
186
- "Bosnian": "bos_Latn",
187
- "Buginese": "bug_Latn",
188
- "Bulgarian": "bul_Cyrl",
189
- "Catalan": "cat_Latn",
190
- "Cebuano": "ceb_Latn",
191
- "Czech": "ces_Latn",
192
- "Chokwe": "cjk_Latn",
193
- "Central Kurdish": "ckb_Arab",
194
- "Crimean Tatar": "crh_Latn",
195
- "Welsh": "cym_Latn",
196
- "Danish": "dan_Latn",
197
- "German": "deu_Latn",
198
- "Dinka": "dik_Latn",
199
- "Jula": "dyu_Latn",
200
- "Dzongkha": "dzo_Tibt",
201
- "Greek": "ell_Grek",
202
- "English": "eng_Latn",
203
- "Esperanto": "epo_Latn",
204
- "Estonian": "est_Latn",
205
- "Basque": "eus_Latn",
206
- "Ewe": "ewe_Latn",
207
- "Faroese": "fao_Latn",
208
- "Persian": "pes_Arab",
209
- "Fijian": "fij_Latn",
210
- "Finnish": "fin_Latn",
211
- "Fon": "fon_Latn",
212
- "French": "fra_Latn",
213
- "Friulian": "fur_Latn",
214
- "Nigerian Fulfulde": "fuv_Latn",
215
- "Scottish Gaelic": "gla_Latn",
216
- "Irish": "gle_Latn",
217
- "Galician": "glg_Latn",
218
- "Guarani": "grn_Latn",
219
- "Gujarati": "guj_Gujr",
220
- "Haitian Creole": "hat_Latn",
221
- "Hausa": "hau_Latn",
222
- "Hebrew": "heb_Hebr",
223
- "Hindi": "hin_Deva",
224
- "Chhattisgarhi": "hne_Deva",
225
- "Croatian": "hrv_Latn",
226
- "Hungarian": "hun_Latn",
227
- "Armenian": "hye_Armn",
228
- "Igbo": "ibo_Latn",
229
- "Iloko": "ilo_Latn",
230
- "Indonesian": "ind_Latn",
231
- "Icelandic": "isl_Latn",
232
- "Italian": "ita_Latn",
233
- "Javanese": "jav_Latn",
234
- "Japanese": "jpn_Jpan",
235
- "Kabyle": "kab_Latn",
236
- "Kachin": "kac_Latn",
237
- "Arabic": "ar_AR",
238
- "Chinese": "zho_Hans",
239
- "Spanish": "spa_Latn",
240
- "Dutch": "nld_Latn",
241
- "Kazakh": "kaz_Cyrl",
242
- "Korean": "kor_Hang",
243
- "Lithuanian": "lit_Latn",
244
- "Malayalam": "mal_Mlym",
245
- "Marathi": "mar_Deva",
246
- "Nepali": "ne_NP",
247
- "Polish": "pol_Latn",
248
- "Portuguese": "por_Latn",
249
- "Russian": "rus_Cyrl",
250
- "Sinhala": "sin_Sinh",
251
- "Tamil": "tam_Taml",
252
- "Turkish": "tur_Latn",
253
- "Ukrainian": "ukr_Cyrl",
254
- "Urdu": "urd_Arab",
255
- "Vietnamese": "vie_Latn",
256
- "Thai":"tha_Thai"
257
- }
258
- return d[original_language]
259
- def process_gpu_translate_result(temp_outputs):
260
- outputs = []
261
- for temp_output in temp_outputs:
262
- length = len(temp_output[0]["generated_translation"])
263
- for i in range(length):
264
- temp = []
265
- for trans in temp_output:
266
- temp.append({
267
- "target_language": trans["target_language"],
268
- "generated_translation": trans['generated_translation'][i],
269
- })
270
- outputs.append(temp)
271
- excel_writer = ExcelFileWriter()
272
- excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
273
- self.tokenizer.src_lang = language_mapping(original_language)
274
- if self.device_name == "cpu":
275
- # Tokenize input
276
- input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
277
- output = []
278
- for target_language in target_languages:
279
- # Get language code for the target language
280
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
281
- # Generate translation
282
- generated_tokens = self.model.generate(
283
- **input_ids,
284
- forced_bos_token_id=target_lang_code,
285
- max_length=128
286
- )
287
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
288
- # Append result to output
289
- output.append({
290
- "target_language": target_language,
291
- "generated_translation": generated_translation,
292
- })
293
- outputs = []
294
- length = len(output[0]["generated_translation"])
295
- for i in range(length):
296
- temp = []
297
- for trans in output:
298
- temp.append({
299
- "target_language": trans["target_language"],
300
- "generated_translation": trans['generated_translation'][i],
301
- })
302
- outputs.append(temp)
303
- return outputs
304
- else:
305
- # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
306
- # max_batch_size = 10
307
- # Ensure batch size is within model limits:
308
- print("length of inputs: ",len(inputs))
309
- batch_size = min(len(inputs), int(max_batch_size))
310
- batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
311
- print("length of batches size: ", len(batches))
312
- temp_outputs = []
313
- processed_num = 0
314
- for index, batch in enumerate(batches):
315
- # Tokenize input
316
- batch = filter_pipeline.batch_encoder(batch)
317
- print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
318
- print(batch)
319
- input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
320
- temp = []
321
- for target_language in target_languages:
322
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
323
- generated_tokens = self.model.generate(
324
- **input_ids,
325
- forced_bos_token_id=target_lang_code,
326
- )
327
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
328
- print(generated_translation)
329
- generated_translation = filter_pipeline.batch_decoder(generated_translation)
330
- # Append result to output
331
- temp.append({
332
- "target_language": target_language,
333
- "generated_translation": generated_translation,
334
- })
335
- input_ids.to('cpu')
336
- del input_ids
337
- temp_outputs.append(temp)
338
- processed_num += len(batch)
339
- if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
340
- print("Already processed number: ", len(temp_outputs))
341
- process_gpu_translate_result(temp_outputs)
342
- outputs = []
343
- for temp_output in temp_outputs:
344
- length = len(temp_output[0]["generated_translation"])
345
- for i in range(length):
346
- temp = []
347
- for trans in temp_output:
348
- temp.append({
349
- "target_language": trans["target_language"],
350
- "generated_translation": trans['generated_translation'][i],
351
- })
352
- outputs.append(temp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  return outputs
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import torch
3
+ from modules.file import ExcelFileWriter
4
+ import os
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import List
8
+ import re
9
+
10
+ class FilterPipeline():
11
+ def __init__(self, filter_list):
12
+ self._filter_list:List[Filter] = filter_list
13
+
14
+ def append(self, filter):
15
+ self._filter_list.append(filter)
16
+
17
+ def batch_encoder(self, inputs):from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
18
+ import torch
19
+ from modules.file import ExcelFileWriter
20
+ import os
21
+
22
+ from abc import ABC, abstractmethod
23
+ from typing import List
24
+ import re
25
+
26
+ class FilterPipeline():
27
+ def __init__(self, filter_list):
28
+ self._filter_list:List[Filter] = filter_list
29
+
30
+ def append(self, filter):
31
+ self._filter_list.append(filter)
32
+
33
+ def batch_encoder(self, inputs):
34
+ for filter in self._filter_list:
35
+ inputs = filter.encoder(inputs)
36
+ return inputs
37
+
38
+ def batch_decoder(self, inputs):
39
+ for filter in reversed(self._filter_list):
40
+ inputs = filter.decoder(inputs)
41
+ return inputs
42
+
43
+ class Filter(ABC):
44
+ def __init__(self):
45
+ self.name = 'filter'
46
+ self.code = []
47
+ @abstractmethod
48
+ def encoder(self, inputs):
49
+ pass
50
+
51
+ @abstractmethod
52
+ def decoder(self, inputs):
53
+ pass
54
+
55
+ class SpecialTokenFilter(Filter):
56
+ def __init__(self):
57
+ self.name = 'special token filter'
58
+ self.code = []
59
+ self.special_tokens = ['!', '!', '-']
60
+
61
+ def encoder(self, inputs):
62
+ filtered_inputs = []
63
+ self.code = []
64
+ for i, input_str in enumerate(inputs):
65
+ if not all(char in self.special_tokens for char in input_str):
66
+ filtered_inputs.append(input_str)
67
+ else:
68
+ self.code.append([i, input_str])
69
+ return filtered_inputs
70
+
71
+ def decoder(self, inputs):
72
+ original_inputs = inputs.copy()
73
+ for removed_indice in self.code:
74
+ original_inputs.insert(removed_indice[0], removed_indice[1])
75
+ return original_inputs
76
+
77
+ class SperSignFilter(Filter):
78
+ def __init__(self):
79
+ self.name = 's persign filter'
80
+ self.code = []
81
+
82
+ def encoder(self, inputs):
83
+ encoded_inputs = []
84
+ self.code = [] # 清空 self.code
85
+ for i, input_str in enumerate(inputs):
86
+ if 's%' in input_str:
87
+ encoded_str = input_str.replace('s%', '*')
88
+ self.code.append(i) # 将包含 's%' 的字符串的索引存储到 self.code 中
89
+ else:
90
+ encoded_str = input_str
91
+ encoded_inputs.append(encoded_str)
92
+ return encoded_inputs
93
+
94
+ def decoder(self, inputs):
95
+ decoded_inputs = inputs.copy()
96
+ for i in self.code:
97
+ decoded_inputs[i] = decoded_inputs[i].replace('*', 's%') # 使用 self.code 中的索引还原原始字符串
98
+ return decoded_inputs
99
+
100
+ class SimilarFilter(Filter):
101
+ def __init__(self):
102
+ self.name = 'similar filter'
103
+ self.code = []
104
+
105
+ def is_similar(self, str1, str2):
106
+ # 判断两个字符串是否相似(只有数字上有区别)
107
+ pattern = re.compile(r'\d+')
108
+ return pattern.sub('', str1) == pattern.sub('', str2)
109
+
110
+ def encoder(self, inputs):
111
+ encoded_inputs = []
112
+ self.code = [] # 清空 self.code
113
+ i = 0
114
+ while i < len(inputs):
115
+ encoded_inputs.append(inputs[i])
116
+ similar_strs = [inputs[i]]
117
+ j = i + 1
118
+ while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
119
+ similar_strs.append(inputs[j])
120
+ j += 1
121
+ if len(similar_strs) > 1:
122
+ self.code.append((i, similar_strs)) # 将相似字符串的起始索引和实际字符串列表存储到 self.code 中
123
+ i = j
124
+ return encoded_inputs
125
+
126
+ def decoder(self, inputs:List):
127
+ decoded_inputs = inputs
128
+ for i, similar_strs in self.code:
129
+ pattern = re.compile(r'\d+')
130
+ for j in range(len(similar_strs)):
131
+ if pattern.search(similar_strs[j]):
132
+ number = re.findall(r'\d+', similar_strs[j])[0] # 获取相似字符串的数字部分
133
+ new_str = pattern.sub(number, inputs[i]) # 将新字符串的数字部分替换为相似字符串的数字部分
134
+ else:
135
+ new_str = inputs[i] # 如果相似字符串不含数字,直接使用新字符串
136
+ if j > 0:
137
+ decoded_inputs.insert(i+j, new_str)
138
+ return decoded_inputs
139
+
140
+ script_dir = os.path.dirname(os.path.abspath(__file__))
141
+ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
142
+
143
+ class Model():
144
+ def __init__(self, modelname, selected_lora_model, selected_gpu):
145
+ def get_gpu_index(gpu_info, target_gpu_name):
146
+ """
147
+ GPU 信息中获取目标 GPU 的索引
148
+ Args:
149
+ gpu_info (list): 包含 GPU 名称的列表
150
+ target_gpu_name (str): 目标 GPU 的名称
151
+
152
+ Returns:
153
+ int: 目标 GPU 的索引,如果未找到则返回 -1
154
+ """
155
+ for i, name in enumerate(gpu_info):
156
+ if target_gpu_name.lower() in name.lower():
157
+ return i
158
+ return -1
159
+ if selected_gpu != "cpu":
160
+ gpu_count = torch.cuda.device_count()
161
+ gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
162
+ selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
163
+ self.device_name = f"cuda:{selected_gpu_index}"
164
+ else:
165
+ self.device_name = "cpu"
166
+ print("device_name", self.device_name)
167
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
168
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
169
+ # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
170
+
171
+ def generate(self, inputs, original_language, target_languages, max_batch_size):
172
+ filter_list = [SpecialTokenFilter(), SperSignFilter(), SimilarFilter()]
173
+ filter_pipeline = FilterPipeline(filter_list)
174
+ def language_mapping(original_language):
175
+ d = {
176
+ "Achinese (Arabic script)": "ace_Arab",
177
+ "Achinese (Latin script)": "ace_Latn",
178
+ "Mesopotamian Arabic": "acm_Arab",
179
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
180
+ "Tunisian Arabic": "aeb_Arab",
181
+ "Afrikaans": "afr_Latn",
182
+ "South Levantine Arabic": "ajp_Arab",
183
+ "Akan": "aka_Latn",
184
+ "Amharic": "amh_Ethi",
185
+ "North Levantine Arabic": "apc_Arab",
186
+ "Standard Arabic": "arb_Arab",
187
+ "Najdi Arabic": "ars_Arab",
188
+ "Moroccan Arabic": "ary_Arab",
189
+ "Egyptian Arabic": "arz_Arab",
190
+ "Assamese": "asm_Beng",
191
+ "Asturian": "ast_Latn",
192
+ "Awadhi": "awa_Deva",
193
+ "Central Aymara": "ayr_Latn",
194
+ "South Azerbaijani": "azb_Arab",
195
+ "North Azerbaijani": "azj_Latn",
196
+ "Bashkir": "bak_Cyrl",
197
+ "Bambara": "bam_Latn",
198
+ "Balinese": "ban_Latn",
199
+ "Belarusian": "bel_Cyrl",
200
+ "Bemba": "bem_Latn",
201
+ "Bengali": "ben_Beng",
202
+ "Bhojpuri": "bho_Deva",
203
+ "Banjar (Arabic script)": "bjn_Arab",
204
+ "Banjar (Latin script)": "bjn_Latn",
205
+ "Tibetan": "bod_Tibt",
206
+ "Bosnian": "bos_Latn",
207
+ "Buginese": "bug_Latn",
208
+ "Bulgarian": "bul_Cyrl",
209
+ "Catalan": "cat_Latn",
210
+ "Cebuano": "ceb_Latn",
211
+ "Czech": "ces_Latn",
212
+ "Chokwe": "cjk_Latn",
213
+ "Central Kurdish": "ckb_Arab",
214
+ "Crimean Tatar": "crh_Latn",
215
+ "Welsh": "cym_Latn",
216
+ "Danish": "dan_Latn",
217
+ "German": "deu_Latn",
218
+ "Dinka": "dik_Latn",
219
+ "Jula": "dyu_Latn",
220
+ "Dzongkha": "dzo_Tibt",
221
+ "Greek": "ell_Grek",
222
+ "English": "eng_Latn",
223
+ "Esperanto": "epo_Latn",
224
+ "Estonian": "est_Latn",
225
+ "Basque": "eus_Latn",
226
+ "Ewe": "ewe_Latn",
227
+ "Faroese": "fao_Latn",
228
+ "Persian": "pes_Arab",
229
+ "Fijian": "fij_Latn",
230
+ "Finnish": "fin_Latn",
231
+ "Fon": "fon_Latn",
232
+ "French": "fra_Latn",
233
+ "Friulian": "fur_Latn",
234
+ "Nigerian Fulfulde": "fuv_Latn",
235
+ "Scottish Gaelic": "gla_Latn",
236
+ "Irish": "gle_Latn",
237
+ "Galician": "glg_Latn",
238
+ "Guarani": "grn_Latn",
239
+ "Gujarati": "guj_Gujr",
240
+ "Haitian Creole": "hat_Latn",
241
+ "Hausa": "hau_Latn",
242
+ "Hebrew": "heb_Hebr",
243
+ "Hindi": "hin_Deva",
244
+ "Chhattisgarhi": "hne_Deva",
245
+ "Croatian": "hrv_Latn",
246
+ "Hungarian": "hun_Latn",
247
+ "Armenian": "hye_Armn",
248
+ "Igbo": "ibo_Latn",
249
+ "Iloko": "ilo_Latn",
250
+ "Indonesian": "ind_Latn",
251
+ "Icelandic": "isl_Latn",
252
+ "Italian": "ita_Latn",
253
+ "Javanese": "jav_Latn",
254
+ "Japanese": "jpn_Jpan",
255
+ "Kabyle": "kab_Latn",
256
+ "Kachin": "kac_Latn",
257
+ "Arabic": "ar_AR",
258
+ "Chinese": "zho_Hans",
259
+ "Spanish": "spa_Latn",
260
+ "Dutch": "nld_Latn",
261
+ "Kazakh": "kaz_Cyrl",
262
+ "Korean": "kor_Hang",
263
+ "Lithuanian": "lit_Latn",
264
+ "Malayalam": "mal_Mlym",
265
+ "Marathi": "mar_Deva",
266
+ "Nepali": "ne_NP",
267
+ "Polish": "pol_Latn",
268
+ "Portuguese": "por_Latn",
269
+ "Russian": "rus_Cyrl",
270
+ "Sinhala": "sin_Sinh",
271
+ "Tamil": "tam_Taml",
272
+ "Turkish": "tur_Latn",
273
+ "Ukrainian": "ukr_Cyrl",
274
+ "Urdu": "urd_Arab",
275
+ "Vietnamese": "vie_Latn",
276
+ "Thai":"tha_Thai"
277
+ }
278
+ return d[original_language]
279
+ def process_gpu_translate_result(temp_outputs):
280
+ outputs = []
281
+ for temp_output in temp_outputs:
282
+ length = len(temp_output[0]["generated_translation"])
283
+ for i in range(length):
284
+ temp = []
285
+ for trans in temp_output:
286
+ temp.append({
287
+ "target_language": trans["target_language"],
288
+ "generated_translation": trans['generated_translation'][i],
289
+ })
290
+ outputs.append(temp)
291
+ excel_writer = ExcelFileWriter()
292
+ excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
293
+ self.tokenizer.src_lang = language_mapping(original_language)
294
+ if self.device_name == "cpu":
295
+ # Tokenize input
296
+ input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
297
+ output = []
298
+ for target_language in target_languages:
299
+ # Get language code for the target language
300
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
301
+ # Generate translation
302
+ generated_tokens = self.model.generate(
303
+ **input_ids,
304
+ forced_bos_token_id=target_lang_code,
305
+ max_length=128
306
+ )
307
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
308
+ # Append result to output
309
+ output.append({
310
+ "target_language": target_language,
311
+ "generated_translation": generated_translation,
312
+ })
313
+ outputs = []
314
+ length = len(output[0]["generated_translation"])
315
+ for i in range(length):
316
+ temp = []
317
+ for trans in output:
318
+ temp.append({
319
+ "target_language": trans["target_language"],
320
+ "generated_translation": trans['generated_translation'][i],
321
+ })
322
+ outputs.append(temp)
323
+ return outputs
324
+ else:
325
+ # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
326
+ # max_batch_size = 10
327
+ # Ensure batch size is within model limits:
328
+ print("length of inputs: ",len(inputs))
329
+ batch_size = min(len(inputs), int(max_batch_size))
330
+ batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
331
+ print("length of batches size: ", len(batches))
332
+ temp_outputs = []
333
+ processed_num = 0
334
+ for index, batch in enumerate(batches):
335
+ # Tokenize input
336
+ print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
337
+ print(len(batch))
338
+ print(batch)
339
+ batch = filter_pipeline.batch_encoder(batch)
340
+ print(batch)
341
+ input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
342
+ temp = []
343
+ for target_language in target_languages:
344
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
345
+ generated_tokens = self.model.generate(
346
+ **input_ids,
347
+ forced_bos_token_id=target_lang_code,
348
+ )
349
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
350
+
351
+ print(generated_translation)
352
+ generated_translation = filter_pipeline.batch_decoder(generated_translation)
353
+ print(generated_translation)
354
+ print(len(generated_translation))
355
+ # Append result to output
356
+ temp.append({
357
+ "target_language": target_language,
358
+ "generated_translation": generated_translation,
359
+ })
360
+ input_ids.to('cpu')
361
+ del input_ids
362
+ temp_outputs.append(temp)
363
+ processed_num += len(batch)
364
+ if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
365
+ print("Already processed number: ", len(temp_outputs))
366
+ process_gpu_translate_result(temp_outputs)
367
+ outputs = []
368
+ for temp_output in temp_outputs:
369
+ length = len(temp_output[0]["generated_translation"])
370
+ for i in range(length):
371
+ temp = []
372
+ for trans in temp_output:
373
+ temp.append({
374
+ "target_language": trans["target_language"],
375
+ "generated_translation": trans['generated_translation'][i],
376
+ })
377
+ outputs.append(temp)
378
+ return outputs
379
+ for filter in self._filter_list:
380
+ inputs = filter.encoder(inputs)
381
+ return inputs
382
+
383
+ def batch_decoder(self, inputs):
384
+ for filter in reversed(self._filter_list):
385
+ inputs = filter.decoder(inputs)
386
+ return inputs
387
+
388
+ class Filter(ABC):
389
+ def __init__(self):
390
+ self.name = 'filter'
391
+ self.code = []
392
+ @abstractmethod
393
+ def encoder(self, inputs):
394
+ pass
395
+
396
+ @abstractmethod
397
+ def decoder(self, inputs):
398
+ pass
399
+
400
+ class SpecialTokenFilter(Filter):
401
+ def __init__(self):
402
+ self.name = 'special token filter'
403
+ self.code = []
404
+ self.special_tokens = ['!', '!']
405
+
406
+ def encoder(self, inputs):
407
+ filtered_inputs = []
408
+ self.code = []
409
+ for i, input_str in enumerate(inputs):
410
+ if not all(char in self.special_tokens for char in input_str):
411
+ filtered_inputs.append(input_str)
412
+ else:
413
+ self.code.append([i, input_str])
414
+ return filtered_inputs
415
+
416
+ def decoder(self, inputs):
417
+ original_inputs = inputs.copy()
418
+ for removed_indice in self.code:
419
+ original_inputs.insert(removed_indice[0], removed_indice[1])
420
+ return original_inputs
421
+
422
+ class SperSignFilter(Filter):
423
+ def __init__(self):
424
+ self.name = 's persign filter'
425
+ self.code = []
426
+
427
+ def encoder(self, inputs):
428
+ encoded_inputs = []
429
+ self.code = [] # 清空 self.code
430
+ for i, input_str in enumerate(inputs):
431
+ if 's%' in input_str:
432
+ encoded_str = input_str.replace('s%', '*')
433
+ self.code.append(i) # 将包含 's%' 的字符串的索引存储到 self.code 中
434
+ else:
435
+ encoded_str = input_str
436
+ encoded_inputs.append(encoded_str)
437
+ return encoded_inputs
438
+
439
+ def decoder(self, inputs):
440
+ decoded_inputs = inputs.copy()
441
+ for i in self.code:
442
+ decoded_inputs[i] = decoded_inputs[i].replace('*', 's%') # 使用 self.code 中的索引还原原始字符串
443
+ return decoded_inputs
444
+
445
+ class SimilarFilter(Filter):
446
+ def __init__(self):
447
+ self.name = 'similar filter'
448
+ self.code = []
449
+
450
+ def is_similar(self, str1, str2):
451
+ # 判断两个字符串是否相似(只有数字上有区别)
452
+ pattern = re.compile(r'\d+')
453
+ return pattern.sub('', str1) == pattern.sub('', str2)
454
+
455
+ def encoder(self, inputs):
456
+ encoded_inputs = []
457
+ self.code = [] # 清空 self.code
458
+ i = 0
459
+ while i < len(inputs):
460
+ encoded_inputs.append(inputs[i])
461
+ similar_strs = [inputs[i]]
462
+ j = i + 1
463
+ while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
464
+ similar_strs.append(inputs[j])
465
+ j += 1
466
+ if len(similar_strs) > 1:
467
+ self.code.append((i, similar_strs)) # 将相似字符串的起始索引和实际字符串列表存储到 self.code 中
468
+ i = j
469
+ return encoded_inputs
470
+
471
+ def decoder(self, inputs):
472
+ decoded_inputs = []
473
+ index = 0
474
+ for i, similar_strs in self.code:
475
+ decoded_inputs.extend(inputs[index:i])
476
+ decoded_inputs.extend(similar_strs) # 直接将实际的相似字符串添加到 decoded_inputs 中
477
+ index = i + 1
478
+ decoded_inputs.extend(inputs[index:])
479
+ return decoded_inputs
480
+
481
+ script_dir = os.path.dirname(os.path.abspath(__file__))
482
+ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
483
+
484
+ class Model():
485
+ def __init__(self, modelname, selected_lora_model, selected_gpu):
486
+ def get_gpu_index(gpu_info, target_gpu_name):
487
+ """
488
+ 从 GPU 信息中获取目标 GPU 的索引
489
+ Args:
490
+ gpu_info (list): 包含 GPU 名称的列表
491
+ target_gpu_name (str): 目标 GPU 的名称
492
+
493
+ Returns:
494
+ int: 目标 GPU 的索引,如果未找到则返回 -1
495
+ """
496
+ for i, name in enumerate(gpu_info):
497
+ if target_gpu_name.lower() in name.lower():
498
+ return i
499
+ return -1
500
+ if selected_gpu != "cpu":
501
+ gpu_count = torch.cuda.device_count()
502
+ gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
503
+ selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
504
+ self.device_name = f"cuda:{selected_gpu_index}"
505
+ else:
506
+ self.device_name = "cpu"
507
+ print("device_name", self.device_name)
508
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
509
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
510
+ # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
511
+
512
+ def generate(self, inputs, original_language, target_languages, max_batch_size):
513
+ filter_list = [SpecialTokenFilter(), SperSignFilter(), SimilarFilter()]
514
+ filter_pipeline = FilterPipeline(filter_list)
515
+ def language_mapping(original_language):
516
+ d = {
517
+ "Achinese (Arabic script)": "ace_Arab",
518
+ "Achinese (Latin script)": "ace_Latn",
519
+ "Mesopotamian Arabic": "acm_Arab",
520
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
521
+ "Tunisian Arabic": "aeb_Arab",
522
+ "Afrikaans": "afr_Latn",
523
+ "South Levantine Arabic": "ajp_Arab",
524
+ "Akan": "aka_Latn",
525
+ "Amharic": "amh_Ethi",
526
+ "North Levantine Arabic": "apc_Arab",
527
+ "Standard Arabic": "arb_Arab",
528
+ "Najdi Arabic": "ars_Arab",
529
+ "Moroccan Arabic": "ary_Arab",
530
+ "Egyptian Arabic": "arz_Arab",
531
+ "Assamese": "asm_Beng",
532
+ "Asturian": "ast_Latn",
533
+ "Awadhi": "awa_Deva",
534
+ "Central Aymara": "ayr_Latn",
535
+ "South Azerbaijani": "azb_Arab",
536
+ "North Azerbaijani": "azj_Latn",
537
+ "Bashkir": "bak_Cyrl",
538
+ "Bambara": "bam_Latn",
539
+ "Balinese": "ban_Latn",
540
+ "Belarusian": "bel_Cyrl",
541
+ "Bemba": "bem_Latn",
542
+ "Bengali": "ben_Beng",
543
+ "Bhojpuri": "bho_Deva",
544
+ "Banjar (Arabic script)": "bjn_Arab",
545
+ "Banjar (Latin script)": "bjn_Latn",
546
+ "Tibetan": "bod_Tibt",
547
+ "Bosnian": "bos_Latn",
548
+ "Buginese": "bug_Latn",
549
+ "Bulgarian": "bul_Cyrl",
550
+ "Catalan": "cat_Latn",
551
+ "Cebuano": "ceb_Latn",
552
+ "Czech": "ces_Latn",
553
+ "Chokwe": "cjk_Latn",
554
+ "Central Kurdish": "ckb_Arab",
555
+ "Crimean Tatar": "crh_Latn",
556
+ "Welsh": "cym_Latn",
557
+ "Danish": "dan_Latn",
558
+ "German": "deu_Latn",
559
+ "Dinka": "dik_Latn",
560
+ "Jula": "dyu_Latn",
561
+ "Dzongkha": "dzo_Tibt",
562
+ "Greek": "ell_Grek",
563
+ "English": "eng_Latn",
564
+ "Esperanto": "epo_Latn",
565
+ "Estonian": "est_Latn",
566
+ "Basque": "eus_Latn",
567
+ "Ewe": "ewe_Latn",
568
+ "Faroese": "fao_Latn",
569
+ "Persian": "pes_Arab",
570
+ "Fijian": "fij_Latn",
571
+ "Finnish": "fin_Latn",
572
+ "Fon": "fon_Latn",
573
+ "French": "fra_Latn",
574
+ "Friulian": "fur_Latn",
575
+ "Nigerian Fulfulde": "fuv_Latn",
576
+ "Scottish Gaelic": "gla_Latn",
577
+ "Irish": "gle_Latn",
578
+ "Galician": "glg_Latn",
579
+ "Guarani": "grn_Latn",
580
+ "Gujarati": "guj_Gujr",
581
+ "Haitian Creole": "hat_Latn",
582
+ "Hausa": "hau_Latn",
583
+ "Hebrew": "heb_Hebr",
584
+ "Hindi": "hin_Deva",
585
+ "Chhattisgarhi": "hne_Deva",
586
+ "Croatian": "hrv_Latn",
587
+ "Hungarian": "hun_Latn",
588
+ "Armenian": "hye_Armn",
589
+ "Igbo": "ibo_Latn",
590
+ "Iloko": "ilo_Latn",
591
+ "Indonesian": "ind_Latn",
592
+ "Icelandic": "isl_Latn",
593
+ "Italian": "ita_Latn",
594
+ "Javanese": "jav_Latn",
595
+ "Japanese": "jpn_Jpan",
596
+ "Kabyle": "kab_Latn",
597
+ "Kachin": "kac_Latn",
598
+ "Arabic": "ar_AR",
599
+ "Chinese": "zho_Hans",
600
+ "Spanish": "spa_Latn",
601
+ "Dutch": "nld_Latn",
602
+ "Kazakh": "kaz_Cyrl",
603
+ "Korean": "kor_Hang",
604
+ "Lithuanian": "lit_Latn",
605
+ "Malayalam": "mal_Mlym",
606
+ "Marathi": "mar_Deva",
607
+ "Nepali": "ne_NP",
608
+ "Polish": "pol_Latn",
609
+ "Portuguese": "por_Latn",
610
+ "Russian": "rus_Cyrl",
611
+ "Sinhala": "sin_Sinh",
612
+ "Tamil": "tam_Taml",
613
+ "Turkish": "tur_Latn",
614
+ "Ukrainian": "ukr_Cyrl",
615
+ "Urdu": "urd_Arab",
616
+ "Vietnamese": "vie_Latn",
617
+ "Thai":"tha_Thai"
618
+ }
619
+ return d[original_language]
620
+ def process_gpu_translate_result(temp_outputs):
621
+ outputs = []
622
+ for temp_output in temp_outputs:
623
+ length = len(temp_output[0]["generated_translation"])
624
+ for i in range(length):
625
+ temp = []
626
+ for trans in temp_output:
627
+ temp.append({
628
+ "target_language": trans["target_language"],
629
+ "generated_translation": trans['generated_translation'][i],
630
+ })
631
+ outputs.append(temp)
632
+ excel_writer = ExcelFileWriter()
633
+ excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
634
+ self.tokenizer.src_lang = language_mapping(original_language)
635
+ if self.device_name == "cpu":
636
+ # Tokenize input
637
+ input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
638
+ output = []
639
+ for target_language in target_languages:
640
+ # Get language code for the target language
641
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
642
+ # Generate translation
643
+ generated_tokens = self.model.generate(
644
+ **input_ids,
645
+ forced_bos_token_id=target_lang_code,
646
+ max_length=128
647
+ )
648
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
649
+ # Append result to output
650
+ output.append({
651
+ "target_language": target_language,
652
+ "generated_translation": generated_translation,
653
+ })
654
+ outputs = []
655
+ length = len(output[0]["generated_translation"])
656
+ for i in range(length):
657
+ temp = []
658
+ for trans in output:
659
+ temp.append({
660
+ "target_language": trans["target_language"],
661
+ "generated_translation": trans['generated_translation'][i],
662
+ })
663
+ outputs.append(temp)
664
+ return outputs
665
+ else:
666
+ # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
667
+ # max_batch_size = 10
668
+ # Ensure batch size is within model limits:
669
+ print("length of inputs: ",len(inputs))
670
+ batch_size = min(len(inputs), int(max_batch_size))
671
+ batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
672
+ print("length of batches size: ", len(batches))
673
+ temp_outputs = []
674
+ processed_num = 0
675
+ for index, batch in enumerate(batches):
676
+ # Tokenize input
677
+ batch = filter_pipeline.batch_encoder(batch)
678
+ print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
679
+ print(batch)
680
+ input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
681
+ temp = []
682
+ for target_language in target_languages:
683
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
684
+ generated_tokens = self.model.generate(
685
+ **input_ids,
686
+ forced_bos_token_id=target_lang_code,
687
+ )
688
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
689
+ print(generated_translation)
690
+ generated_translation = filter_pipeline.batch_decoder(generated_translation)
691
+ # Append result to output
692
+ temp.append({
693
+ "target_language": target_language,
694
+ "generated_translation": generated_translation,
695
+ })
696
+ input_ids.to('cpu')
697
+ del input_ids
698
+ temp_outputs.append(temp)
699
+ processed_num += len(batch)
700
+ if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
701
+ print("Already processed number: ", len(temp_outputs))
702
+ process_gpu_translate_result(temp_outputs)
703
+ outputs = []
704
+ for temp_output in temp_outputs:
705
+ length = len(temp_output[0]["generated_translation"])
706
+ for i in range(length):
707
+ temp = []
708
+ for trans in temp_output:
709
+ temp.append({
710
+ "target_language": trans["target_language"],
711
+ "generated_translation": trans['generated_translation'][i],
712
+ })
713
+ outputs.append(temp)
714
  return outputs