abinashbordoloi commited on
Commit
f89e277
1 Parent(s): 91eeaad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -81
app.py CHANGED
@@ -1,90 +1,55 @@
1
- import os
2
- import torch
3
  import gradio as gr
4
- import time
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
6
  from supported_languages import LANGS
 
7
 
 
 
 
8
 
9
- def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {
12
- # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
13
- # 'nllb-1.3B': 'facebook/nllb-200-1.3B',
14
- 'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
15
- # #'nllb-3.3B': 'facebook/nllb-200-3.3B',
16
- }
17
-
18
- model_dict = {}
19
-
20
- for call_name, real_name in model_name_dict.items():
21
- print('\tLoading model: %s' % call_name)
22
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
23
- tokenizer = AutoTokenizer.from_pretrained(real_name)
24
- model_dict[call_name+'_model'] = model
25
- model_dict[call_name+'_tokenizer'] = tokenizer
26
-
27
- return model_dict
28
-
29
-
30
- def translation(source, target, text):
31
- if len(model_dict) == 2:
32
- model_name = 'nllb-distilled-1.3B'
33
-
34
- start_time = time.time()
35
- source = LANGS[source]
36
- target = LANGS[target]
37
-
38
- model = model_dict[model_name + '_model']
39
- tokenizer = model_dict[model_name + '_tokenizer']
40
-
41
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
42
- output = translator(text, max_length=400)
43
 
44
- end_time = time.time()
45
 
46
- full_output = output
47
- output = output[0]['translation_text']
48
- result = {'inference_time': end_time - start_time,
49
- 'source': source,
50
- 'target': target,
51
- 'result': output,
52
- 'full_output': full_output}
53
- return result
54
 
55
-
56
- if __name__ == '__main__':
57
- print('\tinit models')
58
-
59
- global model_dict
60
-
61
- model_dict = load_models()
 
 
 
62
 
63
- # define gradio demo
64
- lang_codes = list(LANGS.keys())
65
- #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
66
- inputs = [gr.Dropdown(label='Source',choices = LANGS.keys()),
67
- gr.Dropdown( label='Target',choices = LANGS.keys()),
68
- gr.Textbox(label="Input text", lines=5,placeholder = "Enter text to translate"),
69
- ]
70
-
71
- outputs = gr.JSON()
72
-
73
- title = "Anubaad-Assamese-Translation-Application-NLLB-200"
74
-
75
- demo_status = "Demo is running on CPU"
76
- description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
77
- examples = [
78
- ['Chinese (Simplified)', 'English', '你吃饭了吗?']
79
- ]
80
-
81
- gr.Interface(translation,
82
- inputs,
83
- outputs,
84
- title=title,
85
- description=description,
86
- examples=examples,
87
- examples_per_page=50,
88
- ).launch()
89
-
90
-
 
 
 
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import torch
4
+ from ui import title, description, examples
5
  from supported_languages import LANGS
6
+ #from langs_all import LANGS ##for 200+ languages
7
 
8
+ TASK = "translation"
9
+ CKPT = "facebook/nllb-200-distilled-1.3B"
10
+ #CKPT = "facebook/nllb-200-distilled-600M"
11
 
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
13
+ tokenizer = AutoTokenizer.from_pretrained(CKPT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # device = 0 if torch.cuda.is_available() else -1
16
 
 
 
 
 
 
 
 
 
17
 
18
+ def translate(text, src_lang, tgt_lang, max_length=512):
19
+ """
20
+ Translate the text from source lang to target lang
21
+ """
22
+ translation_pipeline = pipeline(TASK,
23
+ model=model,
24
+ tokenizer=tokenizer,
25
+ src_lang=src_lang,
26
+ tgt_lang=tgt_lang,
27
+ max_length=max_length)
28
 
29
+ # translation_pipeline = pipeline(TASK,
30
+ # model=model,
31
+ # tokenizer=tokenizer,
32
+ # src_lang=src_lang,
33
+ # tgt_lang=tgt_lang,
34
+ # max_length=max_length,
35
+ # device=device)
36
+
37
+ result = translation_pipeline(text)
38
+ return result[0]['translation_text']
39
+
40
+
41
+ gr.Interface(
42
+ translate,
43
+ [
44
+ gr.Textbox(label="Text",placeholder ="Enter Your Text here"),
45
+ gr.Dropdown(label="Source Language", choices=LANGS.key()),
46
+ gr.Dropdown(label="Target Language", choices=LANGS.key()),
47
+ gr.Slider(8, 512, value=512, step=8, label="Max Length")
48
+ ],
49
+ ["text"],
50
+ examples=examples,
51
+ # article=article,
52
+ cache_examples=False,
53
+ title=title,
54
+ description=description
55
+ ).launch()