kenken999 commited on
Commit
41e5bd2
1 Parent(s): 36c2b65
Files changed (1) hide show
  1. babyagi/classesa/diamond.py +12 -7
babyagi/classesa/diamond.py CHANGED
@@ -34,14 +34,14 @@ class ProductDatabase:
34
  cursor.execute("UPDATE diamondprice SET vector_col = %s WHERE id = %s", (vector, product_id))
35
  self.conn.commit()
36
 
37
- def search_similar_vectors(self, query_text, top_k=5):
38
  query_vector = self.get_embedding(query_text).tolist() # ndarray をリストに変換
39
  with self.conn.cursor() as cursor:
40
  cursor.execute("""
41
- SELECT id,carat, cut, color, clarity, depth, diamondprice.table, x, y, z, vector_col <=> %s::vector AS distance
42
  FROM diamondprice
43
  WHERE vector_col IS NOT NULL
44
- ORDER BY distance desc
45
  LIMIT %s;
46
  """, (query_vector, top_k))
47
  results = cursor.fetchall()
@@ -54,6 +54,7 @@ class ProductDatabase:
54
  SELECT id,carat, cut, color, clarity, depth, diamondprice.table, x, y, z
55
  FROM diamondprice
56
  order by id asc
 
57
  """, (query_vector, top_k))
58
  results = cursor.fetchall()
59
  return results
@@ -75,12 +76,13 @@ def main():
75
  query_text="1"
76
  results = db.search_similar_all(query_text)
77
  print("Search results:")
78
- DEBUG=0
79
  if DEBUG==1:
80
  for result in results:
81
  print(result)
82
  id = result[0]
83
- sample_text = str(result[1])+"-"+str(result[2])+"-"+str(result[3])+"-"+str(result[4])+"-"+str(result[5])+"-"+str(result[6])+"-"+str(result[7])+"-"+str(result[8])+"-"+str(result[9])
 
84
  db.insert_vector(id, sample_text)
85
  #return
86
  # サンプルデータの挿入
@@ -93,8 +95,11 @@ def main():
93
 
94
 
95
  # ベクトル検索
96
- query_text = "12.03Very GoodJSI262.0587.27"
97
- query_text = "2.03-Very Good-J-SI2"
 
 
 
98
  results = db.search_similar_vectors(query_text)
99
  print("Search results:")
100
  for result in results:
 
34
  cursor.execute("UPDATE diamondprice SET vector_col = %s WHERE id = %s", (vector, product_id))
35
  self.conn.commit()
36
 
37
+ def search_similar_vectors(self, query_text, top_k=50):
38
  query_vector = self.get_embedding(query_text).tolist() # ndarray をリストに変換
39
  with self.conn.cursor() as cursor:
40
  cursor.execute("""
41
+ SELECT id,price,carat, cut, color, clarity, depth, diamondprice.table, x, y, z, vector_col <=> %s::vector AS distance
42
  FROM diamondprice
43
  WHERE vector_col IS NOT NULL
44
+ ORDER BY distance asc
45
  LIMIT %s;
46
  """, (query_vector, top_k))
47
  results = cursor.fetchall()
 
54
  SELECT id,carat, cut, color, clarity, depth, diamondprice.table, x, y, z
55
  FROM diamondprice
56
  order by id asc
57
+ limit 10
58
  """, (query_vector, top_k))
59
  results = cursor.fetchall()
60
  return results
 
76
  query_text="1"
77
  results = db.search_similar_all(query_text)
78
  print("Search results:")
79
+ DEBUG=1
80
  if DEBUG==1:
81
  for result in results:
82
  print(result)
83
  id = result[0]
84
+ sample_text = str(result[1])+str(result[2])+str(result[3])+str(result[4])+str(result[5])+str(result[6])+str(result[7])+str(result[8])+str(result[9])
85
+ print(sample_text)
86
  db.insert_vector(id, sample_text)
87
  #return
88
  # サンプルデータの挿入
 
95
 
96
 
97
  # ベクトル検索
98
+ query_text = "2.03Very GoodJSI262.058.08.068.125.05"
99
+
100
+ query_text = "2.03Very GoodJSI2"
101
+
102
+ #query_text = "2.03-Very Good-J-SI2-62.2-58.0-7.27-7.33-4.55"
103
  results = db.search_similar_vectors(query_text)
104
  print("Search results:")
105
  for result in results: