Files changed (1) hide show
  1. app.py +26 -45
app.py CHANGED
@@ -9,6 +9,7 @@ import json
9
  import numpy as np
10
  from ultralytics import YOLO
11
  import cv2
 
12
 
13
  # Load CLIP model and tokenizer
14
  @st.cache_resource
@@ -28,14 +29,6 @@ def load_yolo_model():
28
 
29
  yolo_model = load_yolo_model()
30
 
31
- # Load and process data
32
- @st.cache_data
33
- def load_data():
34
- with open('./musinsa-final.json', 'r', encoding='utf-8') as f:
35
- return json.load(f)
36
-
37
- data = load_data()
38
-
39
  # Helper functions
40
  def load_image_from_url(url, max_retries=3):
41
  for attempt in range(max_retries):
@@ -49,6 +42,9 @@ def load_image_from_url(url, max_retries=3):
49
  time.sleep(1)
50
  else:
51
  return None
 
 
 
52
 
53
  def get_image_embedding(image):
54
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
@@ -57,37 +53,6 @@ def get_image_embedding(image):
57
  image_features /= image_features.norm(dim=-1, keepdim=True)
58
  return image_features.cpu().numpy()
59
 
60
- @st.cache_data
61
- def process_database():
62
- database_embeddings = []
63
- database_info = []
64
-
65
- for item in data:
66
- image_url = item['이미지 링크'][0]
67
- image = load_image_from_url(image_url)
68
- if image is not None:
69
- embedding = get_image_embedding(image)
70
- database_embeddings.append(embedding)
71
- database_info.append({
72
- 'id': item['\ufeff상품 ID'],
73
- 'category': item['카테고리'],
74
- 'brand': item['브랜드명'],
75
- 'name': item['제품명'],
76
- 'price': item['정가'],
77
- 'discount': item['할인율'],
78
- 'image_url': image_url
79
- })
80
- else:
81
- st.warning(f"Skipping item {item['상품 ID']} due to image loading failure")
82
-
83
- if database_embeddings:
84
- return np.vstack(database_embeddings), database_info
85
- else:
86
- st.error("No valid embeddings were generated.")
87
- return None, None
88
-
89
- database_embeddings, database_info = process_database()
90
-
91
  def get_text_embedding(text):
92
  text_tokens = tokenizer([text]).to(device)
93
  with torch.no_grad():
@@ -95,17 +60,33 @@ def get_text_embedding(text):
95
  text_features /= text_features.norm(dim=-1, keepdim=True)
96
  return text_features.cpu().numpy()
97
 
98
- def find_similar_images(query_embedding, top_k=5):
 
 
 
 
 
 
 
 
 
99
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
100
  top_indices = np.argsort(similarities)[::-1][:top_k]
 
 
 
 
 
101
  results = []
102
- for idx in top_indices:
103
  results.append({
104
- 'info': database_info[idx],
105
- 'similarity': similarities[idx]
106
  })
107
  return results
108
 
 
 
109
  def detect_clothing(image):
110
  results = yolo_model(image)
111
  detections = results[0].boxes.data.cpu().numpy()
@@ -182,7 +163,7 @@ elif st.session_state.step == 'show_results':
182
  cropped_image = crop_image(st.session_state.query_image, selected_detection['bbox'])
183
  st.image(cropped_image, caption="Cropped Image", use_column_width=True)
184
  query_embedding = get_image_embedding(cropped_image)
185
- similar_images = find_similar_images(query_embedding)
186
 
187
  st.subheader("Similar Items:")
188
  for img in similar_images:
@@ -208,7 +189,7 @@ else: # Text search
208
  if st.button("Search by Text"):
209
  if query_text:
210
  text_embedding = get_text_embedding(query_text)
211
- similar_images = find_similar_images(text_embedding)
212
  st.subheader("Similar Items:")
213
  for img in similar_images:
214
  col1, col2 = st.columns(2)
 
9
  import numpy as np
10
  from ultralytics import YOLO
11
  import cv2
12
+ import chromadb
13
 
14
  # Load CLIP model and tokenizer
15
  @st.cache_resource
 
29
 
30
  yolo_model = load_yolo_model()
31
 
 
 
 
 
 
 
 
 
32
  # Helper functions
33
  def load_image_from_url(url, max_retries=3):
34
  for attempt in range(max_retries):
 
42
  time.sleep(1)
43
  else:
44
  return None
45
+ #Load chromaDB
46
+ client = chromadb.PersistentClient(path="./clothesDB")
47
+ collection = client.get_collection(name="fashion_items_ver2")
48
 
49
  def get_image_embedding(image):
50
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
 
53
  image_features /= image_features.norm(dim=-1, keepdim=True)
54
  return image_features.cpu().numpy()
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def get_text_embedding(text):
57
  text_tokens = tokenizer([text]).to(device)
58
  with torch.no_grad():
 
60
  text_features /= text_features.norm(dim=-1, keepdim=True)
61
  return text_features.cpu().numpy()
62
 
63
+ def get_all_embeddings_from_collection(collection):
64
+ all_embeddings = collection.get(include=['embeddings'])['embeddings']
65
+ return np.array(all_embeddings)
66
+
67
+ def get_metadata_from_ids(collection, ids):
68
+ results = collection.get(ids=ids)
69
+ return results['metadatas']
70
+
71
+ def find_similar_images(query_embedding, collection, top_k=5):
72
+ database_embeddings = get_all_embeddings_from_collection(collection)
73
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
74
  top_indices = np.argsort(similarities)[::-1][:top_k]
75
+
76
+ all_data = collection.get(include=['metadatas'])['metadatas']
77
+
78
+ top_metadatas = [all_data[idx] for idx in top_indices]
79
+
80
  results = []
81
+ for idx, metadata in enumerate(top_metadatas):
82
  results.append({
83
+ 'info': metadata,
84
+ 'similarity': similarities[top_indices[idx]]
85
  })
86
  return results
87
 
88
+
89
+
90
  def detect_clothing(image):
91
  results = yolo_model(image)
92
  detections = results[0].boxes.data.cpu().numpy()
 
163
  cropped_image = crop_image(st.session_state.query_image, selected_detection['bbox'])
164
  st.image(cropped_image, caption="Cropped Image", use_column_width=True)
165
  query_embedding = get_image_embedding(cropped_image)
166
+ similar_images = find_similar_images(query_embedding, collection)
167
 
168
  st.subheader("Similar Items:")
169
  for img in similar_images:
 
189
  if st.button("Search by Text"):
190
  if query_text:
191
  text_embedding = get_text_embedding(query_text)
192
+ similar_images = find_similar_images(text_embedding, collection)
193
  st.subheader("Similar Items:")
194
  for img in similar_images:
195
  col1, col2 = st.columns(2)