dragoman / app.py
robinhad's picture
Disable logging
25432dd verified
raw
history blame
6.18 kB
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import (
MistralForCausalLM,
TextIteratorStreamer,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
)
from time import sleep
from threading import Thread
from torch import float16
import spaces
import huggingface_hub
from threading import Thread
from queue import Queue
from time import sleep
from os import getenv
# from data_logger import log_data
from datetime import datetime
def check_thread(logging_queue: Queue):
logging_callback = log_data(
hf_token=getenv("HF_API_TOKEN"),
dataset_name=getenv("OUTPUT_DATASET"),
private=True,
)
while True:
sleep(60)
batch = []
while not logging_queue.empty():
batch.append(logging_queue.get())
if len(batch) > 0:
try:
logging_callback(batch)
except:
print(
"Error happened while pushing data to HF. Puttting items back in queue..."
)
for item in batch:
logging_queue.put(item)
if False: #getenv("HF_API_TOKEN") is not None:
#print("Starting logging thread...")
#log_queue = Queue()
#t = Thread(target=check_thread, args=(log_queue,))
#t.start()
logging_callback = log_data(
hf_token=getenv("HF_API_TOKEN"),
dataset_name=getenv("OUTPUT_DATASET"),
private=True,
)
else:
print("No HF_API_TOKEN found. Logging is disabled.")
config = PeftConfig.from_pretrained("lang-uk/dragoman")
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=float16,
bnb_4bit_use_double_quant=False,
)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", quantization_config=quant_config
)
# device_map="auto",)
model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False
)
@spaces.GPU(duration=30)
def translate(input_text):
global log_queue
# generated_text = ""
input_text = input_text.strip()
print(f"{datetime.utcnow()} | Translating: {input_text}")
if False: #getenv("HF_API_TOKEN") is not None:
try:
logging_callback = log_data(
hf_token=getenv("HF_API_TOKEN"),
dataset_name=getenv("OUTPUT_DATASET"),
private=True,
)
logging_callback([[input_text]])
except:
print("Error happened while pushing data to HF.")
input_text = f"[INST] {input_text} [/INST]"
inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
generation_kwargs = dict(
inputs, max_new_tokens=200, num_beams=10, temperature=1, pad_token_id=tokenizer.eos_token_id
) # streamer=streamer,
# streaming support
# streamer = TextIteratorStreamer(
# tokenizer, skip_prompt=True, skip_special_tokens=True
# )
# thread = Thread(target=model.generate, kwargs=generation_kwargs)
# thread.start()
# for new_text in streamer:
# generated_text += new_text
# yield generated_text
# generated_text += "\n"
# yield generated_text
output = model.generate(**generation_kwargs)
output = (
tokenizer.decode(output[0], skip_special_tokens=True)
.split("[/INST] ")[-1]
.strip()
)
return output
# download description of the model
desc_file = huggingface_hub.hf_hub_download("lang-uk/dragoman", "README.md")
with open(desc_file, "r") as f:
model_description = f.read()
model_description = model_description[model_description.find("---", 1) + 5 :]
model_description = (
"""### By using this service, users are required to agree to the following terms: you agree that user input will be collected for future research and model improvements. \n\n"""
+ model_description
)
iface = gr.Interface(
fn=translate,
inputs=gr.Textbox(
value='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.',
label="Source sentence",
),
outputs=gr.Textbox(
value='Ця демо-версія містить модель із статті "Налаштування принтера даних із покращеним машинним перекладом з англійської на українську", яка була прийнята до семінару UNLP 2024 на конференції LREC-COLING 2024.',
label="Translated sentence",
),
examples=[
[
"The Colosseum in Rome was a symbol of the grandeur and power of the Roman Empire and was a place for the emperor to connect with the people by providing them with entertainment and free food."
],
[
"How many leaves would it drop in a month of February in a non-leap year?",
],
[
"ChatGPT (Chat Generative Pre-trained Transformer) is a chatbot developed by OpenAI and launched on November 30, 2022. Based on a large language model, it enables users to refine and steer a conversation towards a desired length, format, style, level of detail, and language. Successive prompts and replies, known as prompt engineering, are considered at each conversation stage as a context.[2] ",
],
[
"who holds this neighborhood?",
],
],
title="Dragoman: SOTA English-Ukrainian translation model",
description='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.',
article=model_description,
# thumbnail: str | None = None,
# css: str | None = None,
# batch: bool = False,
# max_batch_size: int = 4,
# api_name: str | Literal[False] | None = "predict",
submit_btn="Translate",
)
iface.launch()