File size: 3,367 Bytes
71a8168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b71a3
71a8168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import psycopg2
from sentence_transformers import SentenceTransformer

class ProductDatabase:
    def __init__(self, database_url):
        self.database_url = database_url
        self.conn = None
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    
    def connect(self):
        self.conn = psycopg2.connect(self.database_url)
    
    def close(self):
        if self.conn:
            self.conn.close()
    
    def setup_vector_extension_and_column(self):
        with self.conn.cursor() as cursor:
            # pgvector拡張機能のインストール
            cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
            
            # ベクトルカラムの追加
            cursor.execute("ALTER TABLE products ADD COLUMN IF NOT EXISTS vector_col vector(384);")
            
            self.conn.commit()

    def get_embedding(self, text):
        embedding = self.model.encode(text)
        return embedding

    def insert_vector(self, product_id, text):
        vector = self.get_embedding(text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("UPDATE products SET vector_col = %s WHERE id = %s", (vector, product_id))
            self.conn.commit()

    def search_similar_vectors(self, query_text, top_k=5):
        query_vector = self.get_embedding(query_text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("""
                SELECT id, vector_col <=> %s::vector AS distance
                FROM products
                ORDER BY distance
                LIMIT %s;
            """, (query_vector, top_k))
            results = cursor.fetchall()
            return results

def main():
    # データベース接続情報
    DATABASE_URL = os.getenv("postgre_url")
    
    # ProductDatabaseクラスのインスタンスを作成
    db = ProductDatabase(DATABASE_URL)
    
    # データベースに接続
    db.connect()
    
    try:
        # pgvector拡張機能のインストールとカラムの追加
        db.setup_vector_extension_and_column()
        print("Vector extension installed and column added successfully.")
        
        # サンプルデータの挿入
        sample_text = """検査にはどのぐらい時間かかりますか?⇒当日に分かります。
法人取引やってますか?⇒大丈夫ですよ。成約時に必要な書類の説明
LINEで金粉送って、査定はできますか?⇒できますが、今お話した内容と同様で、検査が必要な旨を返すだけなので、金粉ではなく、他のお品物でLINE査定くださいと。
分かりました、またどうするか検討して連絡しますと"""
        sample_product_id = 1  # 実際の製品IDを使用
        db.insert_vector(sample_product_id, sample_text)
        db.insert_vector(2, sample_text)

        print(f"Vector inserted for product ID {sample_product_id}.")

        
        # ベクトル検索
        query_text = "今お話した内容と同様で"
        results = db.search_similar_vectors(query_text)
        print("Search results:")
        for result in results:
            print(result)
    
    finally:
        # 接続を閉じる
        db.close()

if __name__ == "__main__":
    main()