barbaroo's picture
Update app.py
e311694 verified
raw
history blame
No virus
2.84 kB
'''
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import nltk
from nltk.tokenize import sent_tokenize
import torch
# Initialize and download necessary NLTK resources
nltk.download('punkt')
# Load the models and tokenizers
model_checkpoint_fo_en = "barbaroo/nllb_200_600M_fo_en"
model_checkpoint_en_fo = "barbaroo/nllb_200_600M_en_fo"
tokenizer_fo_en = AutoTokenizer.from_pretrained(model_checkpoint_fo_en)
model_fo_en = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint_fo_en)
tokenizer_en_fo = AutoTokenizer.from_pretrained(model_checkpoint_en_fo)
model_en_fo = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint_en_fo)
# Check if a GPU is available and move models to GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
print("GPU is available. Initializing models on GPU.")
model_fo_en.to(device)
model_en_fo.to(device)
else:
print("GPU is not available. Using CPU.")
def split_into_sentences(text):
return sent_tokenize(text)
def translate(text, model, tokenizer, max_length=80):
sentences = split_into_sentences(text)
translated_text = []
for sentence in sentences:
inputs = tokenizer.encode(sentence, return_tensors="pt", max_length=max_length, truncation=True).to(device)
outputs = model.generate(inputs, max_length=max_length, num_beams=4, early_stopping=True)
translated_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
translated_text.append(translated_sentence)
return " ".join(translated_text)
def handle_input(text, file, direction):
if file is not None:
# Decode the file bytes directly
text = file.decode("utf-8")
if direction == "fo_en":
model = model_fo_en
tokenizer = tokenizer_fo_en
else:
model = model_en_fo
tokenizer = tokenizer_en_fo
# Translate the text if it's not empty
if text:
return translate(text, model, tokenizer)
else:
return "Please enter text or upload a text file."
# Define the Gradio interface
iface = gr.Interface(
fn=handle_input,
inputs=[
gr.Textbox(lines=2, placeholder="Type here or upload a text file..."),
gr.File(label="or Upload Text File", type="binary"),
gr.Dropdown(label="Translation Direction", choices=["fo_en", "en_fo"], value="fo_en")
],
outputs="text",
title="Bidirectional Translator",
description="Enter text directly or upload a text file (.txt) to translate between Faroese and English."
)
# Launch the interface
iface.launch()
'''
import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")