Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,10 +6,9 @@ from PIL import Image
|
|
6 |
from io import BytesIO
|
7 |
import time
|
8 |
import numpy as np
|
9 |
-
from ultralytics import YOLO
|
10 |
-
import chromadb
|
11 |
from transformers import pipeline
|
12 |
-
|
|
|
13 |
|
14 |
# Load segmentation model
|
15 |
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
|
@@ -25,20 +24,10 @@ def load_clip_model():
|
|
25 |
|
26 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
27 |
|
28 |
-
# Load YOLOv8 model
|
29 |
-
#@st.cache_resource
|
30 |
-
#def load_yolo_model():
|
31 |
-
# return YOLO("./best.pt")
|
32 |
-
|
33 |
-
#yolo_model = load_yolo_model()
|
34 |
-
|
35 |
# Load chromaDB
|
36 |
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
|
37 |
-
#collection = client.get_collection(name="clothes_items_ver3")
|
38 |
collection = client.get_collection(name="clothes")
|
39 |
|
40 |
-
|
41 |
-
|
42 |
# Helper functions
|
43 |
def load_image_from_url(url, max_retries=3):
|
44 |
for attempt in range(max_retries):
|
@@ -53,63 +42,62 @@ def load_image_from_url(url, max_retries=3):
|
|
53 |
else:
|
54 |
return None
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
with torch.no_grad():
|
59 |
image_features = clip_model.encode_image(image_tensor)
|
60 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
61 |
return image_features.cpu().numpy().flatten()
|
62 |
|
63 |
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
|
64 |
-
# Segment image
|
65 |
segments = segmenter(img)
|
66 |
-
|
67 |
-
# Create list of masks
|
68 |
mask_list = []
|
69 |
detected_categories = []
|
70 |
for s in segments:
|
71 |
if s['label'] in clothes:
|
72 |
mask_list.append(s['mask'])
|
73 |
-
detected_categories.append(s['label'])
|
74 |
|
75 |
-
|
76 |
-
final_mask = np.zeros_like(np.array(img)[:, :, 0]) # Initialize mask
|
77 |
for mask in mask_list:
|
78 |
current_mask = np.array(mask)
|
79 |
-
final_mask = np.maximum(final_mask, current_mask)
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
# Apply mask to original image
|
85 |
-
img_with_alpha = img.convert("RGBA") # Ensure the image has an alpha channel
|
86 |
img_with_alpha.putalpha(final_mask)
|
87 |
|
88 |
-
return img_with_alpha.convert("RGB"), final_mask, detected_categories
|
89 |
-
|
|
|
90 |
def find_similar_images(query_embedding, collection, top_k=5):
|
91 |
-
query_embedding = query_embedding.reshape(1, -1)
|
92 |
-
results = collection.query(
|
93 |
-
query_embeddings=query_embedding,
|
94 |
-
n_results=top_k,
|
95 |
-
include=['metadatas', 'distances']
|
96 |
-
)
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
|
101 |
structured_results = []
|
102 |
for metadata, distance in zip(top_metadatas, top_distances):
|
103 |
structured_results.append({
|
104 |
'info': metadata,
|
105 |
-
'similarity': 1
|
106 |
})
|
107 |
|
108 |
return structured_results
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
# 세션 상태 초기화
|
114 |
if 'step' not in st.session_state:
|
115 |
st.session_state.step = 'input'
|
@@ -117,7 +105,7 @@ if 'query_image_url' not in st.session_state:
|
|
117 |
st.session_state.query_image_url = ''
|
118 |
if 'detections' not in st.session_state:
|
119 |
st.session_state.detections = []
|
120 |
-
if 'segmented_image' not in st.session_state:
|
121 |
st.session_state.segmented_image = None
|
122 |
if 'selected_category' not in st.session_state:
|
123 |
st.session_state.selected_category = None
|
@@ -125,7 +113,6 @@ if 'selected_category' not in st.session_state:
|
|
125 |
# Streamlit app
|
126 |
st.title("Advanced Fashion Search App")
|
127 |
|
128 |
-
# 단계별 처리
|
129 |
if st.session_state.step == 'input':
|
130 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
131 |
if st.button("Detect Clothing"):
|
@@ -133,10 +120,9 @@ if st.session_state.step == 'input':
|
|
133 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
134 |
if query_image is not None:
|
135 |
st.session_state.query_image = query_image
|
136 |
-
# Perform segmentation
|
137 |
segmented_image, final_mask, detected_categories = segment_clothing(query_image)
|
138 |
-
st.session_state.segmented_image = segmented_image
|
139 |
-
st.session_state.detections = detected_categories
|
140 |
st.image(segmented_image, caption="Segmented Image", use_column_width=True)
|
141 |
if st.session_state.detections:
|
142 |
st.session_state.step = 'select_category'
|
@@ -160,19 +146,18 @@ elif st.session_state.step == 'select_category':
|
|
160 |
st.warning("No categories detected.")
|
161 |
|
162 |
elif st.session_state.step == 'show_results':
|
163 |
-
original_image = st.session_state.query_image.convert("RGB")
|
164 |
st.image(original_image, caption="Original Image", use_column_width=True)
|
165 |
|
166 |
-
#
|
167 |
-
query_embedding =
|
168 |
-
|
169 |
similar_images = find_similar_images(query_embedding, collection)
|
170 |
|
171 |
st.subheader("Similar Items:")
|
172 |
for img in similar_images:
|
173 |
col1, col2 = st.columns(2)
|
174 |
with col1:
|
175 |
-
#st.image(img['image_url'], use_column_width=True)
|
176 |
st.image(img['info']['image_url'], use_column_width=True)
|
177 |
with col2:
|
178 |
st.write(f"Name: {img['info']['name']}")
|
@@ -188,4 +173,4 @@ elif st.session_state.step == 'show_results':
|
|
188 |
st.session_state.step = 'input'
|
189 |
st.session_state.query_image_url = ''
|
190 |
st.session_state.detections = []
|
191 |
-
st.session_state.segmented_image = None
|
|
|
6 |
from io import BytesIO
|
7 |
import time
|
8 |
import numpy as np
|
|
|
|
|
9 |
from transformers import pipeline
|
10 |
+
import chromadb
|
11 |
+
from sklearn.metrics.pairwise import euclidean_distances # 유클리드 거리 계산 추가
|
12 |
|
13 |
# Load segmentation model
|
14 |
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
|
|
|
24 |
|
25 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Load chromaDB
|
28 |
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
|
|
|
29 |
collection = client.get_collection(name="clothes")
|
30 |
|
|
|
|
|
31 |
# Helper functions
|
32 |
def load_image_from_url(url, max_retries=3):
|
33 |
for attempt in range(max_retries):
|
|
|
42 |
else:
|
43 |
return None
|
44 |
|
45 |
+
# 세그먼트 마스크 기반 임베딩 추출
|
46 |
+
def get_segmented_embedding(img, final_mask):
|
47 |
+
img_array = np.array(img)
|
48 |
+
final_mask_array = np.array(final_mask)
|
49 |
+
|
50 |
+
# 마스크 적용 (배경을 흰색으로 처리)
|
51 |
+
img_array[final_mask_array == 0] = 255
|
52 |
+
masked_img = Image.fromarray(img_array)
|
53 |
+
|
54 |
+
# 마스크된 이미지로부터 임베딩 추출
|
55 |
+
image_tensor = preprocess_val(masked_img).unsqueeze(0).to(device)
|
56 |
with torch.no_grad():
|
57 |
image_features = clip_model.encode_image(image_tensor)
|
58 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
59 |
return image_features.cpu().numpy().flatten()
|
60 |
|
61 |
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
|
|
|
62 |
segments = segmenter(img)
|
|
|
|
|
63 |
mask_list = []
|
64 |
detected_categories = []
|
65 |
for s in segments:
|
66 |
if s['label'] in clothes:
|
67 |
mask_list.append(s['mask'])
|
68 |
+
detected_categories.append(s['label'])
|
69 |
|
70 |
+
final_mask = np.zeros_like(np.array(img)[:, :, 0])
|
|
|
71 |
for mask in mask_list:
|
72 |
current_mask = np.array(mask)
|
73 |
+
final_mask = np.maximum(final_mask, current_mask)
|
74 |
|
75 |
+
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255)
|
76 |
+
img_with_alpha = img.convert("RGBA")
|
|
|
|
|
|
|
77 |
img_with_alpha.putalpha(final_mask)
|
78 |
|
79 |
+
return img_with_alpha.convert("RGB"), final_mask, detected_categories
|
80 |
+
|
81 |
+
# 유클리드 거리 기반으로 유사한 이미지 찾기
|
82 |
def find_similar_images(query_embedding, collection, top_k=5):
|
83 |
+
query_embedding = query_embedding.reshape(1, -1)
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
all_embeddings = np.array(collection.get(include=['embeddings'])['embeddings'])
|
86 |
+
distances = euclidean_distances(query_embedding, all_embeddings).flatten() # 유클리드 거리 계산
|
87 |
+
|
88 |
+
top_indices = np.argsort(distances)[:top_k]
|
89 |
+
top_metadatas = [collection.get(ids=[str(idx)])['metadatas'][0] for idx in top_indices]
|
90 |
+
top_distances = distances[top_indices]
|
91 |
|
92 |
structured_results = []
|
93 |
for metadata, distance in zip(top_metadatas, top_distances):
|
94 |
structured_results.append({
|
95 |
'info': metadata,
|
96 |
+
'similarity': 1 / (1 + distance) # 유클리드 거리가 작을수록 유사도가 높음
|
97 |
})
|
98 |
|
99 |
return structured_results
|
100 |
|
|
|
|
|
|
|
101 |
# 세션 상태 초기화
|
102 |
if 'step' not in st.session_state:
|
103 |
st.session_state.step = 'input'
|
|
|
105 |
st.session_state.query_image_url = ''
|
106 |
if 'detections' not in st.session_state:
|
107 |
st.session_state.detections = []
|
108 |
+
if 'segmented_image' not in st.session_state:
|
109 |
st.session_state.segmented_image = None
|
110 |
if 'selected_category' not in st.session_state:
|
111 |
st.session_state.selected_category = None
|
|
|
113 |
# Streamlit app
|
114 |
st.title("Advanced Fashion Search App")
|
115 |
|
|
|
116 |
if st.session_state.step == 'input':
|
117 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
118 |
if st.button("Detect Clothing"):
|
|
|
120 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
121 |
if query_image is not None:
|
122 |
st.session_state.query_image = query_image
|
|
|
123 |
segmented_image, final_mask, detected_categories = segment_clothing(query_image)
|
124 |
+
st.session_state.segmented_image = segmented_image
|
125 |
+
st.session_state.detections = detected_categories
|
126 |
st.image(segmented_image, caption="Segmented Image", use_column_width=True)
|
127 |
if st.session_state.detections:
|
128 |
st.session_state.step = 'select_category'
|
|
|
146 |
st.warning("No categories detected.")
|
147 |
|
148 |
elif st.session_state.step == 'show_results':
|
149 |
+
original_image = st.session_state.query_image.convert("RGB")
|
150 |
st.image(original_image, caption="Original Image", use_column_width=True)
|
151 |
|
152 |
+
# 세그먼트된 이미지에서 임베딩 추출
|
153 |
+
query_embedding = get_segmented_embedding(st.session_state.query_image, st.session_state.segmented_image)
|
154 |
+
|
155 |
similar_images = find_similar_images(query_embedding, collection)
|
156 |
|
157 |
st.subheader("Similar Items:")
|
158 |
for img in similar_images:
|
159 |
col1, col2 = st.columns(2)
|
160 |
with col1:
|
|
|
161 |
st.image(img['info']['image_url'], use_column_width=True)
|
162 |
with col2:
|
163 |
st.write(f"Name: {img['info']['name']}")
|
|
|
173 |
st.session_state.step = 'input'
|
174 |
st.session_state.query_image_url = ''
|
175 |
st.session_state.detections = []
|
176 |
+
st.session_state.segmented_image = None
|