File size: 1,791 Bytes
aeeb9a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from llmware.prompts import Prompt
def load_rag_benchmark_tester_ds():
# pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo
from datasets import load_dataset
ds_name = "llmware/rag_instruct_benchmark_tester"
dataset = load_dataset(ds_name)
print("update: loading test dataset - ", dataset)
test_set = []
for i, samples in enumerate(dataset["train"]):
test_set.append(samples)
# to view test set samples
# print("rag benchmark dataset test samples: ", i, samples)
return test_set
def run_test(model_name, prompt_list):
print("\nupdate: Starting RAG Benchmark Inference Test")
prompter = Prompt().load_model(model_name,from_hf=True)
for i, entries in enumerate(prompt_list):
prompt = entries["query"]
context = entries["context"]
response = prompter.prompt_main(prompt,context=context,prompt_name="default_with_context", temperature=0.3)
fc = prompter.evidence_check_numbers(response)
sc = prompter.evidence_comparison_stats(response)
sr = prompter.evidence_check_sources(response)
print("\nupdate: model inference output - ", i, response["llm_response"])
print("update: gold_answer - ", i, entries["answer"])
for entries in fc:
print("update: fact check - ", entries["fact_check"])
for entries in sc:
print("update: comparison stats - ", entries["comparison_stats"])
for entries in sr:
print("update: sources - ", entries["source_review"])
return 0
if __name__ == "__main__":
core_test_set = load_rag_benchmark_tester_ds()
model_name = "llmware/dragon-red-pajama-7b-v0"
output = run_test(model_name, core_test_set)
|