Spaces:
Sleeping
Sleeping
File size: 5,267 Bytes
32a6937 c8fdb3b 8f1a330 5983ad7 f850f3b 7dc3087 32a6937 7dc3087 8f1a330 63b82b4 08c1bd3 79eed96 8f1a330 79eed96 8f1a330 79eed96 32a6937 3a82207 32a6937 5983ad7 32a6937 5983ad7 32a6937 3a82207 32a6937 8f1a330 5983ad7 8f1a330 5983ad7 8f1a330 63b82b4 8f1a330 3a82207 5983ad7 8f1a330 5983ad7 8f1a330 5983ad7 79eed96 5983ad7 79eed96 5983ad7 8f1a330 5983ad7 00f3401 3a82207 63b82b4 32a6937 79eed96 63b82b4 8f1a330 63b82b4 f850f3b 63b82b4 32a6937 f850f3b 32a6937 00f3401 63b82b4 ea9c0d3 63b82b4 8f1a330 63b82b4 3a82207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import json
import os
import gradio as gr
from huggingface_hub import InferenceClient
from app_modules.utils import calc_bleu_rouge_scores, detect_repetitions
from dotenv import find_dotenv, load_dotenv
found_dotenv = find_dotenv(".env")
HF_RP = os.getenv("HF_RP", "1.2")
repetition_penalty = float(HF_RP)
print(f" repetition_penalty: {repetition_penalty}")
questions_file_path = (
os.getenv("QUESTIONS_FILE_PATH") or "./data/datasets/ms_macro.json"
)
questions = json.loads(open(questions_file_path).read())
examples = [[question["question"].strip()] for question in questions]
print(f"Loaded {len(examples)} examples")
qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
def chat(
message,
history: list[tuple[str, str]],
system_message,
temperature=0,
repetition_penalty=1.1,
do_sample=True,
max_tokens=1024,
top_p=0.95,
):
print("repetition_penalty:", repetition_penalty)
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
index = -1
if [message] in examples:
index = examples.index([message])
message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
print("RAG prompt:", message)
chat.append({"role": "user", "content": message})
messages = [{"role": "system", "content": system_message}]
messages.append({"role": "user", "content": message})
partial_text = ""
# huggingface_hub.utils._errors.HfHubHTTPError: 422 Client Error: Unprocessable Entity for url: https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta (Request ID: NZamtWmdoSg3flfgRKT0e)
# Make sure 'text-generation' task is supported by the model.
# for message in client.text_generation(
# messages,
# stream=True,
# temperature=temperature,
# top_p=top_p,
# repetition_penalty=repetition_penalty,
# ):
# https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta
# {
# "id": "HuggingFaceH4/zephyr-7b-beta",
# "sha": "b70e0c9a2d9e14bd1e812d3c398e5f313e93b473",
# "pipeline_tag": "text-generation",
# "library_name": "transformers",
# "private": false,
# "gated": false,
# "siblings": [],
# "safetensors": {
# "parameters": {
# "BF16": 7241732096
# }
# },
# "cardData": {
# "tags": [
# "generated_from_trainer"
# ],
# "base_model": "mistralai/Mistral-7B-v0.1"
# }
# }
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
new_text = message.choices[0].delta.content
partial_text += new_text
yield partial_text
answer = partial_text
(whitespace_score, repetition_score, total_repetitions) = detect_repetitions(answer)
partial_text += "\n\nRepetition Metrics:\n"
partial_text += f"1. Whitespace Score: {whitespace_score:.3f}\n"
partial_text += f"1. Repetition Score: {repetition_score:.3f}\n"
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
if index >= 0: # RAG
key = (
"wellFormedAnswers"
if "wellFormedAnswers" in questions[index]
else "answers"
)
scores = calc_bleu_rouge_scores([answer], [questions[index][key]], debug=True)
partial_text += "\n\n Performance Metrics:\n"
partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n'
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
yield partial_text
demo = gr.ChatInterface(
fn=chat,
examples=examples,
cache_examples=False,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0, label="Temperature", render=False
),
gr.Slider(
minimum=1.0,
maximum=1.5,
step=0.1,
value=repetition_penalty,
label="Repetition Penalty",
render=False,
),
gr.Checkbox(label="Sampling", value=True),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
demo.launch()
|