DATN_Query_Routing / utils.py
qminh369's picture
Upload 7 files
eaf119d verified
raw
history blame
1.68 kB
from pymilvus import connections, utility, DataType, FieldSchema, CollectionSchema, Collection
from sentence_transformers import SentenceTransformer
from pyvi import ViTokenizer
import string
import json
def load_json(path):
with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def convert_query(query):
tokenized_query = ViTokenizer.tokenize(query.lower())
return tokenized_query
def load_stopword(path):
stop_words = []
with open(path, 'r', encoding='utf-8') as file:
for line in file:
stop_words.append(line.strip())
return stop_words
def remove_stop_words(path, split_prompts):
stop_words = load_stopword(path)
clean_words = []
for ele in split_prompts:
if ele not in stop_words:
clean_words.append(ele.strip())
return clean_words
def clean_query(path, query):
vi_query = ViTokenizer.tokenize(query.lower())
word_query = vi_query.split(' ')
#print("word query: ", word_query)
query_remove_punc = [word.replace('_', ' ') for word in word_query if word not in string.punctuation]
removed_stop_words = remove_stop_words(path, query_remove_punc)
removed_stop_words = list(dict.fromkeys(removed_stop_words))
return removed_stop_words
def load_model(model_name):
model = SentenceTransformer(model_name)
return model
def connect_vector_db():
connections.connect('default', host='localhost', port='19530')
print("Connect finished!")
def load_collection(collection_name):
collection = Collection(collection_name)
collection.load()
print(f"{collection_name} load complete!")
return collection