Spaces:
Runtime error
Runtime error
File size: 6,064 Bytes
28635a8 3f325ac 28635a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import pandas as pd
from datasets import Dataset
from transformers import pipeline, GPT2Tokenizer
from sentence_transformers import SentenceTransformer, util
# Define paths and models
filename = "output_chess_details.txt"
retrieval_model_name = 'output/sentence-transformer-finetuned/' #using a prefine-tuned model
gpt2_model_name = "gpt2"
csv_file_path = "train_dataset.csv"
output_csv_file_path = "updated_train_dataset.csv"
val_csv_file_path = "val_dataset.csv"
output_val_csv_file_path = "updated_val_csv.csv"
tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
# Initialize models
try:
retrieval_model = SentenceTransformer(retrieval_model_name)
gpt_model = pipeline("text-generation", model=gpt2_model_name)
print("Models loaded successfully.")
except Exception as e:
print(f"Failed to load models: {e}")
def load_and_preprocess_text(filename):
"""
Load and preprocess text data from a file.
Parameters:
- filename (str): Path to the text file.
Returns:
- list[str]: A list of preprocessed text segments.
"""
try:
with open(filename, 'r', encoding='utf-8') as file:
segments = [line.strip() for line in file if line.strip()]
print("Text loaded and preprocessed successfully.")
return segments
except Exception as e:
print(f"Failed to load or preprocess text: {e}")
return []
segments = load_and_preprocess_text(filename)
def find_relevant_segment(user_query, segments):
"""
Find the most relevant text segment based on a user query.
Parameters:
- user_query (str): The user's query.
- segments (list[str]): List of text segments to search within.
Returns:
- str: The most relevant text segment.
"""
try:
query_embedding = retrieval_model.encode(user_query)
segment_embeddings = retrieval_model.encode(segments)
similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
best_idx = similarities.argmax()
return segments[best_idx]
except Exception as e:
print(f"Error finding relevant segment: {e}")
return ""
def generate_response(question):
"""
Generate a response to a given question by finding a relevant text segment and
using it to generate a more complete answer.
Parameters:
- question (str): The user's question.
Returns:
- str: Generated response.
"""
relevant_segment = find_relevant_segment(question, segments)
return generate_response_with_context(question, relevant_segment)
def generate_response_with_context(user_query, relevant_segment):
"""
Generate a response based on a user query and a relevant segment.
Parameters:
- user_query (str): The user's query.
- relevant_segment (str): A relevant fact or detail.
Returns:
- str: Formatted response incorporating the relevant segment.
"""
try:
prompt = f"Thank you for your question! Here is an additional fact about your topic: {relevant_segment}"
max_tokens = len(tokenizer(prompt)['input_ids']) + 50
response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
return clean_up_response(response, relevant_segment)
except Exception as e:
print(f"Error generating response: {e}")
return ""
def clean_up_response(response, segment):
"""
Clean up the generated response to ensure it is tidy and presentable.
Parameters:
- response (str): The initial response generated by the model.
- segment (str): The segment used to generate the response.
Returns:
- str: A cleaned and formatted response.
"""
sentences = response.split('.')
cleaned_sentences = [sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in segment]
cleaned_response = '. '.join(cleaned_sentences).strip()
if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
cleaned_response += "."
return cleaned_response
def process_dataset(csv_file_path, output_csv_file_path):
"""
Process the dataset by generating responses and evaluating their similarities.
Parameters:
- csv_file_path (str): Path to the CSV file containing the dataset.
- output_csv_file_path (str): Path where the updated dataset will be saved.
Prints:
- Path to the saved results and the average similarity score.
"""
df = pd.read_csv(csv_file_path)
dataset = Dataset.from_pandas(df)
updated_dataset = add_model_answers(dataset)
similarities = evaluate_similarity(updated_dataset)
updated_dataset = updated_dataset.add_column("similarity", similarities)
results_df = updated_dataset.to_pandas()
results_df.to_csv(output_csv_file_path, index=False)
average_similarity = sum(similarities) / len(similarities) if similarities else 0
print(f"Results saved to {output_csv_file_path}")
print(f"Average Similarity Score: {average_similarity:.3f}")
def add_model_answers(dataset):
"""
Add generated answers to the dataset.
Parameters:
- dataset (datasets.Dataset): The Hugging Face dataset object.
Returns:
- datasets.Dataset: Updated dataset with added answers.
"""
answers = [generate_response(q) for q in dataset['Question']]
dataset = dataset.add_column("Answer", answers)
return dataset
def evaluate_similarity(dataset):
"""
Evaluate the similarity of generated answers against ground truth answers.
Parameters:
- dataset (datasets.Dataset): The dataset containing both answers and ground truths.
Returns:
- list[float]: List of similarity scores.
"""
similarities = [util.pytorch_cos_sim(retrieval_model.encode(ans), retrieval_model.encode(gt))[0][0].item()
for ans, gt in zip(dataset['Answer'], dataset['GroundTruth'])]
return similarities
# Process datasets
process_dataset(csv_file_path, output_csv_file_path)
process_dataset(val_csv_file_path, output_val_csv_file_path)
|