Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
import os | |
from threading import Thread | |
import subprocess | |
from app_modules.utils import calc_bleu_rouge_scores, detect_repetitions | |
from dotenv import find_dotenv, load_dotenv | |
found_dotenv = find_dotenv(".env") | |
if len(found_dotenv) == 0: | |
found_dotenv = find_dotenv(".env.example") | |
print(f"loading env vars from: {found_dotenv}") | |
load_dotenv(found_dotenv, override=False) | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
token = os.getenv("HUGGINGFACE_AUTH_TOKEN") | |
model_name = os.getenv( | |
"HUGGINGFACE_MODEL_NAME_OR_PATH", "google/gemma-1.1-2b-it" | |
) # "microsoft/Phi-3-mini-128k-instruct" | |
print(f" model_name: {model_name}") | |
HF_RP = os.getenv("HF_RP", "1.2") | |
repetition_penalty = float(HF_RP) | |
print(f" repetition_penalty: {repetition_penalty}") | |
questions_file_path = ( | |
os.getenv("QUESTIONS_FILE_PATH") or "./data/datasets/ms_macro.json" | |
) | |
questions = json.loads(open(questions_file_path).read()) | |
examples = [[question["question"].strip()] for question in questions] | |
print(f"Loaded {len(examples)} examples") | |
qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer." | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=token, | |
trust_remote_code=True, | |
) | |
tok = AutoTokenizer.from_pretrained(model_name, token=token) | |
terminators = [ | |
tok.eos_token_id, | |
] | |
# Check that MPS is available | |
if not torch.backends.mps.is_available(): | |
if not torch.backends.mps.is_built(): | |
print( | |
"MPS not available because the current PyTorch install was not " | |
"built with MPS enabled." | |
) | |
else: | |
print( | |
"MPS not available because the current MacOS version is not 12.3+ " | |
"and/or you do not have an MPS-enabled device on this machine." | |
) | |
mps_device = None | |
else: | |
mps_device = torch.device("mps") | |
if mps_device is not None: | |
device = mps_device | |
print("Using MPS") | |
elif torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print(f"Using GPU: {torch.cuda.get_device_name(device)}") | |
else: | |
device = torch.device("cpu") | |
print("Using CPU") | |
model = model.to(device) | |
def chat( | |
message, | |
history, | |
temperature=0, | |
repetition_penalty=1.1, | |
do_sample=True, | |
max_tokens=1024, | |
): | |
print("repetition_penalty:", repetition_penalty) | |
chat = [] | |
for item in history: | |
chat.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
chat.append({"role": "assistant", "content": item[1]}) | |
index = -1 | |
if [message] in examples: | |
index = examples.index([message]) | |
message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}" | |
print("RAG prompt:", message) | |
chat.append({"role": "user", "content": message}) | |
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
model_inputs = tok([messages], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer( | |
tok, timeout=200.0, skip_prompt=True, skip_special_tokens=True | |
) | |
if temperature == 0: | |
temperature = 0.01 | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
do_sample=do_sample, | |
temperature=temperature, | |
eos_token_id=terminators, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
answer = partial_text | |
(newline_score, repetition_score, total_repetitions) = detect_repetitions(answer) | |
partial_text += "\n\nRepetition Metrics:\n" | |
partial_text += f"1. Newline Score: {newline_score:.3f}\n" | |
partial_text += f"1. Repetition Score: {repetition_score:.3f}\n" | |
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n" | |
if index >= 0: # RAG | |
key = ( | |
"wellFormedAnswers" | |
if "wellFormedAnswers" in questions[index] | |
else "answers" | |
) | |
scores = calc_bleu_rouge_scores([answer], [questions[index][key]], debug=True) | |
partial_text += "\n\n Performance Metrics:\n" | |
partial_text += f'1. BLEU: {scores["bleu_scores"]["bleu"]:.3f}\n' | |
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n' | |
yield partial_text | |
demo = gr.ChatInterface( | |
fn=chat, | |
examples=examples, | |
cache_examples=False, | |
additional_inputs_accordion=gr.Accordion( | |
label="⚙️ Parameters", open=False, render=False | |
), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=0, maximum=1, step=0.1, value=0, label="Temperature", render=False | |
), | |
gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
step=0.1, | |
value=repetition_penalty, | |
label="Repetition Penalty", | |
render=False, | |
), | |
gr.Checkbox(label="Sampling", value=True), | |
gr.Slider( | |
minimum=128, | |
maximum=4096, | |
step=1, | |
value=512, | |
label="Max new tokens", | |
render=False, | |
), | |
], | |
stop_btn="Stop Generation", | |
title="Chat With LLMs", | |
description=f"Now Running [{model_name}](https://huggingface.co/{model_name})", | |
) | |
demo.launch() | |