JoJosmin commited on
Commit
6d2622d
1 Parent(s): af5d866

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -52
app.py CHANGED
@@ -6,10 +6,9 @@ from PIL import Image
6
  from io import BytesIO
7
  import time
8
  import numpy as np
9
- from ultralytics import YOLO
10
- import chromadb
11
  from transformers import pipeline
12
- from sklearn.metrics.pairwise import cosine_similarity
 
13
 
14
  # Load segmentation model
15
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
@@ -25,20 +24,10 @@ def load_clip_model():
25
 
26
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
27
 
28
- # Load YOLOv8 model
29
- #@st.cache_resource
30
- #def load_yolo_model():
31
- # return YOLO("./best.pt")
32
-
33
- #yolo_model = load_yolo_model()
34
-
35
  # Load chromaDB
36
  client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
37
- #collection = client.get_collection(name="clothes_items_ver3")
38
  collection = client.get_collection(name="clothes")
39
 
40
-
41
-
42
  # Helper functions
43
  def load_image_from_url(url, max_retries=3):
44
  for attempt in range(max_retries):
@@ -53,63 +42,62 @@ def load_image_from_url(url, max_retries=3):
53
  else:
54
  return None
55
 
56
- def get_image_embedding(image):
57
- image_tensor = preprocess_val(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
58
  with torch.no_grad():
59
  image_features = clip_model.encode_image(image_tensor)
60
  image_features /= image_features.norm(dim=-1, keepdim=True)
61
  return image_features.cpu().numpy().flatten()
62
 
63
  def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
64
- # Segment image
65
  segments = segmenter(img)
66
-
67
- # Create list of masks
68
  mask_list = []
69
  detected_categories = []
70
  for s in segments:
71
  if s['label'] in clothes:
72
  mask_list.append(s['mask'])
73
- detected_categories.append(s['label']) # Store detected categories
74
 
75
- # Paste all masks on top of each other
76
- final_mask = np.zeros_like(np.array(img)[:, :, 0]) # Initialize mask
77
  for mask in mask_list:
78
  current_mask = np.array(mask)
79
- final_mask = np.maximum(final_mask, current_mask) # Use maximum to combine masks
80
 
81
- # Convert final mask from np array to PIL image
82
- final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255) # Convert to binary mask
83
-
84
- # Apply mask to original image
85
- img_with_alpha = img.convert("RGBA") # Ensure the image has an alpha channel
86
  img_with_alpha.putalpha(final_mask)
87
 
88
- return img_with_alpha.convert("RGB"), final_mask, detected_categories # Return detected categories
89
-
 
90
  def find_similar_images(query_embedding, collection, top_k=5):
91
- query_embedding = query_embedding.reshape(1, -1) # Reshape to 2D array for ChromaDB
92
- results = collection.query(
93
- query_embeddings=query_embedding,
94
- n_results=top_k,
95
- include=['metadatas', 'distances']
96
- )
97
 
98
- top_metadatas = results['metadatas'][0]
99
- top_distances = results['distances'][0]
 
 
 
 
100
 
101
  structured_results = []
102
  for metadata, distance in zip(top_metadatas, top_distances):
103
  structured_results.append({
104
  'info': metadata,
105
- 'similarity': 1 - distance
106
  })
107
 
108
  return structured_results
109
 
110
-
111
-
112
-
113
  # 세션 상태 초기화
114
  if 'step' not in st.session_state:
115
  st.session_state.step = 'input'
@@ -117,7 +105,7 @@ if 'query_image_url' not in st.session_state:
117
  st.session_state.query_image_url = ''
118
  if 'detections' not in st.session_state:
119
  st.session_state.detections = []
120
- if 'segmented_image' not in st.session_state: # Add segmented_image to session state
121
  st.session_state.segmented_image = None
122
  if 'selected_category' not in st.session_state:
123
  st.session_state.selected_category = None
@@ -125,7 +113,6 @@ if 'selected_category' not in st.session_state:
125
  # Streamlit app
126
  st.title("Advanced Fashion Search App")
127
 
128
- # 단계별 처리
129
  if st.session_state.step == 'input':
130
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
131
  if st.button("Detect Clothing"):
@@ -133,10 +120,9 @@ if st.session_state.step == 'input':
133
  query_image = load_image_from_url(st.session_state.query_image_url)
134
  if query_image is not None:
135
  st.session_state.query_image = query_image
136
- # Perform segmentation
137
  segmented_image, final_mask, detected_categories = segment_clothing(query_image)
138
- st.session_state.segmented_image = segmented_image # Store segmented image in session state
139
- st.session_state.detections = detected_categories # Store detected categories
140
  st.image(segmented_image, caption="Segmented Image", use_column_width=True)
141
  if st.session_state.detections:
142
  st.session_state.step = 'select_category'
@@ -160,19 +146,18 @@ elif st.session_state.step == 'select_category':
160
  st.warning("No categories detected.")
161
 
162
  elif st.session_state.step == 'show_results':
163
- original_image = st.session_state.query_image.convert("RGB") # Convert to RGB before displaying
164
  st.image(original_image, caption="Original Image", use_column_width=True)
165
 
166
- # Get the embedding of the segmented image
167
- query_embedding = get_image_embedding(st.session_state.segmented_image) # Use the segmented image from session state
168
-
169
  similar_images = find_similar_images(query_embedding, collection)
170
 
171
  st.subheader("Similar Items:")
172
  for img in similar_images:
173
  col1, col2 = st.columns(2)
174
  with col1:
175
- #st.image(img['image_url'], use_column_width=True)
176
  st.image(img['info']['image_url'], use_column_width=True)
177
  with col2:
178
  st.write(f"Name: {img['info']['name']}")
@@ -188,4 +173,4 @@ elif st.session_state.step == 'show_results':
188
  st.session_state.step = 'input'
189
  st.session_state.query_image_url = ''
190
  st.session_state.detections = []
191
- st.session_state.segmented_image = None # Reset segmented_image
 
6
  from io import BytesIO
7
  import time
8
  import numpy as np
 
 
9
  from transformers import pipeline
10
+ import chromadb
11
+ from sklearn.metrics.pairwise import euclidean_distances # 유클리드 거리 계산 추가
12
 
13
  # Load segmentation model
14
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
 
24
 
25
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
26
 
 
 
 
 
 
 
 
27
  # Load chromaDB
28
  client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
 
29
  collection = client.get_collection(name="clothes")
30
 
 
 
31
  # Helper functions
32
  def load_image_from_url(url, max_retries=3):
33
  for attempt in range(max_retries):
 
42
  else:
43
  return None
44
 
45
+ # 세그먼트 마스크 기반 임베딩 추출
46
+ def get_segmented_embedding(img, final_mask):
47
+ img_array = np.array(img)
48
+ final_mask_array = np.array(final_mask)
49
+
50
+ # 마스크 적용 (배경을 흰색으로 처리)
51
+ img_array[final_mask_array == 0] = 255
52
+ masked_img = Image.fromarray(img_array)
53
+
54
+ # 마스크된 이미지로부터 임베딩 추출
55
+ image_tensor = preprocess_val(masked_img).unsqueeze(0).to(device)
56
  with torch.no_grad():
57
  image_features = clip_model.encode_image(image_tensor)
58
  image_features /= image_features.norm(dim=-1, keepdim=True)
59
  return image_features.cpu().numpy().flatten()
60
 
61
  def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
 
62
  segments = segmenter(img)
 
 
63
  mask_list = []
64
  detected_categories = []
65
  for s in segments:
66
  if s['label'] in clothes:
67
  mask_list.append(s['mask'])
68
+ detected_categories.append(s['label'])
69
 
70
+ final_mask = np.zeros_like(np.array(img)[:, :, 0])
 
71
  for mask in mask_list:
72
  current_mask = np.array(mask)
73
+ final_mask = np.maximum(final_mask, current_mask)
74
 
75
+ final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255)
76
+ img_with_alpha = img.convert("RGBA")
 
 
 
77
  img_with_alpha.putalpha(final_mask)
78
 
79
+ return img_with_alpha.convert("RGB"), final_mask, detected_categories
80
+
81
+ # 유클리드 거리 기반으로 유사한 이미지 찾기
82
  def find_similar_images(query_embedding, collection, top_k=5):
83
+ query_embedding = query_embedding.reshape(1, -1)
 
 
 
 
 
84
 
85
+ all_embeddings = np.array(collection.get(include=['embeddings'])['embeddings'])
86
+ distances = euclidean_distances(query_embedding, all_embeddings).flatten() # 유클리드 거리 계산
87
+
88
+ top_indices = np.argsort(distances)[:top_k]
89
+ top_metadatas = [collection.get(ids=[str(idx)])['metadatas'][0] for idx in top_indices]
90
+ top_distances = distances[top_indices]
91
 
92
  structured_results = []
93
  for metadata, distance in zip(top_metadatas, top_distances):
94
  structured_results.append({
95
  'info': metadata,
96
+ 'similarity': 1 / (1 + distance) # 유클리드 거리가 작을수록 유사도가 높음
97
  })
98
 
99
  return structured_results
100
 
 
 
 
101
  # 세션 상태 초기화
102
  if 'step' not in st.session_state:
103
  st.session_state.step = 'input'
 
105
  st.session_state.query_image_url = ''
106
  if 'detections' not in st.session_state:
107
  st.session_state.detections = []
108
+ if 'segmented_image' not in st.session_state:
109
  st.session_state.segmented_image = None
110
  if 'selected_category' not in st.session_state:
111
  st.session_state.selected_category = None
 
113
  # Streamlit app
114
  st.title("Advanced Fashion Search App")
115
 
 
116
  if st.session_state.step == 'input':
117
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
118
  if st.button("Detect Clothing"):
 
120
  query_image = load_image_from_url(st.session_state.query_image_url)
121
  if query_image is not None:
122
  st.session_state.query_image = query_image
 
123
  segmented_image, final_mask, detected_categories = segment_clothing(query_image)
124
+ st.session_state.segmented_image = segmented_image
125
+ st.session_state.detections = detected_categories
126
  st.image(segmented_image, caption="Segmented Image", use_column_width=True)
127
  if st.session_state.detections:
128
  st.session_state.step = 'select_category'
 
146
  st.warning("No categories detected.")
147
 
148
  elif st.session_state.step == 'show_results':
149
+ original_image = st.session_state.query_image.convert("RGB")
150
  st.image(original_image, caption="Original Image", use_column_width=True)
151
 
152
+ # 세그먼트된 이미지에서 임베딩 추출
153
+ query_embedding = get_segmented_embedding(st.session_state.query_image, st.session_state.segmented_image)
154
+
155
  similar_images = find_similar_images(query_embedding, collection)
156
 
157
  st.subheader("Similar Items:")
158
  for img in similar_images:
159
  col1, col2 = st.columns(2)
160
  with col1:
 
161
  st.image(img['info']['image_url'], use_column_width=True)
162
  with col2:
163
  st.write(f"Name: {img['info']['name']}")
 
173
  st.session_state.step = 'input'
174
  st.session_state.query_image_url = ''
175
  st.session_state.detections = []
176
+ st.session_state.segmented_image = None