|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, pipeline |
|
from abc import ABC, abstractmethod |
|
from typing import Type |
|
import torch |
|
import torch.nn.functional as F |
|
import os |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir))) |
|
|
|
class Model(): |
|
def __init__(self, modelname, selected_lora_model, selected_gpu): |
|
def get_gpu_index(gpu_info, target_gpu_name): |
|
""" |
|
从 GPU 信息中获取目标 GPU 的索引 |
|
Args: |
|
gpu_info (list): 包含 GPU 名称的列表 |
|
target_gpu_name (str): 目标 GPU 的名称 |
|
Returns: |
|
int: 目标 GPU 的索引,如果未找到则返回 -1 |
|
""" |
|
for i, name in enumerate(gpu_info): |
|
if target_gpu_name.lower() in name.lower(): |
|
return i |
|
return -1 |
|
if selected_gpu != "cpu": |
|
gpu_count = torch.cuda.device_count() |
|
gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)] |
|
selected_gpu_index = get_gpu_index(gpu_info, selected_gpu) |
|
self.device_name = f"cuda:{selected_gpu_index}" |
|
else: |
|
self.device_name = "cpu" |
|
print("device_name", self.device_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name) |
|
self.tokenizer = AutoTokenizer.from_pretrained(modelname) |
|
|
|
|
|
def generate(self, inputs, original_language, target_languages, max_batch_size): |
|
def language_mapping(original_language): |
|
d = { |
|
"Achinese (Arabic script)": "ace_Arab", |
|
"Achinese (Latin script)": "ace_Latn", |
|
"Mesopotamian Arabic": "acm_Arab", |
|
"Ta'izzi-Adeni Arabic": "acq_Arab", |
|
"Tunisian Arabic": "aeb_Arab", |
|
"Afrikaans": "afr_Latn", |
|
"South Levantine Arabic": "ajp_Arab", |
|
"Akan": "aka_Latn", |
|
"Amharic": "amh_Ethi", |
|
"North Levantine Arabic": "apc_Arab", |
|
"Standard Arabic": "arb_Arab", |
|
"Najdi Arabic": "ars_Arab", |
|
"Moroccan Arabic": "ary_Arab", |
|
"Egyptian Arabic": "arz_Arab", |
|
"Assamese": "asm_Beng", |
|
"Asturian": "ast_Latn", |
|
"Awadhi": "awa_Deva", |
|
"Central Aymara": "ayr_Latn", |
|
"South Azerbaijani": "azb_Arab", |
|
"North Azerbaijani": "azj_Latn", |
|
"Bashkir": "bak_Cyrl", |
|
"Bambara": "bam_Latn", |
|
"Balinese": "ban_Latn", |
|
"Belarusian": "bel_Cyrl", |
|
"Bemba": "bem_Latn", |
|
"Bengali": "ben_Beng", |
|
"Bhojpuri": "bho_Deva", |
|
"Banjar (Arabic script)": "bjn_Arab", |
|
"Banjar (Latin script)": "bjn_Latn", |
|
"Tibetan": "bod_Tibt", |
|
"Bosnian": "bos_Latn", |
|
"Buginese": "bug_Latn", |
|
"Bulgarian": "bul_Cyrl", |
|
"Catalan": "cat_Latn", |
|
"Cebuano": "ceb_Latn", |
|
"Czech": "ces_Latn", |
|
"Chokwe": "cjk_Latn", |
|
"Central Kurdish": "ckb_Arab", |
|
"Crimean Tatar": "crh_Latn", |
|
"Welsh": "cym_Latn", |
|
"Danish": "dan_Latn", |
|
"German": "deu_Latn", |
|
"Dinka": "dik_Latn", |
|
"Jula": "dyu_Latn", |
|
"Dzongkha": "dzo_Tibt", |
|
"Greek": "ell_Grek", |
|
"English": "eng_Latn", |
|
"Esperanto": "epo_Latn", |
|
"Estonian": "est_Latn", |
|
"Basque": "eus_Latn", |
|
"Ewe": "ewe_Latn", |
|
"Faroese": "fao_Latn", |
|
"Persian": "pes_Arab", |
|
"Fijian": "fij_Latn", |
|
"Finnish": "fin_Latn", |
|
"Fon": "fon_Latn", |
|
"French": "fra_Latn", |
|
"Friulian": "fur_Latn", |
|
"Nigerian Fulfulde": "fuv_Latn", |
|
"Scottish Gaelic": "gla_Latn", |
|
"Irish": "gle_Latn", |
|
"Galician": "glg_Latn", |
|
"Guarani": "grn_Latn", |
|
"Gujarati": "guj_Gujr", |
|
"Haitian Creole": "hat_Latn", |
|
"Hausa": "hau_Latn", |
|
"Hebrew": "heb_Hebr", |
|
"Hindi": "hin_Deva", |
|
"Chhattisgarhi": "hne_Deva", |
|
"Croatian": "hrv_Latn", |
|
"Hungarian": "hun_Latn", |
|
"Armenian": "hye_Armn", |
|
"Igbo": "ibo_Latn", |
|
"Iloko": "ilo_Latn", |
|
"Indonesian": "ind_Latn", |
|
"Icelandic": "isl_Latn", |
|
"Italian": "ita_Latn", |
|
"Javanese": "jav_Latn", |
|
"Japanese": "jpn_Jpan", |
|
"Kabyle": "kab_Latn", |
|
"Kachin": "kac_Latn", |
|
"Arabic": "ar_AR", |
|
"Chinese": "zho_Hans", |
|
"Spanish": "spa_Latn", |
|
"Dutch": "nld_Latn", |
|
"Kazakh": "kaz_Cyrl", |
|
"Korean": "kor_Hang", |
|
"Lithuanian": "lit_Latn", |
|
"Malayalam": "mal_Mlym", |
|
"Marathi": "mar_Deva", |
|
"Nepali": "ne_NP", |
|
"Polish": "pol_Latn", |
|
"Portuguese": "por_Latn", |
|
"Russian": "rus_Cyrl", |
|
"Sinhala": "sin_Sinh", |
|
"Tamil": "tam_Taml", |
|
"Turkish": "tur_Latn", |
|
"Ukrainian": "ukr_Cyrl", |
|
"Urdu": "urd_Arab", |
|
"Vietnamese": "vie_Latn", |
|
"Thai":"tha_Thai" |
|
} |
|
return d[original_language] |
|
self.tokenizer.src_lang = language_mapping(original_language) |
|
if self.device_name == "cpu": |
|
|
|
input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name) |
|
output = [] |
|
for target_language in target_languages: |
|
|
|
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] |
|
|
|
generated_tokens = self.model.generate( |
|
**input_ids, |
|
forced_bos_token_id=target_lang_code, |
|
max_length=128 |
|
) |
|
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
output.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
outputs = [] |
|
length = len(output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
return outputs |
|
else: |
|
|
|
|
|
|
|
batch_size = min(len(inputs), int(max_batch_size)) |
|
batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)] |
|
temp_outputs = [] |
|
processed_num = 0 |
|
for index, batch in enumerate(batches): |
|
|
|
input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name) |
|
temp = [] |
|
for target_language in target_languages: |
|
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] |
|
generated_tokens = self.model.generate( |
|
**input_ids, |
|
forced_bos_token_id=target_lang_code, |
|
) |
|
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
temp.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
input_ids.to('cpu') |
|
del input_ids |
|
temp_outputs.append(temp) |
|
processed_num += len(batch) |
|
outputs = [] |
|
for temp_output in temp_outputs: |
|
length = len(temp_output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in temp_output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
return outputs |