Spaces:
Sleeping
Sleeping
import json | |
import os | |
import sys | |
import pandas as pd | |
from timeit import default_timer as timer | |
import nltk | |
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat" | |
if chatting: | |
os.environ["BATCH_SIZE"] = "1" | |
from app_modules.init import app_init | |
from app_modules.llm_qa_chain import QAChain | |
from app_modules.utils import print_llm_response, calc_metrics, detect_repetition_scores | |
llm_loader, qa_chain = app_init() | |
if chatting: | |
print("Starting chat mode") | |
while True: | |
question = input("Please enter your question: ") | |
if question.lower() == "exit": | |
break | |
result = qa_chain.call_chain({"question": question, "chat_history": []}, None) | |
print_llm_response(result) | |
sys.exit(0) | |
num_of_questions = 0 | |
if len(sys.argv) > 1: | |
num_of_questions = int(sys.argv[1]) | |
# Create an empty DataFrame with column names | |
df = pd.DataFrame( | |
columns=[ | |
"id", | |
"question", | |
"answer", | |
] | |
) | |
batch_size = int(os.getenv("BATCH_SIZE", "1")) | |
print(f"Batch size: {batch_size}") | |
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH") | |
debug_retrieval = os.getenv("DEBUG_RETRIEVAL", "false").lower() == "true" | |
# Open the file for reading | |
print(f"Reading questions from file: {questions_file_path}") | |
test_data = json.loads(open(questions_file_path).read()) | |
if isinstance(test_data, dict): | |
questions = [test_data[key] for key in test_data.keys()] | |
ids = [key for key in test_data.keys()] | |
else: | |
questions = test_data | |
ids = [row["id"] for row in questions] | |
if num_of_questions > 0: | |
questions = questions[:num_of_questions] | |
print(f"Number of questions: {len(questions)}") | |
if __name__ == "__main__": | |
chat_start = timer() | |
index = 0 | |
while index < len(questions): | |
batch_ids = ids[index : index + batch_size] | |
batch_questions = [q["question"] for q in questions[index : index + batch_size]] | |
if isinstance(qa_chain, QAChain): | |
inputs = [{"question": q, "chat_history": []} for q in batch_questions] | |
else: | |
inputs = [{"question": q} for q in batch_questions] | |
start = timer() | |
result = qa_chain.call_chain(inputs, None) | |
end = timer() | |
print(f"Completed in {end - start:.3f}s") | |
# print("result:", result) | |
batch_answers = [r["answer"] for r in result] | |
for id, question, answer in zip(batch_ids, batch_questions, batch_answers): | |
df.loc[len(df)] = { | |
"id": id, | |
"question": question, | |
"answer": answer, | |
} | |
index += batch_size | |
for r in result: | |
print_llm_response(r, debug_retrieval) | |
chat_end = timer() | |
total_time = chat_end - chat_start | |
print(f"Total time used: {total_time:.3f} s") | |
df2 = pd.DataFrame( | |
columns=[ | |
"id", | |
"question", | |
"answer", | |
"word_count", | |
"ground_truth", | |
] | |
) | |
for i in range(len(df)): | |
question = questions[i] | |
answer = df["answer"][i] | |
query = df["question"][i] | |
id = df["id"][i] | |
ground_truth = question[ | |
"wellFormedAnswers" if "wellFormedAnswers" in question else "answers" | |
] | |
word_count = len(nltk.word_tokenize(answer)) | |
df2.loc[len(df2)] = { | |
"id": id, | |
"question": query, | |
"answer": answer, | |
"word_count": word_count, | |
"ground_truth": ground_truth, | |
} | |
df2[["newline_score", "repetition_score", "total_repetitions"]] = df2[ | |
"answer" | |
].apply(detect_repetition_scores) | |
pd.options.display.float_format = "{:.3f}".format | |
print(df2.describe()) | |
word_count = df2["word_count"].sum() | |
csv_file = ( | |
os.getenv("TEST_RESULTS_CSV_FILE") or f"qa_batch_{batch_size}_test_results.csv" | |
) | |
with open(csv_file, "w") as f: | |
f.write( | |
f"# RAG: {isinstance(qa_chain, QAChain)} questions: {questions_file_path}\n" | |
) | |
f.write( | |
f"# model: {llm_loader.model_name} repetition_penalty: {llm_loader.repetition_penalty}\n" | |
) | |
df2.to_csv(csv_file, mode="a", index=False, header=True) | |
print(f"test results saved to file: {csv_file}") | |
scores = calc_metrics(df2) | |
df = pd.DataFrame( | |
{ | |
"model": [llm_loader.model_name], | |
"repetition_penalty": [llm_loader.repetition_penalty], | |
"word_count": [word_count], | |
"inference_time": [total_time], | |
"inference_speed": [word_count / total_time], | |
"bleu1": [scores["bleu_scores"]["bleu"]], | |
"rougeL": [scores["rouge_scores"]["rougeL"]], | |
} | |
) | |
print(f"Number of words generated: {word_count}") | |
print(f"Average generation speed: {word_count / total_time:.3f} words/s") | |
csv_file = os.getenv("ALL_RESULTS_CSV_FILE") or "qa_chain_all_results.csv" | |
file_existed = os.path.exists(csv_file) and os.path.getsize(csv_file) > 0 | |
df.to_csv(csv_file, mode="a", index=False, header=not file_existed) | |
print(f"all results appended to file: {csv_file}") | |