Spaces:
Sleeping
Sleeping
show metrics in graido app
Browse files- app.py +23 -4
- app_modules/utils.py +17 -4
app.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import (
|
|
9 |
import os
|
10 |
from threading import Thread
|
11 |
import subprocess
|
|
|
12 |
|
13 |
from dotenv import find_dotenv, load_dotenv
|
14 |
|
@@ -93,10 +94,11 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
|
|
93 |
if item[1] is not None:
|
94 |
chat.append({"role": "assistant", "content": item[1]})
|
95 |
|
|
|
96 |
if [message] in examples:
|
97 |
index = examples.index([message])
|
98 |
message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
|
99 |
-
print(message)
|
100 |
|
101 |
chat.append({"role": "user", "content": message})
|
102 |
|
@@ -105,6 +107,10 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
|
|
105 |
streamer = TextIteratorStreamer(
|
106 |
tok, timeout=200.0, skip_prompt=True, skip_special_tokens=True
|
107 |
)
|
|
|
|
|
|
|
|
|
108 |
generate_kwargs = dict(
|
109 |
model_inputs,
|
110 |
streamer=streamer,
|
@@ -114,9 +120,6 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
|
|
114 |
eos_token_id=terminators,
|
115 |
)
|
116 |
|
117 |
-
if temperature == 0:
|
118 |
-
generate_kwargs["do_sample"] = False
|
119 |
-
|
120 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
121 |
t.start()
|
122 |
|
@@ -125,6 +128,22 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
|
|
125 |
partial_text += new_text
|
126 |
yield partial_text
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
yield partial_text
|
129 |
|
130 |
|
|
|
9 |
import os
|
10 |
from threading import Thread
|
11 |
import subprocess
|
12 |
+
from app_modules.utils import calc_bleu_rouge_scores, detect_repetitions
|
13 |
|
14 |
from dotenv import find_dotenv, load_dotenv
|
15 |
|
|
|
94 |
if item[1] is not None:
|
95 |
chat.append({"role": "assistant", "content": item[1]})
|
96 |
|
97 |
+
index = -1
|
98 |
if [message] in examples:
|
99 |
index = examples.index([message])
|
100 |
message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
|
101 |
+
print("RAG prompt:", message)
|
102 |
|
103 |
chat.append({"role": "user", "content": message})
|
104 |
|
|
|
107 |
streamer = TextIteratorStreamer(
|
108 |
tok, timeout=200.0, skip_prompt=True, skip_special_tokens=True
|
109 |
)
|
110 |
+
|
111 |
+
if temperature == 0:
|
112 |
+
temperature = 0.01
|
113 |
+
|
114 |
generate_kwargs = dict(
|
115 |
model_inputs,
|
116 |
streamer=streamer,
|
|
|
120 |
eos_token_id=terminators,
|
121 |
)
|
122 |
|
|
|
|
|
|
|
123 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
124 |
t.start()
|
125 |
|
|
|
128 |
partial_text += new_text
|
129 |
yield partial_text
|
130 |
|
131 |
+
answer = partial_text
|
132 |
+
(newline_score, repetition_score, total_repetitions) = detect_repetitions(answer)
|
133 |
+
partial_text += "\n\nRepetition Metrics:\n"
|
134 |
+
partial_text += f"1. Newline Score: {newline_score:.3f}\n"
|
135 |
+
partial_text += f"1. Repetition Score: {repetition_score:.3f}\n"
|
136 |
+
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
|
137 |
+
|
138 |
+
if index >= 0: # RAG
|
139 |
+
scores = calc_bleu_rouge_scores(
|
140 |
+
[answer], [questions[index]["wellFormedAnswers"]], debug=True
|
141 |
+
)
|
142 |
+
|
143 |
+
partial_text += "\n\n Performance Metrics:\n"
|
144 |
+
partial_text += f'1. BLEU: {scores["bleu_scores"]["bleu"]:.3f}\n'
|
145 |
+
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
|
146 |
+
|
147 |
yield partial_text
|
148 |
|
149 |
|
app_modules/utils.py
CHANGED
@@ -191,15 +191,28 @@ bleu = evaluate.load("bleu")
|
|
191 |
rouge = evaluate.load("rouge")
|
192 |
|
193 |
|
194 |
-
def
|
195 |
-
|
196 |
-
|
|
|
197 |
|
198 |
bleu_scores = bleu.compute(
|
199 |
predictions=predictions, references=references, max_order=1
|
200 |
)
|
201 |
rouge_scores = rouge.compute(predictions=predictions, references=references)
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
|
205 |
pattern_abnormal_newlines = re.compile(r"\n{5,}")
|
|
|
191 |
rouge = evaluate.load("rouge")
|
192 |
|
193 |
|
194 |
+
def calc_bleu_rouge_scores(predictions, references, debug=False):
|
195 |
+
if debug:
|
196 |
+
print("predictions:", predictions)
|
197 |
+
print("references:", references)
|
198 |
|
199 |
bleu_scores = bleu.compute(
|
200 |
predictions=predictions, references=references, max_order=1
|
201 |
)
|
202 |
rouge_scores = rouge.compute(predictions=predictions, references=references)
|
203 |
+
result = {"bleu_scores": bleu_scores, "rouge_scores": rouge_scores}
|
204 |
+
|
205 |
+
if debug:
|
206 |
+
print("result:", result)
|
207 |
+
|
208 |
+
return result
|
209 |
+
|
210 |
+
|
211 |
+
def calc_metrics(df):
|
212 |
+
predictions = [df["answer"][i] for i in range(len(df))]
|
213 |
+
references = [df["ground_truth"][i] for i in range(len(df))]
|
214 |
+
|
215 |
+
return calc_bleu_rouge_scores(predictions, references)
|
216 |
|
217 |
|
218 |
pattern_abnormal_newlines = re.compile(r"\n{5,}")
|