ttschatbot / main.py
leesuo215's picture
Update main.py
8e8fbf9 verified
raw
history blame
5.22 kB
from flask import Flask, request, jsonify
from dotenv import load_dotenv
import os
import pymongo
import google.generativeai as genai
from flask_cors import CORS
from tqdm import tqdm
# Load environment variables from .env file
load_dotenv()
# Access the key
MONGODB_URI = os.getenv('MONGODB_URI')
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL') or 'keepitreal/vietnamese-sbert'
DB_NAME = os.getenv('DB_NAME')
DB_COLLECTION = os.getenv('DB_COLLECTION')
GEMINI_KEY = os.getenv('GEMINI_KEY')
genai.configure(api_key=GEMINI_KEY)
model = genai.GenerativeModel('gemini-1.5-pro')
client = pymongo.MongoClient(MONGODB_URI)
db = client[DB_NAME]
collection = db[DB_COLLECTION]
app = Flask(__name__)
CORS(app)
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
def vector_search(user_query, collection, limit=4):
"""
Perform a vector search in the MongoDB collection based on the user query.
Args:
user_query (str): The user's query string.
collection (MongoCollection): The MongoDB collection to search.
Returns:
list: A list of matching documents.
"""
# Generate embedding for the user query
query_embedding = get_embedding(user_query)
if query_embedding is None:
return "Invalid query or embedding generation failed."
# Define the vector search pipeline
vector_search_stage = {
"$vectorSearch": {
"index": "vector_index",
"queryVector": query_embedding,
"path": "embedding",
"numCandidates": 150,
"limit": limit,
}
}
unset_stage = {
"$unset": "embedding"
}
project_stage = {
"$project": {
"_id": 0,
"title": 1,
"details": 1,
"price": 1,
"promotion_price": 1,
"size_options": 1,
"gender_options": 1,
"quantity": 1,
"stock": 1,
"is_shoes": 1,
"is_sandals": 1,
}
}
pipeline = [vector_search_stage, unset_stage, project_stage]
# Execute the search
results = collection.aggregate(pipeline)
return list(results)
def get_search_result(query, collection):
get_knowledge = vector_search(query, collection, 10)
search_result = ""
i = 0
for result in get_knowledge:
# print(result)
i += 1
if result.get('price'):
search_result += f"\n\nSản phẩm {i+1}: {result.get('title')}, Giá: {result.get('price')}"
if result.get('promotion_price'):
search_result += f", Giá ưu đãi: {result.get('promotion_price')}"
if result.get('stock'):
search_result += f", Trạng thái: {result.get('stock')}"
if result.get('is_shoes') == True:
search_result += f", Loại: Giày"
if result.get('is_sandals') == True:
search_result += f", Loại: Dép"
if result.get('size_options'):
search_result += f", Size: {result.get('size_options')}"
if result.get('gender_options'):
search_result += f", Dành cho: {result.get('gender_options')}"
if result.get('details'):
search_result += f", Chi tiết sản phẩm: {result.get('details')}"
return search_result
def get_embedding(text):
if not text.strip():
print("Attempted to get embedding for empty text.")
return []
embedding = embedding_model.encode(text)
return embedding.tolist()
def process_query(query):
return query.lower()
@app.route('/api/search', methods=['POST'])
def handle_query():
data = request.get_json()
query = process_query(data.get('question'))
if not query:
return jsonify({'error': 'No query provided'}), 400
# Retrieve data from vector database
source_information = get_search_result(query, collection).replace('<br>', '\n')
combined_information = f"Hãy trở thành chuyên gia tư vấn bán hàng cho một website bán giày dép ThuThaoShoes. Câu hỏi của khách hàng: {query}\nTrả lời câu hỏi dựa vào các thông tin sản phẩm dưới đây: {source_information}."
response = model.generate_content(combined_information)
return jsonify({
'content': response.text
})
@app.route('/api/embedding', methods=['GET'])
def get_embedding_api():
# Lấy tất cả các tài liệu từ collection
documents = list(collection.find({}))
for doc in tqdm(documents, desc="Processing documents"):
product_specs = doc.get('title', '')
product_cat = doc.get('category', '')
print(product_specs + ' ' + product_cat)
embedding = get_embedding(product_specs + ' Danh mục: ' + product_cat)
if embedding is not None:
# Cập nhật tài liệu với embedding mới
collection.update_one(
{'_id': doc['_id']},
{'$set': {'embedding': embedding}}
)
return jsonify({'message': 'Embedding cập nhật thành công cho tất cả các tài liệu.'})
if __name__ == '__main__':
app.run(debug=True)