Wang Zhipeng commited on
Commit
6a89750
1 Parent(s): 651b246

add readme.md support_language.json model.py

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +14 -0
  3. model.py +233 -0
  4. support_language.json +208 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NLLB
2
+
3
+ The No Language Left Behind (NLLB) model is an ambitious project spearheaded by Meta AI (formerly Facebook AI), aimed at breaking language barriers and facilitating universal access to information across the globe. This advanced machine translation system represents a significant leap forward in AI-driven language technology, with the goal of providing high-quality translations in a wide range of languages, including those that are underrepresented in the digital world.
4
+
5
+ ## Introduction
6
+
7
+ The NLLB model is part of Meta's broader effort to democratize information and make the internet more inclusive. By leveraging cutting-edge machine learning techniques and vast amounts of linguistic data, NLLB strives to deliver accurate and contextually relevant translations across a multitude of languages, many of which have been traditionally neglected by major technology providers.
8
+
9
+ ## Features
10
+
11
+ * **Wide Language Coverage:** NLLB supports an impressive array of languages, focusing on inclusivity and the representation of underrepresented languages.
12
+ * **High-Quality Translations:** Utilizes advanced AI and machine learning algorithms to ensure translations are not only accurate but also maintain the context and cultural nuances.
13
+ * **Accessibility:** Designed to be easily integrated into various platforms and applications, NLLB aims to make multilingual content accessible to a global audience.
14
+ * **Open Source:** In line with Meta's commitment to open science, parts of the NLLB project are made available to the public, enabling researchers and developers to contribute to and build upon this groundbreaking work.
model.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, pipeline
2
+ from abc import ABC, abstractmethod
3
+ from typing import Type
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from modules.file import ExcelFileWriter
7
+ import os
8
+
9
+ script_dir = os.path.dirname(os.path.abspath(__file__))
10
+ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
11
+
12
+ class Model():
13
+ def __init__(self, modelname, selected_lora_model, selected_gpu):
14
+ def get_gpu_index(gpu_info, target_gpu_name):
15
+ """
16
+ 从 GPU 信息中获取目标 GPU 的索引
17
+ Args:
18
+ gpu_info (list): 包含 GPU 名称的列表
19
+ target_gpu_name (str): 目标 GPU 的名称
20
+
21
+ Returns:
22
+ int: 目标 GPU 的索引,如果未找到则返回 -1
23
+ """
24
+ for i, name in enumerate(gpu_info):
25
+ if target_gpu_name.lower() in name.lower():
26
+ return i
27
+ return -1
28
+ if selected_gpu != "cpu":
29
+ gpu_count = torch.cuda.device_count()
30
+ gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
31
+ selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
32
+ self.device_name = f"cuda:{selected_gpu_index}"
33
+ else:
34
+ self.device_name = "cpu"
35
+ print("device_name", self.device_name)
36
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
37
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
38
+ # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
39
+
40
+ def generate(self, inputs, original_language, target_languages, max_batch_size):
41
+ def language_mapping(original_language):
42
+ d = {
43
+ "Achinese (Arabic script)": "ace_Arab",
44
+ "Achinese (Latin script)": "ace_Latn",
45
+ "Mesopotamian Arabic": "acm_Arab",
46
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
47
+ "Tunisian Arabic": "aeb_Arab",
48
+ "Afrikaans": "afr_Latn",
49
+ "South Levantine Arabic": "ajp_Arab",
50
+ "Akan": "aka_Latn",
51
+ "Amharic": "amh_Ethi",
52
+ "North Levantine Arabic": "apc_Arab",
53
+ "Standard Arabic": "arb_Arab",
54
+ "Najdi Arabic": "ars_Arab",
55
+ "Moroccan Arabic": "ary_Arab",
56
+ "Egyptian Arabic": "arz_Arab",
57
+ "Assamese": "asm_Beng",
58
+ "Asturian": "ast_Latn",
59
+ "Awadhi": "awa_Deva",
60
+ "Central Aymara": "ayr_Latn",
61
+ "South Azerbaijani": "azb_Arab",
62
+ "North Azerbaijani": "azj_Latn",
63
+ "Bashkir": "bak_Cyrl",
64
+ "Bambara": "bam_Latn",
65
+ "Balinese": "ban_Latn",
66
+ "Belarusian": "bel_Cyrl",
67
+ "Bemba": "bem_Latn",
68
+ "Bengali": "ben_Beng",
69
+ "Bhojpuri": "bho_Deva",
70
+ "Banjar (Arabic script)": "bjn_Arab",
71
+ "Banjar (Latin script)": "bjn_Latn",
72
+ "Tibetan": "bod_Tibt",
73
+ "Bosnian": "bos_Latn",
74
+ "Buginese": "bug_Latn",
75
+ "Bulgarian": "bul_Cyrl",
76
+ "Catalan": "cat_Latn",
77
+ "Cebuano": "ceb_Latn",
78
+ "Czech": "ces_Latn",
79
+ "Chokwe": "cjk_Latn",
80
+ "Central Kurdish": "ckb_Arab",
81
+ "Crimean Tatar": "crh_Latn",
82
+ "Welsh": "cym_Latn",
83
+ "Danish": "dan_Latn",
84
+ "German": "deu_Latn",
85
+ "Dinka": "dik_Latn",
86
+ "Jula": "dyu_Latn",
87
+ "Dzongkha": "dzo_Tibt",
88
+ "Greek": "ell_Grek",
89
+ "English": "eng_Latn",
90
+ "Esperanto": "epo_Latn",
91
+ "Estonian": "est_Latn",
92
+ "Basque": "eus_Latn",
93
+ "Ewe": "ewe_Latn",
94
+ "Faroese": "fao_Latn",
95
+ "Persian": "pes_Arab",
96
+ "Fijian": "fij_Latn",
97
+ "Finnish": "fin_Latn",
98
+ "Fon": "fon_Latn",
99
+ "French": "fra_Latn",
100
+ "Friulian": "fur_Latn",
101
+ "Nigerian Fulfulde": "fuv_Latn",
102
+ "Scottish Gaelic": "gla_Latn",
103
+ "Irish": "gle_Latn",
104
+ "Galician": "glg_Latn",
105
+ "Guarani": "grn_Latn",
106
+ "Gujarati": "guj_Gujr",
107
+ "Haitian Creole": "hat_Latn",
108
+ "Hausa": "hau_Latn",
109
+ "Hebrew": "heb_Hebr",
110
+ "Hindi": "hin_Deva",
111
+ "Chhattisgarhi": "hne_Deva",
112
+ "Croatian": "hrv_Latn",
113
+ "Hungarian": "hun_Latn",
114
+ "Armenian": "hye_Armn",
115
+ "Igbo": "ibo_Latn",
116
+ "Iloko": "ilo_Latn",
117
+ "Indonesian": "ind_Latn",
118
+ "Icelandic": "isl_Latn",
119
+ "Italian": "ita_Latn",
120
+ "Javanese": "jav_Latn",
121
+ "Japanese": "jpn_Jpan",
122
+ "Kabyle": "kab_Latn",
123
+ "Kachin": "kac_Latn",
124
+ "Arabic": "ar_AR",
125
+ "Chinese": "zho_Hans",
126
+ "Spanish": "spa_Latn",
127
+ "Dutch": "nld_Latn",
128
+ "Kazakh": "kaz_Cyrl",
129
+ "Korean": "kor_Hang",
130
+ "Lithuanian": "lit_Latn",
131
+ "Malayalam": "mal_Mlym",
132
+ "Marathi": "mar_Deva",
133
+ "Nepali": "ne_NP",
134
+ "Polish": "pol_Latn",
135
+ "Portuguese": "por_Latn",
136
+ "Russian": "rus_Cyrl",
137
+ "Sinhala": "sin_Sinh",
138
+ "Tamil": "tam_Taml",
139
+ "Turkish": "tur_Latn",
140
+ "Ukrainian": "ukr_Cyrl",
141
+ "Urdu": "urd_Arab",
142
+ "Vietnamese": "vie_Latn",
143
+ "Thai":"tha_Thai"
144
+ }
145
+ return d[original_language]
146
+ def process_gpu_translate_result(temp_outputs):
147
+ outputs = []
148
+ for temp_output in temp_outputs:
149
+ length = len(temp_output[0]["generated_translation"])
150
+ for i in range(length):
151
+ temp = []
152
+ for trans in temp_output:
153
+ temp.append({
154
+ "target_language": trans["target_language"],
155
+ "generated_translation": trans['generated_translation'][i],
156
+ })
157
+ outputs.append(temp)
158
+ excel_writer = ExcelFileWriter()
159
+ excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
160
+ self.tokenizer.src_lang = language_mapping(original_language)
161
+ if self.device_name == "cpu":
162
+ # Tokenize input
163
+ input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
164
+ output = []
165
+ for target_language in target_languages:
166
+ # Get language code for the target language
167
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
168
+ # Generate translation
169
+ generated_tokens = self.model.generate(
170
+ **input_ids,
171
+ forced_bos_token_id=target_lang_code,
172
+ max_length=128
173
+ )
174
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
175
+ # Append result to output
176
+ output.append({
177
+ "target_language": target_language,
178
+ "generated_translation": generated_translation,
179
+ })
180
+ outputs = []
181
+ length = len(output[0]["generated_translation"])
182
+ for i in range(length):
183
+ temp = []
184
+ for trans in output:
185
+ temp.append({
186
+ "target_language": trans["target_language"],
187
+ "generated_translation": trans['generated_translation'][i],
188
+ })
189
+ outputs.append(temp)
190
+ return outputs
191
+ else:
192
+ # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
193
+ # max_batch_size = 10
194
+ # Ensure batch size is within model limits:
195
+ batch_size = min(len(inputs), int(max_batch_size))
196
+ batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
197
+ temp_outputs = []
198
+ processed_num = 0
199
+ for index, batch in enumerate(batches):
200
+ # Tokenize input
201
+ input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
202
+ temp = []
203
+ for target_language in target_languages:
204
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
205
+ generated_tokens = self.model.generate(
206
+ **input_ids,
207
+ forced_bos_token_id=target_lang_code,
208
+ )
209
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
210
+ # Append result to output
211
+ temp.append({
212
+ "target_language": target_language,
213
+ "generated_translation": generated_translation,
214
+ })
215
+ input_ids.to('cpu')
216
+ del input_ids
217
+ temp_outputs.append(temp)
218
+ processed_num += len(batch)
219
+ if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
220
+ print("Already processed number: ", len(temp_outputs))
221
+ process_gpu_translate_result(temp_outputs)
222
+ outputs = []
223
+ for temp_output in temp_outputs:
224
+ length = len(temp_output[0]["generated_translation"])
225
+ for i in range(length):
226
+ temp = []
227
+ for trans in temp_output:
228
+ temp.append({
229
+ "target_language": trans["target_language"],
230
+ "generated_translation": trans['generated_translation'][i],
231
+ })
232
+ outputs.append(temp)
233
+ return outputs
support_language.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "original_language":[
3
+ "Achinese (Arabic script)",
4
+ "Achinese (Latin script)",
5
+ "Afrikaans",
6
+ "Akan",
7
+ "Amharic",
8
+ "Arabic",
9
+ "Armenian",
10
+ "Assamese",
11
+ "Asturian",
12
+ "Awadhi",
13
+ "Balinese",
14
+ "Bambara",
15
+ "Banjar (Arabic script)",
16
+ "Banjar (Latin script)",
17
+ "Bashkir",
18
+ "Basque",
19
+ "Belarusian",
20
+ "Bemba",
21
+ "Bengali",
22
+ "Bhojpuri",
23
+ "Bosnian",
24
+ "Buginese",
25
+ "Bulgarian",
26
+ "Catalan",
27
+ "Cebuano",
28
+ "Central Aymara",
29
+ "Central Kurdish",
30
+ "Chhattisgarhi",
31
+ "Chinese",
32
+ "Chokwe",
33
+ "Crimean Tatar",
34
+ "Croatian",
35
+ "Czech",
36
+ "Danish",
37
+ "Dinka",
38
+ "Dutch",
39
+ "Dzongkha",
40
+ "Egyptian Arabic",
41
+ "English",
42
+ "Esperanto",
43
+ "Estonian",
44
+ "Ewe",
45
+ "Faroese",
46
+ "Fijian",
47
+ "Finnish",
48
+ "Fon",
49
+ "French",
50
+ "Friulian",
51
+ "Galician",
52
+ "German",
53
+ "Greek",
54
+ "Guarani",
55
+ "Gujarati",
56
+ "Haitian Creole",
57
+ "Hausa",
58
+ "Hebrew",
59
+ "Hindi",
60
+ "Hungarian",
61
+ "Icelandic",
62
+ "Igbo",
63
+ "Iloko",
64
+ "Indonesian",
65
+ "Irish",
66
+ "Italian",
67
+ "Japanese",
68
+ "Javanese",
69
+ "Jula",
70
+ "Kabyle",
71
+ "Kachin",
72
+ "Kazakh",
73
+ "Korean",
74
+ "Lithuanian",
75
+ "Malayalam",
76
+ "Marathi",
77
+ "Mesopotamian Arabic",
78
+ "Moroccan Arabic",
79
+ "Najdi Arabic",
80
+ "Nepali",
81
+ "Nigerian Fulfulde",
82
+ "North Azerbaijani",
83
+ "North Levantine Arabic",
84
+ "Persian",
85
+ "Polish",
86
+ "Portuguese",
87
+ "Russian",
88
+ "Scottish Gaelic",
89
+ "Sinhala",
90
+ "South Azerbaijani",
91
+ "South Levantine Arabic",
92
+ "Spanish",
93
+ "Standard Arabic",
94
+ "Ta'izzi-Adeni Arabic",
95
+ "Tamil",
96
+ "Thai",
97
+ "Tibetan",
98
+ "Tunisian Arabic",
99
+ "Turkish",
100
+ "Ukrainian",
101
+ "Urdu",
102
+ "Vietnamese",
103
+ "Welsh"
104
+ ],
105
+ "target_language":[
106
+ "Achinese (Arabic script)",
107
+ "Achinese (Latin script)",
108
+ "Afrikaans",
109
+ "Akan",
110
+ "Amharic",
111
+ "Arabic",
112
+ "Armenian",
113
+ "Assamese",
114
+ "Asturian",
115
+ "Awadhi",
116
+ "Balinese",
117
+ "Bambara",
118
+ "Banjar (Arabic script)",
119
+ "Banjar (Latin script)",
120
+ "Bashkir",
121
+ "Basque",
122
+ "Belarusian",
123
+ "Bemba",
124
+ "Bengali",
125
+ "Bhojpuri",
126
+ "Bosnian",
127
+ "Buginese",
128
+ "Bulgarian",
129
+ "Catalan",
130
+ "Cebuano",
131
+ "Central Aymara",
132
+ "Central Kurdish",
133
+ "Chhattisgarhi",
134
+ "Chinese",
135
+ "Chokwe",
136
+ "Crimean Tatar",
137
+ "Croatian",
138
+ "Czech",
139
+ "Danish",
140
+ "Dinka",
141
+ "Dutch",
142
+ "Dzongkha",
143
+ "Egyptian Arabic",
144
+ "English",
145
+ "Esperanto",
146
+ "Estonian",
147
+ "Ewe",
148
+ "Faroese",
149
+ "Fijian",
150
+ "Finnish",
151
+ "Fon",
152
+ "French",
153
+ "Friulian",
154
+ "Galician",
155
+ "German",
156
+ "Greek",
157
+ "Guarani",
158
+ "Gujarati",
159
+ "Haitian Creole",
160
+ "Hausa",
161
+ "Hebrew",
162
+ "Hindi",
163
+ "Hungarian",
164
+ "Icelandic",
165
+ "Igbo",
166
+ "Iloko",
167
+ "Indonesian",
168
+ "Irish",
169
+ "Italian",
170
+ "Japanese",
171
+ "Javanese",
172
+ "Jula",
173
+ "Kabyle",
174
+ "Kachin",
175
+ "Kazakh",
176
+ "Korean",
177
+ "Lithuanian",
178
+ "Malayalam",
179
+ "Marathi",
180
+ "Mesopotamian Arabic",
181
+ "Moroccan Arabic",
182
+ "Najdi Arabic",
183
+ "Nepali",
184
+ "Nigerian Fulfulde",
185
+ "North Azerbaijani",
186
+ "North Levantine Arabic",
187
+ "Persian",
188
+ "Polish",
189
+ "Portuguese",
190
+ "Russian",
191
+ "Scottish Gaelic",
192
+ "Sinhala",
193
+ "South Azerbaijani",
194
+ "South Levantine Arabic",
195
+ "Spanish",
196
+ "Standard Arabic",
197
+ "Ta'izzi-Adeni Arabic",
198
+ "Tamil",
199
+ "Thai",
200
+ "Tibetan",
201
+ "Tunisian Arabic",
202
+ "Turkish",
203
+ "Ukrainian",
204
+ "Urdu",
205
+ "Vietnamese",
206
+ "Welsh"
207
+ ]
208
+ }