Spaces:
Sleeping
Sleeping
bibliotecadebabel
commited on
Commit
•
37c2a8d
1
Parent(s):
475dbf8
first commit
Browse files- README.md +5 -5
- app.py +117 -0
- requirements.txt +11 -0
- src/constants/__init__.py +0 -0
- src/constants/config.py +69 -0
- src/constants/credentials.py +11 -0
- src/decorators/decorators.py +11 -0
- src/pytorch_modules/datasets/schema_string_dataset.py +40 -0
- src/pytorch_modules/datasets/tokenized_dataset.py +62 -0
- src/pytorch_modules/models/utils_models.py +21 -0
- src/reader.py +92 -0
- src/utils.py +90 -0
- src/utils_search.py +153 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
title: Search Demo
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Search Demo
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.32.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: isc
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import src.constants.config as configurations
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from sentence_transformers import CrossEncoder
|
5 |
+
from src.constants.credentials import cohere_trial_key
|
6 |
+
import streamlit as st
|
7 |
+
from src.reader import Reader
|
8 |
+
from src.utils_search import UtilsSearch
|
9 |
+
from copy import deepcopy
|
10 |
+
import numpy as np
|
11 |
+
import cohere
|
12 |
+
|
13 |
+
|
14 |
+
configurations = configurations.service_mxbai_msc_direct_config
|
15 |
+
api_key = cohere_trial_key
|
16 |
+
co = cohere.Client(api_key)
|
17 |
+
semantic_column_names = configurations["semantic_column_names"]
|
18 |
+
|
19 |
+
# Check CUDA availability and set device
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
torch.cuda.set_device(0) # Use the first GPU
|
22 |
+
else:
|
23 |
+
st.write("CUDA is not available. Using CPU instead.")
|
24 |
+
|
25 |
+
@st.cache_data
|
26 |
+
def init():
|
27 |
+
config = configurations
|
28 |
+
search_utils = UtilsSearch(config)
|
29 |
+
reader = Reader(config=config["reader_config"])
|
30 |
+
model = SentenceTransformer(config['sentence_transformer_name'], device='cuda:0')
|
31 |
+
cross_encoder = CrossEncoder(config['cross_encoder_name'], device='cuda:0')
|
32 |
+
df = reader.read()
|
33 |
+
index = search_utils.dataframe_to_index(df)
|
34 |
+
return df, model, cross_encoder, index, search_utils
|
35 |
+
|
36 |
+
def get_possible_values_for_column(column_name, search_utils, df):
|
37 |
+
if column_name not in st.session_state:
|
38 |
+
setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name))
|
39 |
+
return getattr(st.session_state, column_name)
|
40 |
+
|
41 |
+
|
42 |
+
# Initialize or retrieve from session state
|
43 |
+
if 'init_results' not in st.session_state:
|
44 |
+
st.session_state.init_results = init()
|
45 |
+
|
46 |
+
# Now you can access your initialized objects directly from the session state
|
47 |
+
df, model, cross_encoder, index, search_utils = st.session_state.init_results
|
48 |
+
|
49 |
+
# Streamlit app layout
|
50 |
+
st.title('Search Demo')
|
51 |
+
|
52 |
+
# Input fields
|
53 |
+
query = st.text_input('Enter your search query here')
|
54 |
+
use_cohere = st.checkbox('Use Cohere', value=False) # Default to checked
|
55 |
+
|
56 |
+
programmatic_search_config = deepcopy(configurations['programmatic_search_config'])
|
57 |
+
|
58 |
+
dynamic_programmatic_search_config = {
|
59 |
+
"scalar_columns": [],
|
60 |
+
"discrete_columns": []
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
for column in programmatic_search_config['scalar_columns']:
|
65 |
+
# Create number input for scalar values
|
66 |
+
col_name = column["column_name"]
|
67 |
+
min_val = float(column["min_value"])
|
68 |
+
max_val = float(column["max_value"])
|
69 |
+
user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val)
|
70 |
+
user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val)
|
71 |
+
dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max})
|
72 |
+
|
73 |
+
for column in programmatic_search_config['discrete_columns']:
|
74 |
+
# Create multiselect for discrete values
|
75 |
+
col_name = column["column_name"]
|
76 |
+
default_values = column["default_values"]
|
77 |
+
# Assuming you have a function to fetch possible values for the discrete columns based on the column name
|
78 |
+
possible_values = get_possible_values_for_column(col_name, search_utils, df) # Implement this function based on your application
|
79 |
+
selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values)
|
80 |
+
dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values})
|
81 |
+
|
82 |
+
|
83 |
+
programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns']
|
84 |
+
programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns']
|
85 |
+
|
86 |
+
|
87 |
+
# Search button
|
88 |
+
if st.button('Search'):
|
89 |
+
if query: # Checking if a query was entered
|
90 |
+
df_filtered = search_utils.filter_dataframe(df, programmatic_search_config)
|
91 |
+
if len(df_filtered) == 0:
|
92 |
+
st.write('No results found')
|
93 |
+
else:
|
94 |
+
index = search_utils.dataframe_to_index(df_filtered)
|
95 |
+
if use_cohere == False:
|
96 |
+
# Call your Cohere-based search function here
|
97 |
+
results_df = search_utils.search(query, df_filtered, model, cross_encoder, index)
|
98 |
+
results_df = search_utils.drop_columns(results_df, programmatic_search_config)
|
99 |
+
|
100 |
+
else:
|
101 |
+
df_retrieved = search_utils.retrieve(query, df_filtered, model, index)
|
102 |
+
df_retrieved = search_utils.drop_columns(df_retrieved, programmatic_search_config)
|
103 |
+
df_retrieved.fillna(value="", inplace=True)
|
104 |
+
docs = df_retrieved.to_dict('records')
|
105 |
+
column_names = semantic_column_names
|
106 |
+
docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
|
107 |
+
rank_fields = list(docs[0].keys())
|
108 |
+
results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0',
|
109 |
+
rank_fields=rank_fields)
|
110 |
+
top_ids = [hit.index for hit in results.results]
|
111 |
+
# Create the DataFrame with the rerank results
|
112 |
+
results_df = df_retrieved.iloc[top_ids].copy()
|
113 |
+
results_df['rank'] = (np.arange(len(results_df)) + 1)
|
114 |
+
|
115 |
+
st.write(results_df)
|
116 |
+
else:
|
117 |
+
st.write("Please enter a query to search.")
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
datasets
|
4 |
+
accelerate>=0.21.0
|
5 |
+
pandas
|
6 |
+
fastparquet
|
7 |
+
s3fs
|
8 |
+
numpy
|
9 |
+
faiss-gpu
|
10 |
+
sentence_transformers
|
11 |
+
cohere
|
src/constants/__init__.py
ADDED
File without changes
|
src/constants/config.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import src.constants.credentials as cred
|
2 |
+
import os
|
3 |
+
|
4 |
+
service_mxbai_made_in_china_config = {"reader_config": {"input_path": os.environ['made_in_china_s3_path'],
|
5 |
+
"credentials": cred.credentials_backblaze,
|
6 |
+
"format":"parquet"
|
7 |
+
},
|
8 |
+
"sample_size": 32,
|
9 |
+
"sentence_transformer_name": "mixedbread-ai/mxbai-embed-large-v1",
|
10 |
+
"cross_encoder_name": "mixedbread-ai/mxbai-rerank-large-v1",
|
11 |
+
"batch_size": 4,
|
12 |
+
"dataset_size": 32,
|
13 |
+
"seq_len": 256,
|
14 |
+
"top_k": 1000,
|
15 |
+
"programmatic_search_config": {
|
16 |
+
"scalar_columns": [{"column_name": "price", "min_value": 0, "max_value": "10000"}],
|
17 |
+
"discrete_columns": [{"column_name": "supplierName",
|
18 |
+
# "default_values": ['Zhongshan Norye Hardware Co., Ltd.']
|
19 |
+
"default_values": []
|
20 |
+
},
|
21 |
+
{"column_name": "warranty",
|
22 |
+
# "default_values": ['Zhongshan Norye Hardware Co., Ltd.']
|
23 |
+
"default_values": []
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"columns_to_drop": ["similarities", "embeddings"]
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
service_mxbai_msc_direct_sample_config = {"reader_config": {"input_path": os.environ['msc_direct_s3_path'],
|
32 |
+
"credentials": cred.credentials_backblaze,
|
33 |
+
"format":"parquet"
|
34 |
+
},
|
35 |
+
"sample_size": 32,
|
36 |
+
"sentence_transformer_name": "mixedbread-ai/mxbai-embed-large-v1",
|
37 |
+
"cross_encoder_name": "mixedbread-ai/mxbai-rerank-large-v1",
|
38 |
+
"batch_size": 4,
|
39 |
+
"dataset_size": 32,
|
40 |
+
"seq_len": 256,
|
41 |
+
"top_k": 50,
|
42 |
+
"semantic_column_names": ['name', 'price', 'brand', 'keyword', 'description',
|
43 |
+
'specifications'],
|
44 |
+
"programmatic_search_config": {
|
45 |
+
"scalar_columns": [{"column_name": "price", "min_value": 0, "max_value": "10000"}],
|
46 |
+
"discrete_columns": [{"column_name": "brand", "default_values": []}],
|
47 |
+
"columns_to_drop": ["similarities", "embeddings", "index"]
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
service_mxbai_msc_direct_config = {"reader_config": {"input_path": os.environ['msc_direct_s3_path'],
|
52 |
+
"credentials": cred.credentials_backblaze,
|
53 |
+
"format":"parquet"
|
54 |
+
},
|
55 |
+
"sample_size": 32,
|
56 |
+
"sentence_transformer_name": "mixedbread-ai/mxbai-embed-large-v1",
|
57 |
+
"cross_encoder_name": "mixedbread-ai/mxbai-rerank-large-v1",
|
58 |
+
"batch_size": 4,
|
59 |
+
"dataset_size": 32,
|
60 |
+
"seq_len": 256,
|
61 |
+
"top_k": 50,
|
62 |
+
"semantic_column_names": ['name', 'price', 'brand', 'keyword', 'description',
|
63 |
+
'specifications'],
|
64 |
+
"programmatic_search_config": {
|
65 |
+
"scalar_columns": [{"column_name": "price", "min_value": 0, "max_value": "10000"}],
|
66 |
+
"discrete_columns": [{"column_name": "brand", "default_values": []}],
|
67 |
+
"columns_to_drop": ["similarities", "embeddings", "index"]
|
68 |
+
}
|
69 |
+
}
|
src/constants/credentials.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
credentials_backblaze = {"access_key_id": os.environ['credentials_backblaze_access_key_id'],
|
4 |
+
"secret_access_key": os.environ['credentials_backblaze_secret_access_key'],
|
5 |
+
"bucket_name": os.environ['credentials_backblaze_bucket_name'],
|
6 |
+
"endpoint_url": os.environ['credentials_backblaze_endpoint_url'],
|
7 |
+
"region_name": "us-east-1"
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
cohere_trial_key = os.environ["cohere_trial_key"]
|
src/decorators/decorators.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
def timeit_decorator(func):
|
4 |
+
def wrapper(*args, **kwargs):
|
5 |
+
start_time = time.time()
|
6 |
+
result = func(*args, **kwargs)
|
7 |
+
end_time = time.time()
|
8 |
+
print(f"Function {func.__name__} took {end_time-start_time:.4f} seconds to execute")
|
9 |
+
return result
|
10 |
+
return wrapper
|
11 |
+
|
src/pytorch_modules/datasets/schema_string_dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class SchemaStringDataset(Dataset):
|
7 |
+
def __init__(self, data, config):
|
8 |
+
self.data = data
|
9 |
+
self.config = config
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
# Return the dataset size specified in the configuration
|
13 |
+
return self.config["dataset_size"]
|
14 |
+
|
15 |
+
def transform_entry(self, entry):
|
16 |
+
# Filter out None and NaN values
|
17 |
+
filtered_entry = {k: v for k, v in entry.items() if v is not np.nan and v is not None}
|
18 |
+
|
19 |
+
# Check if there are any entries after filtering
|
20 |
+
if not filtered_entry:
|
21 |
+
return '', '' # Return empty strings if no valid entries exist
|
22 |
+
|
23 |
+
# Use the rest of the entry as input
|
24 |
+
inputs = [f"{k}:{v}" for k, v in filtered_entry.items()]
|
25 |
+
|
26 |
+
return ' '.join(inputs)
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
transformed_data = {
|
29 |
+
'inputs': []
|
30 |
+
}
|
31 |
+
|
32 |
+
item = self.data[idx]
|
33 |
+
input_data = {k: v for k, v in item.items()}
|
34 |
+
inputs = self.transform_entry(input_data)
|
35 |
+
transformed_data['inputs'] = inputs
|
36 |
+
|
37 |
+
transformed_data['idx'] = idx
|
38 |
+
|
39 |
+
# Return the transformed item for the current idx
|
40 |
+
return transformed_data
|
src/pytorch_modules/datasets/tokenized_dataset.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class TokenizedDataset(Dataset):
|
7 |
+
def __init__(self, custom_dataset, tokenizer, max_seq_len):
|
8 |
+
"""
|
9 |
+
custom_dataset: An instance of CustomDataset
|
10 |
+
tokenizer: An instance of the tokenizer
|
11 |
+
max_seq_len: Maximum sequence length for padding
|
12 |
+
"""
|
13 |
+
self.dataset = custom_dataset
|
14 |
+
self.tokenizer = tokenizer
|
15 |
+
self.max_seq_len = max_seq_len
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
# The length is inherited from the custom dataset
|
19 |
+
return len(self.dataset)
|
20 |
+
|
21 |
+
def tokenize_and_pad(self, text_list):
|
22 |
+
"""
|
23 |
+
Tokenize and pad a list of text strings.
|
24 |
+
"""
|
25 |
+
# Tokenize all text strings in the list
|
26 |
+
tokens = self.tokenizer(text_list, padding='max_length', max_length=self.max_seq_len, truncation=True, return_tensors="pt")
|
27 |
+
return tokens
|
28 |
+
|
29 |
+
def __getitem__(self, idx):
|
30 |
+
# Fetch the transformed data from the CustomDataset instance
|
31 |
+
transformed_data = self.dataset[idx]
|
32 |
+
|
33 |
+
# Initialize containers for inputs and optionally labels
|
34 |
+
tokenized_inputs = {}
|
35 |
+
tokenized_labels = {}
|
36 |
+
|
37 |
+
# Dynamically process each item in the dataset
|
38 |
+
for key, value in transformed_data.items():
|
39 |
+
if type(value) == int: # Check if value is an integer
|
40 |
+
# Convert integer to tensor and directly assign to inputs or labels based on key prefix
|
41 |
+
if key.startswith('label'):
|
42 |
+
tokenized_labels[key] = torch.tensor(value) # Convert int to tensor for labels
|
43 |
+
else:
|
44 |
+
tokenized_inputs[key] = torch.tensor(value) # Convert int to tensor for inputs
|
45 |
+
|
46 |
+
if type(value) == str:
|
47 |
+
tokenized_data = self.tokenize_and_pad(value)
|
48 |
+
if key.startswith('label'):
|
49 |
+
tokenized_labels[key] = tokenized_data['input_ids']
|
50 |
+
tokenized_labels['attention_mask_' + key] = tokenized_data['attention_mask']
|
51 |
+
|
52 |
+
else:
|
53 |
+
tokenized_inputs[key] = tokenized_data['input_ids']
|
54 |
+
tokenized_inputs['attention_mask_' + key] = tokenized_data['attention_mask']
|
55 |
+
|
56 |
+
|
57 |
+
# Prepare the return structure, conditionally including 'label' if labels are present
|
58 |
+
output = {"inputs": tokenized_inputs}
|
59 |
+
if tokenized_labels: # Check if there are any labels before adding to the output
|
60 |
+
output["label"] = tokenized_labels
|
61 |
+
|
62 |
+
return output
|
src/pytorch_modules/models/utils_models.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
class UtilsModels:
|
5 |
+
@staticmethod
|
6 |
+
def compute_embeddings(sentence_transformer, tokenized_sentences, attention_mask):
|
7 |
+
# Flatten the batch and num_sentences dimensions
|
8 |
+
batch_size, num_sentences, seq_len = tokenized_sentences.size()
|
9 |
+
flat_input_ids = tokenized_sentences.view(-1, seq_len)
|
10 |
+
flat_attention_mask = attention_mask.view(-1, seq_len) if attention_mask is not None else None
|
11 |
+
|
12 |
+
# Process sentences through the sentence_transformer
|
13 |
+
outputs = sentence_transformer(input_ids=flat_input_ids, attention_mask=flat_attention_mask)
|
14 |
+
embeddings = outputs.last_hidden_state
|
15 |
+
|
16 |
+
# Pool the embeddings to get a single vector per sentence (optional)
|
17 |
+
# Here, simply taking the mean across the sequence_length dimension
|
18 |
+
sentence_embeddings = embeddings.mean(dim=1)
|
19 |
+
|
20 |
+
# Reshape back to [batch_size, num_sentences * 2, embedding_dim]
|
21 |
+
return sentence_embeddings.view(batch_size, num_sentences, -1)
|
src/reader.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
from src.utils import Utils
|
6 |
+
|
7 |
+
|
8 |
+
class Reader:
|
9 |
+
|
10 |
+
def __init__(self, config):
|
11 |
+
self.config = config
|
12 |
+
self.utils = Utils()
|
13 |
+
self.cache_dir = config.get("cache_dir", "./cache") # default cache directory
|
14 |
+
|
15 |
+
def read(self, input_path=None, reader_config=None):
|
16 |
+
# If reader_config is None, use the class-level config
|
17 |
+
if reader_config is None:
|
18 |
+
reader_config = self.config
|
19 |
+
|
20 |
+
file_format = reader_config.get("format", None)
|
21 |
+
input_path = input_path or reader_config.get("input_path", "")
|
22 |
+
|
23 |
+
# Decide which method to use based on file format
|
24 |
+
if file_format == "parquet":
|
25 |
+
return self._read_dataframe_from_parquet(input_path, reader_config)
|
26 |
+
elif file_format == "csv":
|
27 |
+
return self._read_dataframe_from_csv(input_path)
|
28 |
+
elif file_format == "s3_csv":
|
29 |
+
return self._read_dataframe_from_csv_s3(input_path, reader_config)
|
30 |
+
elif file_format == "json_folder":
|
31 |
+
return self._read_json_files_to_dataframe(input_path)
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unsupported file format: {file_format}")
|
34 |
+
|
35 |
+
def _read_dataframe_from_parquet(self, input_path=None, reader_config=None):
|
36 |
+
if reader_config is None:
|
37 |
+
reader_config = self.config
|
38 |
+
|
39 |
+
input_path = input_path or reader_config.get("input_path", "")
|
40 |
+
|
41 |
+
if input_path.startswith("s3://"):
|
42 |
+
# Check if the file is cached
|
43 |
+
local_cache_path = os.path.join(self.cache_dir, os.path.basename(input_path))
|
44 |
+
|
45 |
+
if os.path.exists(local_cache_path):
|
46 |
+
print("reading from cache")
|
47 |
+
print(local_cache_path)
|
48 |
+
return pd.read_parquet(local_cache_path)
|
49 |
+
|
50 |
+
print("reading from s3")
|
51 |
+
|
52 |
+
credentials = reader_config.get("credentials", {})
|
53 |
+
storage_options = {
|
54 |
+
'key': credentials.get("access_key_id", ""),
|
55 |
+
'secret': credentials.get("secret_access_key", ""),
|
56 |
+
'client_kwargs': {'endpoint_url': credentials.get("endpoint_url", "")}
|
57 |
+
}
|
58 |
+
|
59 |
+
# Read from S3 and cache locally
|
60 |
+
df = pd.read_parquet(input_path, storage_options=storage_options)
|
61 |
+
os.makedirs(self.cache_dir, exist_ok=True) # Check and create if not exists
|
62 |
+
df.to_parquet(local_cache_path) # Save to cache
|
63 |
+
return df
|
64 |
+
else:
|
65 |
+
return pd.read_parquet(input_path)
|
66 |
+
|
67 |
+
def _read_dataframe_from_csv(self, file_path):
|
68 |
+
return self.utils.read_dataframe_from_csv(file_path)
|
69 |
+
|
70 |
+
def _read_json_files_to_dataframe(self, folder_path):
|
71 |
+
self.utils.load_json_files_to_dataframe(folder_path)
|
72 |
+
|
73 |
+
def _read_dataframe_from_csv_s3(self, input_path, reader_config):
|
74 |
+
credentials = reader_config.get("credentials", {})
|
75 |
+
endpoint_url = credentials.get("endpoint_url", "")
|
76 |
+
access_key_id = credentials.get("access_key_id", "")
|
77 |
+
secret_access_key = credentials.get("secret_access_key", "")
|
78 |
+
|
79 |
+
# Constructing the storage options for s3fs
|
80 |
+
storage_options = {
|
81 |
+
'key': access_key_id,
|
82 |
+
'secret': secret_access_key,
|
83 |
+
'client_kwargs': {'endpoint_url': endpoint_url}
|
84 |
+
}
|
85 |
+
|
86 |
+
# Use pandas to read the CSV file directly from S3
|
87 |
+
try:
|
88 |
+
df = pd.read_csv(input_path, storage_options=storage_options)
|
89 |
+
return df
|
90 |
+
except Exception as e:
|
91 |
+
print(f"An error occurred while reading the CSV file from S3: {e}")
|
92 |
+
return None
|
src/utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
class Utils:
|
8 |
+
@staticmethod
|
9 |
+
def read_dataframe_from_csv(file_path):
|
10 |
+
"""
|
11 |
+
Reads a DataFrame from a CSV file if the file exists.
|
12 |
+
|
13 |
+
Parameters:
|
14 |
+
- file_path: The full path to the CSV file.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
- A pandas DataFrame if the file exists and is read successfully; None otherwise.
|
18 |
+
"""
|
19 |
+
# Check if the file exists
|
20 |
+
if os.path.isfile(file_path):
|
21 |
+
try:
|
22 |
+
# Attempt to read the CSV file into a DataFrame
|
23 |
+
df = pd.read_csv(file_path)
|
24 |
+
return df
|
25 |
+
except Exception as e:
|
26 |
+
# If an error occurs during reading, print it
|
27 |
+
print(f"An error occurred while reading the file: {e}")
|
28 |
+
return None
|
29 |
+
else:
|
30 |
+
# If the file does not exist, print a message
|
31 |
+
print(f"File does not exist: {file_path}")
|
32 |
+
return None
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def read_json_files_to_dataframe(folder_path):
|
36 |
+
"""
|
37 |
+
Reads JSON files from a specified folder, automatically infers columns from the JSON files,
|
38 |
+
and returns the data as a pandas DataFrame.
|
39 |
+
|
40 |
+
:param folder_path: Path to the folder containing JSON files.
|
41 |
+
:return: A pandas DataFrame containing data from all JSON files in the folder.
|
42 |
+
"""
|
43 |
+
data = []
|
44 |
+
|
45 |
+
for filename in os.listdir(folder_path):
|
46 |
+
if filename.endswith('.json'):
|
47 |
+
file_path = os.path.join(folder_path, filename)
|
48 |
+
|
49 |
+
with open(file_path, 'r') as file:
|
50 |
+
# First attempt to load the JSON
|
51 |
+
json_data = json.load(file)
|
52 |
+
|
53 |
+
# Check if json_data is a string instead of a dict, decode it again
|
54 |
+
if isinstance(json_data, str):
|
55 |
+
json_data = json.loads(json_data)
|
56 |
+
|
57 |
+
data.append(json_data)
|
58 |
+
|
59 |
+
# Create a DataFrame from the list of dictionaries
|
60 |
+
df = pd.DataFrame(data)
|
61 |
+
|
62 |
+
return df
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def write_pandas_to_local(df, output_path):
|
66 |
+
"""
|
67 |
+
Writes a pandas DataFrame to a CSV file at the specified output path.
|
68 |
+
|
69 |
+
:param df: The pandas DataFrame to be saved.
|
70 |
+
:param output_path: The file path where the DataFrame should be saved as a CSV.
|
71 |
+
"""
|
72 |
+
# Create the directory if it does not exist
|
73 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
74 |
+
|
75 |
+
# Save the DataFrame to a CSV file without saving the index
|
76 |
+
df.to_csv(output_path, index=False)
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def convert_iterables_to_strings(df):
|
80 |
+
"""
|
81 |
+
Convert columns with iterable types (excluding strings) to string representations.
|
82 |
+
This includes handling numpy arrays or lists within dataframe cells.
|
83 |
+
"""
|
84 |
+
for col in df.columns:
|
85 |
+
# Apply conversion if the value is an iterable (excluding strings) or a numpy array
|
86 |
+
df[col] = df[col].apply(lambda x: str(x) if isinstance(x, (np.ndarray, list)) else x)
|
87 |
+
return df
|
88 |
+
|
89 |
+
|
90 |
+
|
src/utils_search.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import faiss
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class UtilsSearch:
|
11 |
+
def __init__(self, config):
|
12 |
+
self.config = config
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def dataframe_to_index(df):
|
16 |
+
embeddings = np.stack(df['embeddings'].to_numpy())
|
17 |
+
norm_embeddings = np.ascontiguousarray(embeddings / np.linalg.norm(embeddings, axis=1)[:, None])
|
18 |
+
# Create a FAISS index (Step 2, unchanged but using normalized embeddings)
|
19 |
+
dimension = norm_embeddings.shape[1]
|
20 |
+
index = faiss.IndexFlatL2(dimension)
|
21 |
+
index.add(norm_embeddings)
|
22 |
+
return index # Ad
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def retrieve(query, df, model, index, top_k=100):
|
26 |
+
query += "Represent this sentence for searching relevant passages: "
|
27 |
+
"""
|
28 |
+
Search the DataFrame for the given query and return a sorted DataFrame based on similarity.
|
29 |
+
|
30 |
+
:param query: The search query string.
|
31 |
+
:param df: The input DataFrame containing embeddings.
|
32 |
+
:param model: The model to encode the query and compute embeddings.
|
33 |
+
:param index: The search index for querying.
|
34 |
+
:param top_k: The number of top results to return.
|
35 |
+
:return: A new DataFrame sorted by similarity to the query, with a 'similarities' column.
|
36 |
+
"""
|
37 |
+
# Check if CUDA is available and set the device accordingly
|
38 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
model.to(device)
|
40 |
+
|
41 |
+
# Compute the query embedding
|
42 |
+
query_vector = model.encode(query, convert_to_tensor=True, device=device).cpu().numpy()
|
43 |
+
|
44 |
+
# Normalize the query vector
|
45 |
+
query_vector /= np.linalg.norm(query_vector)
|
46 |
+
|
47 |
+
# Perform the search
|
48 |
+
distances, indices = index.search(np.array([query_vector]), top_k)
|
49 |
+
|
50 |
+
# Retrieve the rows from the DataFrame corresponding to the indices
|
51 |
+
retrieved_df = df.iloc[indices[0]]
|
52 |
+
|
53 |
+
# Attach the distances as a new column named 'similarities'
|
54 |
+
# Ensure the distances array matches the size of the retrieved DataFrame, especially if using slicing or other operations that might change its shape
|
55 |
+
retrieved_df = retrieved_df.assign(similarities=distances[0])
|
56 |
+
|
57 |
+
if 'similarities' in retrieved_df.columns:
|
58 |
+
retrieved_df = retrieved_df.sort_values(by='similarities', ascending=False)
|
59 |
+
|
60 |
+
# Optionally, you might want to reset the index if the order matters or if you need to serialize the DataFrame without index issues
|
61 |
+
retrieved_df = retrieved_df.reset_index(drop=True)
|
62 |
+
|
63 |
+
|
64 |
+
return retrieved_df
|
65 |
+
|
66 |
+
def rerank(self, query, df_top_100, cross_encoder, index):
|
67 |
+
# Convert the top 5 records to a list of dictionaries for processing
|
68 |
+
# print(df_top_100)
|
69 |
+
config = self.config
|
70 |
+
df_copy = df_top_100.copy().reset_index(drop=True)
|
71 |
+
records = df_copy.to_dict(orient='records')[:100]
|
72 |
+
|
73 |
+
# Assuming SchemaStringDataset can handle GPU data
|
74 |
+
dataset_str = SchemaStringDataset(records, config)
|
75 |
+
|
76 |
+
# Extract documents from dataset
|
77 |
+
documents = [batch["inputs"][:256] for batch in dataset_str]
|
78 |
+
|
79 |
+
# Rank documents based on the query
|
80 |
+
# Ensure data processed by cross_encoder is moved to the correct device
|
81 |
+
ids = [item["corpus_id"] for item in cross_encoder.rank(query, documents, top_k=10)]
|
82 |
+
|
83 |
+
# Use the ids to filter and reorder the original DataFrame
|
84 |
+
df_sorted_by_relevance = df_copy.loc[ids]
|
85 |
+
return df_sorted_by_relevance
|
86 |
+
|
87 |
+
def search(self, query, df, model, cross_encoder, index):
|
88 |
+
sorted_df = self.retrieve(query, df, model, index)
|
89 |
+
return self.rerank(query, sorted_df, cross_encoder, index)
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def top_10_common_values(df, column_name):
|
93 |
+
"""
|
94 |
+
This function takes a pandas dataframe and a column name,
|
95 |
+
and returns the top 10 most common non-null values of that column as a list.
|
96 |
+
"""
|
97 |
+
# Drop null values from the specified column and count occurrences of each value
|
98 |
+
# Convert the index of the resulting Series (which contains the values) to a list
|
99 |
+
value_counts_list = df[column_name].dropna().value_counts().head(10).index.tolist()
|
100 |
+
|
101 |
+
return value_counts_list
|
102 |
+
|
103 |
+
@staticmethod
|
104 |
+
def filter_dataframe(df, config, top_k_programmatic=100):
|
105 |
+
"""
|
106 |
+
Filters a DataFrame based on scalar and discrete column configurations, with type handling and null filtering.
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
- df: pandas.DataFrame to filter.
|
110 |
+
- config: Dictionary containing 'scalar_columns' and 'discrete_columns' configurations.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
- Filtered pandas.DataFrame.
|
114 |
+
"""
|
115 |
+
scalar_columns = config.get('scalar_columns', [])
|
116 |
+
discrete_columns = config.get('discrete_columns', [])
|
117 |
+
|
118 |
+
# Combine all column names to check for nulls
|
119 |
+
all_columns = [col["column_name"] for col in scalar_columns] + [col["column_name"] for col in discrete_columns]
|
120 |
+
|
121 |
+
# Drop rows where any of the specified columns have null values
|
122 |
+
df = df.dropna(subset=all_columns)
|
123 |
+
|
124 |
+
# Filtering based on scalar columns
|
125 |
+
for col in scalar_columns:
|
126 |
+
column_name = col["column_name"]
|
127 |
+
# Ensure min_value and max_value are of numeric type
|
128 |
+
min_value = float(col["min_value"])
|
129 |
+
max_value = float(col["max_value"])
|
130 |
+
# Convert the DataFrame column to numeric type to avoid comparison issues
|
131 |
+
df[column_name] = pd.to_numeric(df[column_name], errors='coerce')
|
132 |
+
df = df[df[column_name].between(min_value, max_value)]
|
133 |
+
|
134 |
+
# Filtering based on discrete columns
|
135 |
+
for col in discrete_columns:
|
136 |
+
column_name = col["column_name"]
|
137 |
+
default_values = col["default_values"]
|
138 |
+
if len(default_values) > 0:
|
139 |
+
df = df[df[column_name].isin(default_values)]
|
140 |
+
|
141 |
+
if 'similarities' in df.columns:
|
142 |
+
df = df.sort_values(by='similarities', ascending=False)
|
143 |
+
|
144 |
+
# Return the top 100 items with the highest similarity
|
145 |
+
return df
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def drop_columns(df, config):
|
149 |
+
columns_to_drop = config.get('columns_to_drop', [])
|
150 |
+
df_dropped = df.drop(columns_to_drop, axis=1)
|
151 |
+
return df_dropped
|
152 |
+
|
153 |
+
|