import gradio as gr import os import torch import gradio as gr from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to(device) model.eval() class Language: def __init__(self, name, code): self.name = name self.code = code lang_id = [ Language("Afrikaans", "af"), Language("Albanian", "sq"), Language("Amharic", "am"), Language("Arabic", "ar"), Language("Armenian", "hy"), Language("Asturian", "ast"), Language("Azerbaijani", "az"), Language("Bashkir", "ba"), Language("Belarusian", "be"), Language("Bulgarian", "bg"), Language("Bengali", "bn"), Language("Breton", "br"), Language("Bosnian", "bs"), Language("Burmese", "my"), Language("Catalan", "ca"), Language("Cebuano", "ceb"), Language("Chinese","zh"), Language("Croatian","hr"), Language("Czech","cs"), Language("Danish","da"), Language("Dutch","nl"), Language("English","en"), Language("Estonian","et"), Language("Fulah","ff"), Language("Finnish","fi"), Language("French","fr"), Language("Western Frisian","fy"), Language("Gaelic","gd"), Language("Galician","gl"), Language("Georgian","ka"), Language("German","de"), Language("Greek","el"), Language("Gujarati","gu"), Language("Hausa","ha"), Language("Hebrew","he"), Language("Hindi","hi"), Language("Haitian","ht"), Language("Hungarian","hu"), Language("Irish","ga"), Language("Indonesian","id"), Language("Igbo","ig"), Language("Iloko","ilo"), Language("Icelandic","is"), Language("Italian","it"), Language("Japanese","ja"), Language("Javanese","jv"), Language("Kazakh","kk"), Language("Central Khmer","km"), Language("Kannada","kn"), Language("Korean","ko"), Language("Luxembourgish","lb"), Language("Ganda","lg"), Language("Lingala","ln"), Language("Lao","lo"), Language("Lithuanian","lt"), Language("Latvian","lv"), Language("Malagasy","mg"), Language("Macedonian","mk"), Language("Malayalam","ml"), Language("Mongolian","mn"), Language("Marathi","mr"), Language("Malay","ms"), Language("Nepali","ne"), Language("Norwegian","no"), Language("Northern Sotho","ns"), Language("Occitan","oc"), Language("Oriya","or"), Language("Panjabi","pa"), Language("Persian","fa"), Language("Polish","pl"), Language("Pushto","ps"), Language("Portuguese","pt"), Language("Romanian","ro"), Language("Russian","ru"), Language("Sindhi","sd"), Language("Sinhala","si"), Language("Slovak","sk"), Language("Slovenian","sl"), Language("Spanish","es"), Language("Somali","so"), Language("Serbian","sr"), Language("Serbian (cyrillic)","sr"), Language("Serbian (latin)","sr"), Language("Swati","ss"), Language("Sundanese","su"), Language("Swedish","sv"), Language("Swahili","sw"), Language("Tamil","ta"), Language("Thai","th"), Language("Tagalog","tl"), Language("Tswana","tn"), Language("Turkish","tr"), Language("Ukrainian","uk"), Language("Urdu","ur"), Language("Uzbek","uz"), Language("Vietnamese","vi"), Language("Welsh","cy"), Language("Wolof","wo"), Language("Xhosa","xh"), Language("Yiddish","yi"), Language("Yoruba","yo"), Language("Zulu","zu"), ] d_lang = lang_id[21] #d_lang_code = d_lang.code def trans_page(input,trg): src_lang = d_lang.code for lang in lang_id: if lang.name == trg: trg_lang = lang.code if trg_lang != src_lang: tokenizer.src_lang = src_lang with torch.no_grad(): encoded_input = tokenizer(input, return_tensors="pt").to(device) generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] else: translated_text=input pass return translated_text def trans_to(input,src,trg): for lang in lang_id: if lang.name == trg: trg_lang = lang.code for lang in lang_id: if lang.name == src: src_lang = lang.code if trg_lang != src_lang: tokenizer.src_lang = src_lang with torch.no_grad(): encoded_input = tokenizer(input, return_tensors="pt").to(device) generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] else: translated_text=input pass return translated_text md1 = "Translate - 100 Languages" with gr.Blocks() as transbot: #this=gr.State() with gr.Row(): gr.Column() with gr.Column(): with gr.Row(): t_space = gr.Dropdown(label="Translate Space to:", choices=[l.name for l in lang_id], value="English") #t_space = gr.Dropdown(label="Translate Space", choices=list(lang_id.keys()),value="English") t_submit = gr.Button("Translate Space") gr.Column() with gr.Row(): gr.Column() with gr.Column(): md = gr.Markdown("""

Translate - 100 Languages

Translation may not be accurate

""") with gr.Row(): lang_from = gr.Dropdown(label="From:", choices=[l.name for l in lang_id],value="English") lang_to = gr.Dropdown(label="To:", choices=[l.name for l in lang_id],value="Chinese") #lang_from = gr.Dropdown(label="From:", choices=list(lang_id.keys()),value="English") #lang_to = gr.Dropdown(label="To:", choices=list(lang_id.keys()),value="Chinese") submit = gr.Button("Go") with gr.Row(): with gr.Column(): message = gr.Textbox(label="Prompt",placeholder="Enter Prompt",lines=4) translated = gr.Textbox(label="Translated",lines=4,interactive=False) gr.Column() t_submit.click(trans_page,[md,t_space],[md]) submit.click(trans_to, inputs=[message,lang_from,lang_to], outputs=[translated]) transbot.queue(concurrency_count=20) transbot.launch()