Spaces:
Sleeping
Sleeping
File size: 5,658 Bytes
32a6937 b8c24aa 3a82207 63b82b4 c8fdb3b 3a82207 deaeb85 5983ad7 00f3401 f850f3b 00f3401 08c1bd3 f850f3b 32a6937 f850f3b 7dc3087 32a6937 7dc3087 63b82b4 32a6937 00f3401 63b82b4 32a6937 ea9c0d3 7115ad7 ea9c0d3 7dc3087 32a6937 63b82b4 64d8a64 63b82b4 64d8a64 63b82b4 c7f7d96 63b82b4 08c1bd3 79eed96 32a6937 3a82207 32a6937 5983ad7 32a6937 5983ad7 32a6937 3a82207 32a6937 3a82207 32a6937 63b82b4 5983ad7 3a82207 63b82b4 32a6937 63b82b4 ea9c0d3 3a82207 00f3401 3a82207 5983ad7 79eed96 5983ad7 79eed96 5983ad7 00f3401 3a82207 63b82b4 32a6937 79eed96 63b82b4 f850f3b 63b82b4 32a6937 f850f3b 32a6937 00f3401 63b82b4 ea9c0d3 63b82b4 9a34670 63b82b4 32a6937 63b82b4 3a82207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import json
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
import os
from threading import Thread
import subprocess
from app_modules.utils import calc_bleu_rouge_scores, detect_repetitions
from dotenv import find_dotenv, load_dotenv
found_dotenv = find_dotenv(".env")
if len(found_dotenv) == 0:
found_dotenv = find_dotenv(".env.example")
print(f"loading env vars from: {found_dotenv}")
load_dotenv(found_dotenv, override=False)
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
token = os.getenv("HUGGINGFACE_AUTH_TOKEN")
model_name = os.getenv(
"HUGGINGFACE_MODEL_NAME_OR_PATH", "google/gemma-1.1-2b-it"
) # "microsoft/Phi-3-mini-128k-instruct"
print(f" model_name: {model_name}")
HF_RP = os.getenv("HF_RP", "1.2")
repetition_penalty = float(HF_RP)
print(f" repetition_penalty: {repetition_penalty}")
questions_file_path = (
os.getenv("QUESTIONS_FILE_PATH") or "./data/datasets/ms_macro.json"
)
questions = json.loads(open(questions_file_path).read())
examples = [[question["question"].strip()] for question in questions]
print(f"Loaded {len(examples)} examples")
qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=token,
trust_remote_code=True,
)
tok = AutoTokenizer.from_pretrained(model_name, token=token)
terminators = [
tok.eos_token_id,
]
# Check that MPS is available
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled."
)
else:
print(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
mps_device = None
else:
mps_device = torch.device("mps")
if mps_device is not None:
device = mps_device
print("Using MPS")
elif torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
model = model.to(device)
def chat(
message,
history,
temperature=0,
repetition_penalty=1.1,
do_sample=True,
max_tokens=1024,
):
print("repetition_penalty:", repetition_penalty)
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
index = -1
if [message] in examples:
index = examples.index([message])
message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
print("RAG prompt:", message)
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=200.0, skip_prompt=True, skip_special_tokens=True
)
if temperature == 0:
temperature = 0.01
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=do_sample,
temperature=temperature,
eos_token_id=terminators,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
answer = partial_text
(newline_score, repetition_score, total_repetitions) = detect_repetitions(answer)
partial_text += "\n\nRepetition Metrics:\n"
partial_text += f"1. Newline Score: {newline_score:.3f}\n"
partial_text += f"1. Repetition Score: {repetition_score:.3f}\n"
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
if index >= 0: # RAG
key = (
"wellFormedAnswers"
if "wellFormedAnswers" in questions[index]
else "answers"
)
scores = calc_bleu_rouge_scores([answer], [questions[index][key]], debug=True)
partial_text += "\n\n Performance Metrics:\n"
partial_text += f'1. BLEU: {scores["bleu_scores"]["bleu"]:.3f}\n'
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
yield partial_text
demo = gr.ChatInterface(
fn=chat,
examples=examples,
cache_examples=False,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0, label="Temperature", render=False
),
gr.Slider(
minimum=1.0,
maximum=1.5,
step=0.1,
value=repetition_penalty,
label="Repetition Penalty",
render=False,
),
gr.Checkbox(label="Sampling", value=True),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description=f"Now Running [{model_name}](https://huggingface.co/{model_name})",
)
demo.launch()
|