Spaces:
Sleeping
Sleeping
import gradio as gr | |
from peft import PeftModel, PeftConfig | |
from transformers import ( | |
MistralForCausalLM, | |
TextIteratorStreamer, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
GenerationConfig, | |
) | |
from time import sleep | |
from threading import Thread | |
from torch import float16 | |
import spaces | |
import huggingface_hub | |
from threading import Thread | |
from queue import Queue | |
from time import sleep | |
from os import getenv | |
# from data_logger import log_data | |
from datetime import datetime | |
def check_thread(logging_queue: Queue): | |
logging_callback = log_data( | |
hf_token=getenv("HF_API_TOKEN"), | |
dataset_name=getenv("OUTPUT_DATASET"), | |
private=True, | |
) | |
while True: | |
sleep(60) | |
batch = [] | |
while not logging_queue.empty(): | |
batch.append(logging_queue.get()) | |
if len(batch) > 0: | |
try: | |
logging_callback(batch) | |
except: | |
print( | |
"Error happened while pushing data to HF. Puttting items back in queue..." | |
) | |
for item in batch: | |
logging_queue.put(item) | |
if False: #getenv("HF_API_TOKEN") is not None: | |
#print("Starting logging thread...") | |
#log_queue = Queue() | |
#t = Thread(target=check_thread, args=(log_queue,)) | |
#t.start() | |
logging_callback = log_data( | |
hf_token=getenv("HF_API_TOKEN"), | |
dataset_name=getenv("OUTPUT_DATASET"), | |
private=True, | |
) | |
else: | |
print("No HF_API_TOKEN found. Logging is disabled.") | |
config = PeftConfig.from_pretrained("lang-uk/dragoman") | |
quant_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=float16, | |
bnb_4bit_use_double_quant=False, | |
) | |
model = MistralForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", quantization_config=quant_config | |
) | |
# device_map="auto",) | |
model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False | |
) | |
def translate(input_text): | |
global log_queue | |
# generated_text = "" | |
input_text = input_text.strip() | |
print(f"{datetime.utcnow()} | Translating: {input_text}") | |
if False: #getenv("HF_API_TOKEN") is not None: | |
try: | |
logging_callback = log_data( | |
hf_token=getenv("HF_API_TOKEN"), | |
dataset_name=getenv("OUTPUT_DATASET"), | |
private=True, | |
) | |
logging_callback([[input_text]]) | |
except: | |
print("Error happened while pushing data to HF.") | |
input_text = f"[INST] {input_text} [/INST]" | |
inputs = tokenizer([input_text], return_tensors="pt").to(model.device) | |
generation_kwargs = dict( | |
inputs, max_new_tokens=200, num_beams=10, temperature=1, pad_token_id=tokenizer.eos_token_id | |
) # streamer=streamer, | |
# streaming support | |
# streamer = TextIteratorStreamer( | |
# tokenizer, skip_prompt=True, skip_special_tokens=True | |
# ) | |
# thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
# thread.start() | |
# for new_text in streamer: | |
# generated_text += new_text | |
# yield generated_text | |
# generated_text += "\n" | |
# yield generated_text | |
output = model.generate(**generation_kwargs) | |
output = ( | |
tokenizer.decode(output[0], skip_special_tokens=True) | |
.split("[/INST] ")[-1] | |
.strip() | |
) | |
return output | |
# download description of the model | |
desc_file = huggingface_hub.hf_hub_download("lang-uk/dragoman", "README.md") | |
with open(desc_file, "r") as f: | |
model_description = f.read() | |
model_description = model_description[model_description.find("---", 1) + 5 :] | |
model_description = ( | |
"""### By using this service, users are required to agree to the following terms: you agree that user input will be collected for future research and model improvements. \n\n""" | |
+ model_description | |
) | |
iface = gr.Interface( | |
fn=translate, | |
inputs=gr.Textbox( | |
value='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.', | |
label="Source sentence", | |
), | |
outputs=gr.Textbox( | |
value='Ця демо-версія містить модель із статті "Налаштування принтера даних із покращеним машинним перекладом з англійської на українську", яка була прийнята до семінару UNLP 2024 на конференції LREC-COLING 2024.', | |
label="Translated sentence", | |
), | |
examples=[ | |
[ | |
"The Colosseum in Rome was a symbol of the grandeur and power of the Roman Empire and was a place for the emperor to connect with the people by providing them with entertainment and free food." | |
], | |
[ | |
"How many leaves would it drop in a month of February in a non-leap year?", | |
], | |
[ | |
"ChatGPT (Chat Generative Pre-trained Transformer) is a chatbot developed by OpenAI and launched on November 30, 2022. Based on a large language model, it enables users to refine and steer a conversation towards a desired length, format, style, level of detail, and language. Successive prompts and replies, known as prompt engineering, are considered at each conversation stage as a context.[2] ", | |
], | |
[ | |
"who holds this neighborhood?", | |
], | |
], | |
title="Dragoman: SOTA English-Ukrainian translation model", | |
description='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.', | |
article=model_description, | |
# thumbnail: str | None = None, | |
# css: str | None = None, | |
# batch: bool = False, | |
# max_batch_size: int = 4, | |
# api_name: str | Literal[False] | None = "predict", | |
submit_btn="Translate", | |
) | |
iface.launch() | |