from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, pipeline from abc import ABC, abstractmethod from typing import Type import torch import torch.nn.functional as F import os script_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir))) class Model(): def __init__(self, modelname, selected_lora_model, selected_gpu): def get_gpu_index(gpu_info, target_gpu_name): """ 从 GPU 信息中获取目标 GPU 的索引 Args: gpu_info (list): 包含 GPU 名称的列表 target_gpu_name (str): 目标 GPU 的名称 Returns: int: 目标 GPU 的索引,如果未找到则返回 -1 """ for i, name in enumerate(gpu_info): if target_gpu_name.lower() in name.lower(): return i return -1 if selected_gpu != "cpu": gpu_count = torch.cuda.device_count() gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)] selected_gpu_index = get_gpu_index(gpu_info, selected_gpu) self.device_name = f"cuda:{selected_gpu_index}" else: self.device_name = "cpu" print("device_name", self.device_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name) self.tokenizer = AutoTokenizer.from_pretrained(modelname) # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device) def generate(self, inputs, original_language, target_languages, max_batch_size): def language_mapping(original_language): d = { "Achinese (Arabic script)": "ace_Arab", "Achinese (Latin script)": "ace_Latn", "Mesopotamian Arabic": "acm_Arab", "Ta'izzi-Adeni Arabic": "acq_Arab", "Tunisian Arabic": "aeb_Arab", "Afrikaans": "afr_Latn", "South Levantine Arabic": "ajp_Arab", "Akan": "aka_Latn", "Amharic": "amh_Ethi", "North Levantine Arabic": "apc_Arab", "Standard Arabic": "arb_Arab", "Najdi Arabic": "ars_Arab", "Moroccan Arabic": "ary_Arab", "Egyptian Arabic": "arz_Arab", "Assamese": "asm_Beng", "Asturian": "ast_Latn", "Awadhi": "awa_Deva", "Central Aymara": "ayr_Latn", "South Azerbaijani": "azb_Arab", "North Azerbaijani": "azj_Latn", "Bashkir": "bak_Cyrl", "Bambara": "bam_Latn", "Balinese": "ban_Latn", "Belarusian": "bel_Cyrl", "Bemba": "bem_Latn", "Bengali": "ben_Beng", "Bhojpuri": "bho_Deva", "Banjar (Arabic script)": "bjn_Arab", "Banjar (Latin script)": "bjn_Latn", "Tibetan": "bod_Tibt", "Bosnian": "bos_Latn", "Buginese": "bug_Latn", "Bulgarian": "bul_Cyrl", "Catalan": "cat_Latn", "Cebuano": "ceb_Latn", "Czech": "ces_Latn", "Chokwe": "cjk_Latn", "Central Kurdish": "ckb_Arab", "Crimean Tatar": "crh_Latn", "Welsh": "cym_Latn", "Danish": "dan_Latn", "German": "deu_Latn", "Dinka": "dik_Latn", "Jula": "dyu_Latn", "Dzongkha": "dzo_Tibt", "Greek": "ell_Grek", "English": "eng_Latn", "Esperanto": "epo_Latn", "Estonian": "est_Latn", "Basque": "eus_Latn", "Ewe": "ewe_Latn", "Faroese": "fao_Latn", "Persian": "pes_Arab", "Fijian": "fij_Latn", "Finnish": "fin_Latn", "Fon": "fon_Latn", "French": "fra_Latn", "Friulian": "fur_Latn", "Nigerian Fulfulde": "fuv_Latn", "Scottish Gaelic": "gla_Latn", "Irish": "gle_Latn", "Galician": "glg_Latn", "Guarani": "grn_Latn", "Gujarati": "guj_Gujr", "Haitian Creole": "hat_Latn", "Hausa": "hau_Latn", "Hebrew": "heb_Hebr", "Hindi": "hin_Deva", "Chhattisgarhi": "hne_Deva", "Croatian": "hrv_Latn", "Hungarian": "hun_Latn", "Armenian": "hye_Armn", "Igbo": "ibo_Latn", "Iloko": "ilo_Latn", "Indonesian": "ind_Latn", "Icelandic": "isl_Latn", "Italian": "ita_Latn", "Javanese": "jav_Latn", "Japanese": "jpn_Jpan", "Kabyle": "kab_Latn", "Kachin": "kac_Latn", "Arabic": "ar_AR", "Chinese": "zho_Hans", "Spanish": "spa_Latn", "Dutch": "nld_Latn", "Kazakh": "kaz_Cyrl", "Korean": "kor_Hang", "Lithuanian": "lit_Latn", "Malayalam": "mal_Mlym", "Marathi": "mar_Deva", "Nepali": "ne_NP", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Russian": "rus_Cyrl", "Sinhala": "sin_Sinh", "Tamil": "tam_Taml", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl", "Urdu": "urd_Arab", "Vietnamese": "vie_Latn", "Thai":"tha_Thai" } return d[original_language] self.tokenizer.src_lang = language_mapping(original_language) if self.device_name == "cpu": # Tokenize input input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name) output = [] for target_language in target_languages: # Get language code for the target language target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] # Generate translation generated_tokens = self.model.generate( **input_ids, forced_bos_token_id=target_lang_code, max_length=128 ) generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # Append result to output output.append({ "target_language": target_language, "generated_translation": generated_translation, }) outputs = [] length = len(output[0]["generated_translation"]) for i in range(length): temp = [] for trans in output: temp.append({ "target_language": trans["target_language"], "generated_translation": trans['generated_translation'][i], }) outputs.append(temp) return outputs else: # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数) # max_batch_size = 10 # Ensure batch size is within model limits: batch_size = min(len(inputs), int(max_batch_size)) batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)] temp_outputs = [] processed_num = 0 for index, batch in enumerate(batches): # Tokenize input input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name) temp = [] for target_language in target_languages: target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] generated_tokens = self.model.generate( **input_ids, forced_bos_token_id=target_lang_code, ) generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # Append result to output temp.append({ "target_language": target_language, "generated_translation": generated_translation, }) input_ids.to('cpu') del input_ids temp_outputs.append(temp) processed_num += len(batch) outputs = [] for temp_output in temp_outputs: length = len(temp_output[0]["generated_translation"]) for i in range(length): temp = [] for trans in temp_output: temp.append({ "target_language": trans["target_language"], "generated_translation": trans['generated_translation'][i], }) outputs.append(temp) return outputs